123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- package cn.com.qmth.am.service.impl;
- import java.math.BigDecimal;
- import java.math.RoundingMode;
- import java.util.Arrays;
- import java.util.HashMap;
- import java.util.Map;
- import java.util.regex.Matcher;
- import java.util.regex.Pattern;
- import org.apache.commons.io.IOUtils;
- import org.apache.commons.lang3.StringUtils;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import org.springframework.beans.factory.annotation.Autowired;
- import org.springframework.stereotype.Service;
- import com.alibaba.fastjson.JSONObject;
- import com.qmth.boot.core.retrofit.exception.RetrofitResponseError;
- import cn.com.qmth.am.bean.OcrServer;
- import cn.com.qmth.am.bean.ds.AutoScoreRequest;
- import cn.com.qmth.am.bean.ds.AutoScoreResult;
- import cn.com.qmth.am.bean.ds.ChatReq;
- import cn.com.qmth.am.bean.ds.ChatResult;
- import cn.com.qmth.am.bean.ds.ChatRole;
- import cn.com.qmth.am.bean.ds.DsChoice;
- import cn.com.qmth.am.bean.ds.MarkingReq;
- import cn.com.qmth.am.bean.ds.OcrMessage;
- import cn.com.qmth.am.bean.ds.OcrReq;
- import cn.com.qmth.am.config.SysProperty;
- import cn.com.qmth.am.entity.QuestionEntity;
- import cn.com.qmth.am.enums.PromptTemplate;
- import cn.com.qmth.am.service.DsMarkingService;
- import cn.com.qmth.am.utils.FreeMarkerUtil;
- import cn.com.qmth.am.utils.HttpMethod;
- import cn.com.qmth.am.utils.OKHttpUtil;
- import okhttp3.Response;
- @Service
- public class DsMarkingServiceImpl implements DsMarkingService {
- @Autowired
- private SysProperty sysProperty;
- private static final Logger log = LoggerFactory.getLogger(DsMarkingService.class);
- // @SuppressWarnings("deprecation")
- // public static void main(String[] args) {
- //
- // Map<String, String> headers = new HashMap<>();
- // headers.put("Authorization", "Bearer 7dac2f2166994b8f9c6de0a8eff2814c");
- // Response resp = null;
- // try {
- // resp = OKHttpUtil.call(HttpMethod.POST,
- // "http://39.174.90.3:31091/spiritx-api/v1/chat/completions", headers,
- // "{\"model\":\"deepseek-r1-distill-qwen-32b-awq\",\"messages\":[{\"role\":\"user\",\"content\":\"你是谁?\"}]}");
- // if (resp.code() != 200) {
- // throw new RuntimeException("err :" + resp.body().string());
- // } else {
- // System.out.println("成功处理:" + resp.body().string());
- // }
- // } catch (Exception e) {
- // throw new RuntimeException(e);
- // } finally {
- // IOUtils.closeQuietly(resp);
- // }
- // }
- @Override
- public String ocr(OcrServer ocrServer, String base64) {
- OcrReq dreq = new OcrReq(ocrServer.getModel());
- dreq.addMsg(new OcrMessage(base64));
- String res = ocr(ocrServer, dreq);
- ChatResult result = JSONObject.parseObject(res, ChatResult.class);
- String text = result.getChoices().stream().filter(choice -> choice.getMessage().getRole() == ChatRole.assistant)
- .map(choice -> choice.getMessage().getContent()).findFirst().orElse("");
- return text;
- }
- @Override
- public AutoScoreResult autoScore(AutoScoreRequest request, QuestionEntity q) {
- String question = FreeMarkerUtil.getMarkingReq(request, q.getPromptTemplate());
- MarkingReq dreq = new MarkingReq(sysProperty.getMarkingModel());
- dreq.addMsg(ChatRole.user, question);
- String res = marking(dreq);
- DsChoice result = JSONObject.parseObject(res, DsChoice.class);
- try {
- String text = result.getMessage().getContent();
- AutoScoreResult scoreResult = new AutoScoreResult();
- if (PromptTemplate.COMMON.equals(q.getPromptTemplate())) {
- // 依据总分与步骤分计算最大精度
- int scale = Math.max(getDecimalPlaces(request.getIntervalScore()),
- getDecimalPlaces(request.getTotalScore()));
- int stepCount = request.getStandardAnswer().size();
- String scoreStr = null;
- if (stepCount > 1) {
- scoreStr = fomatStrByRex(text);
- } else {
- scoreStr = fomatStr(text);
- }
- String[] scores = StringUtils.split(scoreStr, ",");
- double[] scoreArray = new double[stepCount];
- for (int i = 0; i < stepCount; i++) {
- // 根据得分率与步骤总分计算实际得分,按最大精度保留小数位数
- double score = BigDecimal
- .valueOf(Math.min(Integer.parseInt(scores[i].trim()), 100)
- * request.getStandardAnswer().get(i).getScore())
- .divide(BigDecimal.valueOf(100), scale, RoundingMode.HALF_UP).doubleValue();
- scoreArray[i] = score;
- }
- scoreResult.setStepScore(scoreArray);
- scoreResult.setTotalScore(Arrays.stream(scoreArray).mapToObj(BigDecimal::new)
- .reduce(BigDecimal.ZERO, BigDecimal::add).setScale(1, BigDecimal.ROUND_HALF_UP).doubleValue());
- } else if (PromptTemplate.TRANSLATION.equals(q.getPromptTemplate())) {
- String scoreStr = fomatStr(text);
- double[] scoreArray = new double[1];
- // 根据得分率与步骤总分计算实际得分,按最大精度保留小数位数
- double score = Double.valueOf(scoreStr);
- scoreArray[0] = score;
- scoreResult.setStepScore(scoreArray);
- scoreResult.setTotalScore(score);
- } else {
- throw new RuntimeException("模版类型错误");
- }
- return scoreResult;
- } catch (Exception e) {
- log.error(e.getMessage() + " | " + res);
- return null;
- }
- }
- private String fomatStrByRex(String scoreStr) {
- int tag = scoreStr.lastIndexOf("</think>");
- if (tag != -1) {
- scoreStr = scoreStr.substring(tag).trim();
- }
- String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll("[0-9]\\.", "");
- Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}");
- Matcher matcher = pattern.matcher(ret);
- if (matcher.find()) {
- return matcher.group();
- } else {
- throw new RuntimeException("返回格式错误");
- }
- }
- private String fomatStr(String scoreStr) {
- scoreStr = scoreStr.substring(scoreStr.lastIndexOf("\n") + 1).trim();
- String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll(":", ":");
- ret = ret.substring(ret.lastIndexOf(":") + 1).trim();
- return ret;
- }
- // private String fomatStrByRex(String scoreStr) {
- // int tag = scoreStr.lastIndexOf("</think>");
- // if (tag != -1) {
- // scoreStr = scoreStr.substring(tag).trim();
- // }
- // String ret = scoreStr.replaceAll(",", ",").replaceAll("。",
- // "").replaceAll("[0-9]\\.", "");
- // Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}");
- // Matcher matcher = pattern.matcher(ret);
- // if (matcher.find()) {
- // return matcher.group();
- // } else {
- // throw new RuntimeException("返回格式错误");
- // }
- // }
- private int getDecimalPlaces(double value) {
- return Math.max(0, BigDecimal.valueOf(value).stripTrailingZeros().scale());
- }
- public static void main(String[] args) {
- String scoreStr = "</think>。\\n\\n\\n70,70,60\\n\\n评分结果2个3,29,110 \\n\\n考生的回答完全覆盖了所有的关键内容,逻辑清晰,术语使用准确";
- scoreStr = scoreStr.substring(scoreStr.lastIndexOf("</think>") + 1).trim();
- System.out.println(Runtime.getRuntime().availableProcessors());
- String ret = scoreStr.replaceAll(",", ",").replaceAll("。", "").replaceAll(":", ":").replaceAll("[0-9]\\.", "");
- Pattern pattern = Pattern.compile("(\\d{1,3}\\s*,\\s*)+\\d{1,3}");
- Matcher matcher = pattern.matcher(ret);
- if (matcher.find()) {
- System.out.println(matcher.group());
- }
- }
- @SuppressWarnings("deprecation")
- private String marking(ChatReq dreq) {
- Map<String, String> headers = new HashMap<>();
- headers.put("Authorization", "Bearer " + sysProperty.getMarkingKey());
- Response resp = null;
- try {
- resp = OKHttpUtil.call(HttpMethod.POST, sysProperty.getMarkingServer(), headers,
- JSONObject.toJSONString(dreq));
- if (resp.code() != 200) {
- throw new RetrofitResponseError(resp.code(), resp.body().string());
- } else {
- return resp.body().string();
- }
- } catch (Exception e) {
- throw new RetrofitResponseError(500, e.getMessage(), e);
- } finally {
- IOUtils.closeQuietly(resp);
- }
- }
- @SuppressWarnings("deprecation")
- private String ocr(OcrServer ocrServer, ChatReq dreq) {
- Map<String, String> headers = new HashMap<>();
- headers.put("Authorization", "Bearer " + ocrServer.getKey());
- Response resp = null;
- try {
- resp = OKHttpUtil.call(HttpMethod.POST, ocrServer.getServer(), headers, JSONObject.toJSONString(dreq));
- if (resp.code() != 200) {
- throw new RetrofitResponseError(resp.code(), resp.body().string());
- } else {
- return resp.body().string();
- }
- } catch (RetrofitResponseError e) {
- throw e;
- } catch (Exception e) {
- throw new RetrofitResponseError(500, e.getMessage(), e);
- } finally {
- IOUtils.closeQuietly(resp);
- }
- }
- }
|