xiatian 2 miesięcy temu
rodzic
commit
bfde424e04

+ 10 - 0
src/main/java/cn/com/qmth/am/bean/ds/ChatContent.java

@@ -22,4 +22,14 @@ public class ChatContent {
         this.content = content;
     }
 
+    public ChatContent() {
+        super();
+    }
+
+    public ChatContent(ChatRole role, String content) {
+        super();
+        this.role = role;
+        this.content = content;
+    }
+
 }

+ 26 - 0
src/main/java/cn/com/qmth/am/bean/ds/TrainData.java

@@ -0,0 +1,26 @@
+package cn.com.qmth.am.bean.ds;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class TrainData {
+
+    private List<ChatContent> messages;
+
+    public List<ChatContent> getMessages() {
+        return messages;
+    }
+
+    public void setMessages(List<ChatContent> messages) {
+        this.messages = messages;
+    }
+
+    public TrainData(String sys, String user, String assistant) {
+        super();
+        messages = new ArrayList<>();
+        messages.add(new ChatContent(ChatRole.system, sys));
+        messages.add(new ChatContent(ChatRole.user, user));
+        messages.add(new ChatContent(ChatRole.assistant, assistant));
+    }
+
+}

+ 5 - 0
src/main/java/cn/com/qmth/am/config/InitData.java

@@ -8,6 +8,7 @@ import org.springframework.stereotype.Component;
 
 import cn.com.qmth.am.service.QuestionService;
 import cn.com.qmth.am.service.StudentScoreService;
