package cn.com.qmth.am.service.impl; import java.math.BigDecimal; import java.math.RoundingMode; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import com.alibaba.fastjson.JSONObject; import com.qmth.boot.core.retrofit.exception.RetrofitResponseError; import cn.com.qmth.am.bean.OcrServer; import cn.com.qmth.am.bean.ds.AutoScoreRequest; import cn.com.qmth.am.bean.ds.AutoScoreResult; import cn.com.qmth.am.bean.ds.ChatReq; import cn.com.qmth.am.bean.ds.ChatResult; import cn.com.qmth.am.bean.ds.ChatRole; import cn.com.qmth.am.bean.ds.DsChoice; import cn.com.qmth.am.bean.ds.MarkingReq; import cn.com.qmth.am.bean.ds.OcrMessage; import cn.com.qmth.am.bean.ds.OcrReq; import cn.com.qmth.am.config.SysProperty; import cn.com.qmth.am.entity.QuestionEntity; import cn.com.qmth.am.enums.PromptTemplate; import cn.com.qmth.am.service.DsMarkingService; import cn.com.qmth.am.utils.FreeMarkerUtil; import cn.com.qmth.am.utils.HttpMethod; import cn.com.qmth.am.utils.OKHttpUtil; import okhttp3.Response; @Service public class DsMarkingServiceImpl implements DsMarkingService { @Autowired private SysProperty sysProperty; private static final Logger log = LoggerFactory.getLogger(DsMarkingService.class); // @SuppressWarnings("deprecation") // public static void main(String[] args) { // // Map headers = new HashMap<>(); // headers.put("Authorization", "Bearer 7dac2f2166994b8f9c6de0a8eff2814c"); // Response resp = null; // try { // resp = OKHttpUtil.call(HttpMethod.POST, // "http://39.174.90.3:31091/spiritx-api/v1/chat/completions", headers, // "{\"model\":\"deepseek-r1-distill-qwen-32b-awq\",\"messages\":[{\"role\":\"user\",\"content\":\"你是谁?\"}]}"); // if (resp.code() != 200) { // throw new RuntimeException("err :" + resp.body().string()); // } else { // System.out.println("成功处理:" + resp.body().string()); // } // } catch (Exception e) { // throw new RuntimeException(e); // } finally { // IOUtils.closeQuietly(resp); // } // } @Override public String ocr(OcrServer ocrServer, String base64) { OcrReq dreq = new OcrReq(ocrServer.getModel()); dreq.addMsg(new OcrMessage(base64)); String res = ocr(ocrServer, dreq); ChatResult result = JSONObject.parseObject(res, ChatResult.class); String text = result.getChoices().stream().filter(choice -> choice.getMessage().getRole() == ChatRole.assistant) .map(choice -> choice.getMessage().getContent()).findFirst().orElse(""); return text; } @Override public AutoScoreResult autoScore(AutoScoreRequest request, QuestionEntity q) { String question = FreeMarkerUtil.getMarkingReq(request, q.getPromptTemplate()); MarkingReq dreq = new MarkingReq(sysProperty.getMarkingModel()); dreq.addMsg(ChatRole.user, question); String res = marking(dreq); DsChoice result = JSONObject.parseObject(res, DsChoice.class); try { String text = result.getMessage().getContent(); AutoScoreResult scoreResult = new AutoScoreResult(); if (PromptTemplate.COMMON.equals(q.getPromptTemplate())) { // 依据总分与步骤分计算最大精度 int scale = Math.max(getDecimalPlaces(request.getIntervalScore()), getDecimalPlaces(request.getTotalScore())); int stepCount = request.getStandardAnswer().size(); String scoreStr = null; if (stepCount > 1) { scoreStr = fomatStrByRex(text); } else { scoreStr = fomatStr(text); } String[] scores = StringUtils.split(scoreStr, ","); double[] scoreArray = new double[stepCount]; for (int i = 0; i < stepCount; i++) { // 根据得分率与步骤总分计算实际得分,按最大精度保留小数位数 double score = BigDecimal .valueOf(Math.min(Integer.parseInt(scores[i].trim()), 100) * request.getStandardAnswer().get(i).getScore()) .divide(BigDecimal.valueOf(100), scale, RoundingMode.HALF_UP).doubleValue(); scoreArray[i] = score; } scoreResult.setStepScore(scoreArray); scoreResult.setTotalScore(Arrays.stream(scoreArray).mapToObj(BigDecimal::new) .reduce(BigDecimal.ZERO, BigDecimal::add).setScale(1, BigDecimal.ROUND_HALF_UP).doubleValue()); } else if (PromptTemplate.TRANSLATION.equals(q.getPromptTemplate())) { String scoreStr = fomatStr(text); double[] scoreArray = new double[1]; // 根据得分率与步骤总分计算实际得分,按最大精度保留小数位数 double score = Double.valueOf(scoreStr); scoreArray[0] = score; scoreResult.setStepScore(scoreArray); scoreResult.setTotalScore(score); } else { throw new RuntimeException("模版类型错误"); } return scoreResult; } catch (Exception e) { log.error(e.getMessage() + " | " + res); return null; } } private String fomatStrByRex(String scoreStr) { int tag = scoreStr.lastIndexOf(""); if (tag != -1) { scoreStr = scoreStr.substring(tag).trim(); } String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll("[0-9]\\.", ""); Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}"); Matcher matcher = pattern.matcher(ret); if (matcher.find()) { return matcher.group(); } else { throw new RuntimeException("返回格式错误"); } } private String fomatStr(String scoreStr) { scoreStr = scoreStr.substring(scoreStr.lastIndexOf("\n") + 1).trim(); String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll(":", ":"); ret = ret.substring(ret.lastIndexOf(":") + 1).trim(); return ret; } // private String fomatStrByRex(String scoreStr) { // int tag = scoreStr.lastIndexOf(""); // if (tag != -1) { // scoreStr = scoreStr.substring(tag).trim(); // } // String ret = scoreStr.replaceAll(",", ",").replaceAll("。", // "").replaceAll("[0-9]\\.", ""); // Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}"); // Matcher matcher = pattern.matcher(ret); // if (matcher.find()) { // return matcher.group(); // } else { // throw new RuntimeException("返回格式错误"); // } // } private int getDecimalPlaces(double value) { return Math.max(0, BigDecimal.valueOf(value).stripTrailingZeros().scale()); } public static void main(String[] args) { String scoreStr = "。\\n\\n\\n70,70,60\\n\\n评分结果2个3,29,110 \\n\\n考生的回答完全覆盖了所有的关键内容,逻辑清晰,术语使用准确"; scoreStr = scoreStr.substring(scoreStr.lastIndexOf("") + 1).trim(); System.out.println(Runtime.getRuntime().availableProcessors()); String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll(":", ":").replaceAll("[0-9]\\.", ""); Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}"); Matcher matcher = pattern.matcher(ret); if (matcher.find()) { System.out.println(matcher.group()); } } @SuppressWarnings("deprecation") private String marking(ChatReq dreq) { Map headers = new HashMap<>(); headers.put("Authorization", "Bearer " + sysProperty.getMarkingKey()); Response resp = null; try { resp = OKHttpUtil.call(HttpMethod.POST, sysProperty.getMarkingServer(), headers, JSONObject.toJSONString(dreq)); if (resp.code() != 200) { throw new RetrofitResponseError(resp.code(), resp.body().string()); } else { return resp.body().string(); } } catch (Exception e) { throw new RetrofitResponseError(500, e.getMessage(), e); } finally { IOUtils.closeQuietly(resp); } } @SuppressWarnings("deprecation") private String ocr(OcrServer ocrServer, ChatReq dreq) { Map headers = new HashMap<>(); headers.put("Authorization", "Bearer " + ocrServer.getKey()); Response resp = null; try { resp = OKHttpUtil.call(HttpMethod.POST, ocrServer.getServer(), headers, JSONObject.toJSONString(dreq)); if (resp.code() != 200) { throw new RetrofitResponseError(resp.code(), resp.body().string()); } else { return resp.body().string(); } } catch (RetrofitResponseError e) { throw e; } catch (Exception e) { throw new RetrofitResponseError(500, e.getMessage(), e); } finally { IOUtils.closeQuietly(resp); } } }