فهرست منبع

core-ai autoScore update

deason 2 ماه پیش
والد
کامیت
3d6065375d
1فایلهای تغییر یافته به همراه49 افزوده شده و 23 حذف شده
  1. 49 23
      core-ai/src/main/java/com/qmth/boot/core/ai/service/AiService.java

+ 49 - 23
core-ai/src/main/java/com/qmth/boot/core/ai/service/AiService.java

@@ -2,18 +2,19 @@ package com.qmth.boot.core.ai.service;
 
 import com.qmth.boot.core.ai.client.LlmApiClient;
 import com.qmth.boot.core.ai.model.llm.*;
+import com.qmth.boot.core.ai.model.llm.score.AutoScoreModel;
 import com.qmth.boot.core.ai.model.llm.score.AutoScoreRequest;
 import com.qmth.boot.core.ai.model.llm.score.AutoScoreResult;
+import com.qmth.boot.core.ai.model.llm.score.StandardAnswer;
 import com.qmth.boot.core.retrofit.utils.SignatureInfo;
 import org.apache.commons.lang3.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import org.springframework.stereotype.Service;
 import org.springframework.validation.annotation.Validated;
 
 import javax.annotation.Resource;
 import javax.validation.constraints.NotNull;
-import java.math.BigDecimal;
-import java.math.RoundingMode;
-import java.util.Arrays;
 import java.util.List;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
@@ -26,6 +27,8 @@ public class AiService {
 
     private static final Pattern score_pattern = Pattern.compile("[^\\d]*(\\d+)[^\\d]*");
 
+    private static final Logger log = LoggerFactory.getLogger(AiService.class);
+
     @Resource
     private LlmApiClient llmApiClient;
 
@@ -71,39 +74,62 @@ public class AiService {
      *
      * @param request   自动判分请求参数
      * @param signature 使用当前机构AK作为鉴权信息
-     * @return 得分率,保留最多三位小数;null表示无法获取判分结果
+     * @return 判分结果;null表示无法获取判分结果
      */
     public AutoScoreResult autoScore(@NotNull @Validated AutoScoreRequest request, @NotNull SignatureInfo signature) {
         ChatResult result = llmApiClient.chatTemplate(signature, LlmAppType.AUTO_SCORE, request);
         String text = result.getChoices().stream().filter(choice -> choice.getMessage().getRole() == ChatRole.assistant)
                 .map(choice -> choice.getMessage().getContent()).findFirst().orElse("");
+
         try {
-            AutoScoreResult scoreResult = new AutoScoreResult();
-            //依据总分与步骤分计算最大精度
-            int scale = Math
-                    .max(getDecimalPlaces(request.getIntervalScore()), getDecimalPlaces(request.getTotalScore()));
+            String[] values = StringUtils.split(text.replaceAll(",", ","), ",");
             int stepCount = request.getStandardAnswer().size();
-            String[] scores = StringUtils.split(text.replaceAll(",", ","), ",");
-            double[] scoreArray = new double[stepCount];
+            if (stepCount != values.length) {
+                log.warn("评分结果无效:{},要求给出{}个关键分值,实际{}个", text, stepCount, values.length);
+                return null;
+            }
+
+            double totalScore = 0;
+            double[] scores = new double[stepCount];
             for (int i = 0; i < stepCount; i++) {
-                //根据得分率与步骤总分计算实际得分,按最大精度保留小数位数
-                double score = BigDecimal.valueOf(
-                        Math.min(Integer.parseInt(scores[i].trim()), 100) * request.getStandardAnswer().get(i)
-                                .getScore()).divide(BigDecimal.valueOf(100), scale, RoundingMode.HALF_UP).doubleValue();
-                scoreArray[i] = score;
+                double stepScore = Double.parseDouble(values[i].trim());
+                StandardAnswer step = request.getStandardAnswer().get(i);
+
+                if (AutoScoreModel.LEVEL == request.getScoreModel()) {
+                    // 按档次给分模式:关键步骤得分应介于当前档次分值区间
+                    if (stepScore < step.getLowScore()) {
+                        log.warn("关键步骤得分无效:{},当前档次分值区间{}-{}分", stepScore, step.getLowScore(), step.getHighScore());
+                        scores[i] = 0;
+                    } else if (stepScore > step.getHighScore()) {
+                        scores[i] = Math.min(stepScore, step.getHighScore());
+                    } else {
+                        scores[i] = stepScore;
+                    }
+
+                    // 处理分值符合最小间隔分 todo
+                    totalScore = Math.max(totalScore, scores[i]);
+                } else {
+                    // 按得分点给分模式:关键步骤得分应小于等于当前得分点的分值
+                    if (stepScore > step.getScore()) {
+                        log.warn("关键步骤得分无效:{},当前得分点分值{}分", stepScore, step.getScore());
+                        scores[i] = 0;
+                        continue;
+                    }
+
+                    // 处理分值符合最小间隔分 todo
+                    scores[i] = stepScore;
+                    totalScore += stepScore;
+                }
             }
-            scoreResult.setStepScore(scoreArray);
-            scoreResult.setTotalScore(
-                    Arrays.stream(scoreArray).mapToObj(BigDecimal::new).reduce(BigDecimal.ZERO, BigDecimal::add)
-                            .doubleValue());
+
+            AutoScoreResult scoreResult = new AutoScoreResult();
+            scoreResult.setStepScore(scores);
+            scoreResult.setTotalScore(totalScore);
             return scoreResult;
         } catch (Exception e) {
+            log.warn("评分结果无效:{},错误:{}", text, e.getMessage());
             return null;
         }
     }
 
-    private static int getDecimalPlaces(double value) {
-        return Math.max(0, BigDecimal.valueOf(value).stripTrailingZeros().scale());
-    }
-
 }