+import cn.com.qmth.am.service.impl.ToolService;
 
 @Component
 public class InitData implements CommandLineRunner {
@@ -21,6 +22,9 @@ public class InitData implements CommandLineRunner {
     @Autowired
     private QuestionService questionService;
 
+    @Autowired
+    private ToolService toolService;
+
     @Override
     public void run(String... args) throws Exception {
         File dataDir = new File(sysProperty.getDataDir());
@@ -36,6 +40,7 @@ public class InitData implements CommandLineRunner {
             sheet.mkdir();
         }
         resetTaskStatus();
+        toolService.expTrainData();
     }
 
     private void resetTaskStatus() {

+ 1 - 1
src/main/java/cn/com/qmth/am/enums/PromptTemplate.java

@@ -2,7 +2,7 @@ package cn.com.qmth.am.enums;
 
 public enum PromptTemplate {
 
-    COMMON("common.ftl"), TRANSLATION("translation.ftl"),;
+    COMMON("common.ftl"), TRANSLATION("translation.ftl"), TRAIN_TRANSLATION("train_translation.ftl"),;
 
     private PromptTemplate(String code) {
         this.code = code;

+ 2 - 0
src/main/java/cn/com/qmth/am/service/StudentScoreService.java

@@ -51,4 +51,6 @@ public interface StudentScoreService extends IService<StudentScoreEntity> {
 
     void updateMarkingScoreAndTrack(StudentScoreVo score);
 
+    List<StudentScoreEntity> findToTrain(Long questionId);
+
 }

+ 9 - 0
src/main/java/cn/com/qmth/am/service/impl/StudentScoreServiceImpl.java

@@ -454,6 +454,15 @@ public class StudentScoreServiceImpl extends ServiceImpl<StudentScoreDao, Studen
         return this.list(wrapper);
     }
 
+    @Override
+    public List<StudentScoreEntity> findToTrain(Long questionId) {
+        QueryWrapper<StudentScoreEntity> wrapper = new QueryWrapper<>();
+        LambdaQueryWrapper<StudentScoreEntity> lw = wrapper.lambda();
+        lw.eq(StudentScoreEntity::getQuestionId, questionId);
+        lw.eq(StudentScoreEntity::getScoreStatus, DataStatus.SUCCESS);
+        return this.list(wrapper);
+    }
+
     @Override
     public List<StudentScoreEntity> findAllToAiMarking() {
         QueryWrapper<StudentScoreEntity> wrapper = new QueryWrapper<>();

+ 131 - 0
src/main/java/cn/com/qmth/am/service/impl/ToolService.java

@@ -0,0 +1,131 @@
+package cn.com.qmth.am.service.impl;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+import org.apache.commons.io.FileUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.stereotype.Service;
+
+import com.alibaba.fastjson.JSONObject;
+
+import cn.com.qmth.am.bean.ds.AutoScoreRequest;
+import cn.com.qmth.am.bean.ds.TrainData;
+import cn.com.qmth.am.entity.QuestionEntity;
+import cn.com.qmth.am.entity.StudentScoreEntity;
+import cn.com.qmth.am.enums.PromptTemplate;
+import cn.com.qmth.am.service.DsMarkingService;
+import cn.com.qmth.am.service.QuestionService;
+import cn.com.qmth.am.service.StudentScoreService;
+import cn.com.qmth.am.utils.FreeMarkerUtil;
+
+@Service
+public class ToolService {
+
+    private static final Logger log = LoggerFactory.getLogger(DsMarkingService.class);
+
+    @Autowired
+    private QuestionService questionService;
+
+    @Autowired
+    private StudentScoreService studentScoreService;
+
+    public void expTrainData() {
+        log.warn("*************************expTrainData start");
+        List<QuestionEntity> qs = questionService.list();
+        for (QuestionEntity q : qs) {
+            List<StudentScoreEntity> ss = studentScoreService.findToTrain(q.getId());
+            List<StudentScoreEntity> ret = getUniformlyDistributedNumbers(getValid(ss, q.getFullScore()), 10000);
+            File f = new File("e:/files/" + q.getTitle() + ".jsonl");
+            if (f.exists()) {
+                f.delete();
+            }
+            try {
+                f.createNewFile();
+                List<String> lines = new ArrayList<>();
+                for (StudentScoreEntity s : ret) {
+                    AutoScoreRequest req = new AutoScoreRequest();
+                    req.setQuestionBody(q.getContent());
+                    req.setStandardAnswer(q.getAnswer());
+                    req.setStudentAnswer(s.getAnswer());
+                    req.setSubjectName(q.getSubjectName());
+                    req.setTotalScore(q.getFullScore());
+                    req.setIntervalScore(0.1);
+                    req.setQuestionTitle(q.getTitle());
+                    req.setExt(q.getExt());
+                    req.setScoreGrades(q.getScoreGrades());
+                    TrainData td = new TrainData("作为" + q.getSubjectName() + "科目" + q.getTitle() + "试题评分员",
+                            FreeMarkerUtil.getMarkingReq(req, PromptTemplate.TRAIN_TRANSLATION), s.getAiScore() + "");
+                    lines.add(JSONObject.toJSONString(td));
+                }
+                FileUtils.writeLines(f, "utf-8", lines);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+        log.warn("*************************expTrainData end");
+    }
+
+    private List<StudentScoreEntity> getValid(List<StudentScoreEntity> ss, double score) {
+        List<StudentScoreEntity> ret = new ArrayList<>();
+        double roundedValue = Math.round(score * 0.3);
+        for (StudentScoreEntity s : ss) {
+            if (s.getAiScore() - s.getMarkingScore() < roundedValue
+                    || s.getMarkingScore() - s.getAiScore() < roundedValue) {
+                ret.add(s);
+            }
+        }
+
+        return ret;
+    }
+
+    private List<StudentScoreEntity> getUniformlyDistributedNumbers(List<StudentScoreEntity> numbers, int count) {
+        if (numbers == null || numbers.isEmpty()) {
+            return new ArrayList<>();
+        }
+
+        // 如果需要的数字数量大于或等于列表长度,直接返回原列表
+        if (count >= numbers.size()) {
+            return new ArrayList<>(numbers);
+        }
+
+        // 对列表进行排序
+        numbers.sort(new Comparator<StudentScoreEntity>() {
+
+            @Override
+            public int compare(StudentScoreEntity o1, StudentScoreEntity o2) {
+                Double c1 = o1.getAiScore();
+                Double c2 = o2.getAiScore();
+                if (c1 < c2) {
+                    return -1;
+                } else if (c1 > c2) {
+                    return 1;
+                } else {
+                    return 0;
+                }
+            }
+        });
+
+        // 计算步长(均匀分布的间隔)
+        double step = (double) (numbers.size() - 1) / (count - 1);
+
+        // 选取均匀分布的数字
+        List<StudentScoreEntity> result = new ArrayList<>();
+        for (int i = 0; i < count; i++) {
+            // 计算当前索引位置
+            int index = (int) Math.round(i * step);
+            // 确保索引不超过列表长度
+            if (index >= numbers.size()) {
+                index = numbers.size() - 1;
+            }
+            result.add(numbers.get(index));
+        }
+
+        return result;
+    }
+}

+ 26 - 0
templates/train_translation.ftl

@@ -0,0 +1,26 @@
+请严格按照以下标准为考生作答进行打分:
+
+# 试题内容
+${questionBody}
+
+<#if standardAnswer?? && (standardAnswer?size > 0) >
+# 参考答案
+<#list standardAnswer as item> 
+${item.content}
+</#list>
+</#if>
+
+# 评分规则
+${scoreGrades}
+
+# 评分流程
+1. 理解试题内容与评分规则,分析考生作答内容,对比参考答案,准确判断考生作答属于哪个档次
+2. 从所属档次的分值范围中选择一个合适的分数作为最终评分结果,能准确反映考生作答的情况
+3. 若考生作答仅包含试题名称或试题内容,没有有效作答,直接判为0分
+
+
+# 考生作答
+${studentAnswer}
+
+# 输出要求:
+直接输出最终评分结果,用数字表示,无需其他文字说明。