DsMarkingServiceImpl.java 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. package cn.com.qmth.am.service.impl;
  2. import java.math.BigDecimal;
  3. import java.math.RoundingMode;
  4. import java.util.Arrays;
  5. import java.util.HashMap;
  6. import java.util.Map;
  7. import java.util.regex.Matcher;
  8. import java.util.regex.Pattern;
  9. import org.apache.commons.io.IOUtils;
  10. import org.apache.commons.lang3.StringUtils;
  11. import org.slf4j.Logger;
  12. import org.slf4j.LoggerFactory;
  13. import org.springframework.beans.factory.annotation.Autowired;
  14. import org.springframework.stereotype.Service;
  15. import com.alibaba.fastjson.JSONObject;
  16. import com.qmth.boot.core.retrofit.exception.RetrofitResponseError;
  17. import cn.com.qmth.am.bean.OcrServer;
  18. import cn.com.qmth.am.bean.ds.AutoScoreRequest;
  19. import cn.com.qmth.am.bean.ds.AutoScoreResult;
  20. import cn.com.qmth.am.bean.ds.ChatReq;
  21. import cn.com.qmth.am.bean.ds.ChatResult;
  22. import cn.com.qmth.am.bean.ds.ChatRole;
  23. import cn.com.qmth.am.bean.ds.DsChoice;
  24. import cn.com.qmth.am.bean.ds.MarkingReq;
  25. import cn.com.qmth.am.bean.ds.OcrMessage;
  26. import cn.com.qmth.am.bean.ds.OcrReq;
  27. import cn.com.qmth.am.config.SysProperty;
  28. import cn.com.qmth.am.entity.QuestionEntity;
  29. import cn.com.qmth.am.enums.PromptTemplate;
  30. import cn.com.qmth.am.service.DsMarkingService;
  31. import cn.com.qmth.am.utils.FreeMarkerUtil;
  32. import cn.com.qmth.am.utils.HttpMethod;
  33. import cn.com.qmth.am.utils.OKHttpUtil;
  34. import okhttp3.Response;
  35. @Service
  36. public class DsMarkingServiceImpl implements DsMarkingService {
  37. @Autowired
  38. private SysProperty sysProperty;
  39. private static final Logger log = LoggerFactory.getLogger(DsMarkingService.class);
  40. // @SuppressWarnings("deprecation")
  41. // public static void main(String[] args) {
  42. //
  43. // Map<String, String> headers = new HashMap<>();
  44. // headers.put("Authorization", "Bearer 7dac2f2166994b8f9c6de0a8eff2814c");
  45. // Response resp = null;
  46. // try {
  47. // resp = OKHttpUtil.call(HttpMethod.POST,
  48. // "http://39.174.90.3:31091/spiritx-api/v1/chat/completions", headers,
  49. // "{\"model\":\"deepseek-r1-distill-qwen-32b-awq\",\"messages\":[{\"role\":\"user\",\"content\":\"你是谁?\"}]}");
  50. // if (resp.code() != 200) {
  51. // throw new RuntimeException("err :" + resp.body().string());
  52. // } else {
  53. // System.out.println("成功处理:" + resp.body().string());
  54. // }
  55. // } catch (Exception e) {
  56. // throw new RuntimeException(e);
  57. // } finally {
  58. // IOUtils.closeQuietly(resp);
  59. // }
  60. // }
  61. @Override
  62. public String ocr(OcrServer ocrServer, String base64) {
  63. OcrReq dreq = new OcrReq(ocrServer.getModel());
  64. dreq.addMsg(new OcrMessage(base64));
  65. String res = ocr(ocrServer, dreq);
  66. ChatResult result = JSONObject.parseObject(res, ChatResult.class);
  67. String text = result.getChoices().stream().filter(choice -> choice.getMessage().getRole() == ChatRole.assistant)
  68. .map(choice -> choice.getMessage().getContent()).findFirst().orElse("");
  69. return text;
  70. }
  71. @Override
  72. public AutoScoreResult autoScore(AutoScoreRequest request, QuestionEntity q) {
  73. String question = FreeMarkerUtil.getMarkingReq(request, q.getPromptTemplate());
  74. MarkingReq dreq = new MarkingReq(sysProperty.getMarkingModel());
  75. dreq.addMsg(ChatRole.user, question);
  76. String res = marking(dreq);
  77. DsChoice result = JSONObject.parseObject(res, DsChoice.class);
  78. try {
  79. String text = result.getMessage().getContent();
  80. AutoScoreResult scoreResult = new AutoScoreResult();
  81. if (PromptTemplate.COMMON.equals(q.getPromptTemplate())) {
  82. // 依据总分与步骤分计算最大精度
  83. int scale = Math.max(getDecimalPlaces(request.getIntervalScore()),
  84. getDecimalPlaces(request.getTotalScore()));
  85. int stepCount = request.getStandardAnswer().size();
  86. String scoreStr = null;
  87. if (stepCount > 1) {
  88. scoreStr = fomatStrByRex(text);
  89. } else {
  90. scoreStr = fomatStr(text);
  91. }
  92. String[] scores = StringUtils.split(scoreStr, ",");
  93. double[] scoreArray = new double[stepCount];
  94. for (int i = 0; i < stepCount; i++) {
  95. // 根据得分率与步骤总分计算实际得分,按最大精度保留小数位数
  96. double score = BigDecimal
  97. .valueOf(Math.min(Integer.parseInt(scores[i].trim()), 100)
  98. * request.getStandardAnswer().get(i).getScore())
  99. .divide(BigDecimal.valueOf(100), scale, RoundingMode.HALF_UP).doubleValue();
  100. scoreArray[i] = score;
  101. }
  102. scoreResult.setStepScore(scoreArray);
  103. scoreResult.setTotalScore(Arrays.stream(scoreArray).mapToObj(BigDecimal::new)
  104. .reduce(BigDecimal.ZERO, BigDecimal::add).setScale(1, BigDecimal.ROUND_HALF_UP).doubleValue());
  105. } else if (PromptTemplate.TRANSLATION.equals(q.getPromptTemplate())) {
  106. String scoreStr = fomatStr(text);
  107. double[] scoreArray = new double[1];
  108. // 根据得分率与步骤总分计算实际得分,按最大精度保留小数位数
  109. double score = Double.valueOf(scoreStr);
  110. scoreArray[0] = score;
  111. scoreResult.setStepScore(scoreArray);
  112. scoreResult.setTotalScore(score);
  113. } else {
  114. throw new RuntimeException("模版类型错误");
  115. }
  116. return scoreResult;
  117. } catch (Exception e) {
  118. log.error(e.getMessage() + " | " + res);
  119. return null;
  120. }
  121. }
  122. private String fomatStrByRex(String scoreStr) {
  123. int tag = scoreStr.lastIndexOf("</think>");
  124. if (tag != -1) {
  125. scoreStr = scoreStr.substring(tag).trim();
  126. }
  127. String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll("[0-9]\\.", "");
  128. Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}");
  129. Matcher matcher = pattern.matcher(ret);
  130. if (matcher.find()) {
  131. return matcher.group();
  132. } else {
  133. throw new RuntimeException("返回格式错误");
  134. }
  135. }
  136. private String fomatStr(String scoreStr) {
  137. scoreStr = scoreStr.substring(scoreStr.lastIndexOf("\n") + 1).trim();
  138. String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll(":", ":");
  139. ret = ret.substring(ret.lastIndexOf(":") + 1).trim();
  140. return ret;
  141. }
  142. // private String fomatStrByRex(String scoreStr) {
  143. // int tag = scoreStr.lastIndexOf("</think>");
  144. // if (tag != -1) {
  145. // scoreStr = scoreStr.substring(tag).trim();
  146. // }
  147. // String ret = scoreStr.replaceAll(",", ",").replaceAll("。",
  148. // "").replaceAll("[0-9]\\.", "");
  149. // Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}");
  150. // Matcher matcher = pattern.matcher(ret);
  151. // if (matcher.find()) {
  152. // return matcher.group();
  153. // } else {
  154. // throw new RuntimeException("返回格式错误");
  155. // }
  156. // }
  157. private int getDecimalPlaces(double value) {
  158. return Math.max(0, BigDecimal.valueOf(value).stripTrailingZeros().scale());
  159. }
  160. public static void main(String[] args) {
  161. String scoreStr = "</think>。\\n\\n\\n70,70,60\\n\\n评分结果2个3,29,110 \\n\\n考生的回答完全覆盖了所有的关键内容,逻辑清晰,术语使用准确";
  162. scoreStr = scoreStr.substring(scoreStr.lastIndexOf("</think>") + 1).trim();
  163. System.out.println(Runtime.getRuntime().availableProcessors());
  164. String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll(":", ":").replaceAll("[0-9]\\.", "");
  165. Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}");
  166. Matcher matcher = pattern.matcher(ret);
  167. if (matcher.find()) {
  168. System.out.println(matcher.group());
  169. }
  170. }
  171. @SuppressWarnings("deprecation")
  172. private String marking(ChatReq dreq) {
  173. Map<String, String> headers = new HashMap<>();
  174. headers.put("Authorization", "Bearer " + sysProperty.getMarkingKey());
  175. Response resp = null;
  176. try {
  177. resp = OKHttpUtil.call(HttpMethod.POST, sysProperty.getMarkingServer(), headers,
  178. JSONObject.toJSONString(dreq));
  179. if (resp.code() != 200) {
  180. throw new RetrofitResponseError(resp.code(), resp.body().string());
  181. } else {
  182. return resp.body().string();
  183. }
  184. } catch (Exception e) {
  185. throw new RetrofitResponseError(500, e.getMessage(), e);
  186. } finally {
  187. IOUtils.closeQuietly(resp);
  188. }
  189. }
  190. @SuppressWarnings("deprecation")
  191. private String ocr(OcrServer ocrServer, ChatReq dreq) {
  192. Map<String, String> headers = new HashMap<>();
  193. headers.put("Authorization", "Bearer " + ocrServer.getKey());
  194. Response resp = null;
  195. try {
  196. resp = OKHttpUtil.call(HttpMethod.POST, ocrServer.getServer(), headers, JSONObject.toJSONString(dreq));
  197. if (resp.code() != 200) {
  198. throw new RetrofitResponseError(resp.code(), resp.body().string());
  199. } else {
  200. return resp.body().string();
  201. }
  202. } catch (RetrofitResponseError e) {
  203. throw e;
  204. } catch (Exception e) {
  205. throw new RetrofitResponseError(500, e.getMessage(), e);
  206. } finally {
  207. IOUtils.closeQuietly(resp);
  208. }
  209. }
  210. }