|
@@ -0,0 +1,131 @@
|
|
|
+package cn.com.qmth.am.service.impl;
|
|
|
+
|
|
|
+import java.io.File;
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Comparator;
|
|
|
+import java.util.List;
|
|
|
+
|
|
|
+import org.apache.commons.io.FileUtils;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.stereotype.Service;
|
|
|
+
|
|
|
+import com.alibaba.fastjson.JSONObject;
|
|
|
+
|
|
|
+import cn.com.qmth.am.bean.ds.AutoScoreRequest;
|
|
|
+import cn.com.qmth.am.bean.ds.TrainData;
|
|
|
+import cn.com.qmth.am.entity.QuestionEntity;
|
|
|
+import cn.com.qmth.am.entity.StudentScoreEntity;
|
|
|
+import cn.com.qmth.am.enums.PromptTemplate;
|
|
|
+import cn.com.qmth.am.service.DsMarkingService;
|
|
|
+import cn.com.qmth.am.service.QuestionService;
|
|
|
+import cn.com.qmth.am.service.StudentScoreService;
|
|
|
+import cn.com.qmth.am.utils.FreeMarkerUtil;
|
|
|
+
|
|
|
+@Service
|
|
|
+public class ToolService {
|
|
|
+
|
|
|
+ private static final Logger log = LoggerFactory.getLogger(DsMarkingService.class);
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private QuestionService questionService;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private StudentScoreService studentScoreService;
|
|
|
+
|
|
|
+ public void expTrainData() {
|
|
|
+ log.warn("*************************expTrainData start");
|
|
|
+ List<QuestionEntity> qs = questionService.list();
|
|
|
+ for (QuestionEntity q : qs) {
|
|
|
+ List<StudentScoreEntity> ss = studentScoreService.findToTrain(q.getId());
|
|
|
+ List<StudentScoreEntity> ret = getUniformlyDistributedNumbers(getValid(ss, q.getFullScore()), 10000);
|
|
|
+ File f = new File("e:/files/" + q.getTitle() + ".jsonl");
|
|
|
+ if (f.exists()) {
|
|
|
+ f.delete();
|
|
|
+ }
|
|
|
+ try {
|
|
|
+ f.createNewFile();
|
|
|
+ List<String> lines = new ArrayList<>();
|
|
|
+ for (StudentScoreEntity s : ret) {
|
|
|
+ AutoScoreRequest req = new AutoScoreRequest();
|
|
|
+ req.setQuestionBody(q.getContent());
|
|
|
+ req.setStandardAnswer(q.getAnswer());
|
|
|
+ req.setStudentAnswer(s.getAnswer());
|
|
|
+ req.setSubjectName(q.getSubjectName());
|
|
|
+ req.setTotalScore(q.getFullScore());
|
|
|
+ req.setIntervalScore(0.1);
|
|
|
+ req.setQuestionTitle(q.getTitle());
|
|
|
+ req.setExt(q.getExt());
|
|
|
+ req.setScoreGrades(q.getScoreGrades());
|
|
|
+ TrainData td = new TrainData("作为" + q.getSubjectName() + "科目" + q.getTitle() + "试题评分员",
|
|
|
+ FreeMarkerUtil.getMarkingReq(req, PromptTemplate.TRAIN_TRANSLATION), s.getAiScore() + "");
|
|
|
+ lines.add(JSONObject.toJSONString(td));
|
|
|
+ }
|
|
|
+ FileUtils.writeLines(f, "utf-8", lines);
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ log.warn("*************************expTrainData end");
|
|
|
+ }
|
|
|
+
|
|
|
+ private List<StudentScoreEntity> getValid(List<StudentScoreEntity> ss, double score) {
|
|
|
+ List<StudentScoreEntity> ret = new ArrayList<>();
|
|
|
+ double roundedValue = Math.round(score * 0.3);
|
|
|
+ for (StudentScoreEntity s : ss) {
|
|
|
+ if (s.getAiScore() - s.getMarkingScore() < roundedValue
|
|
|
+ || s.getMarkingScore() - s.getAiScore() < roundedValue) {
|
|
|
+ ret.add(s);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return ret;
|
|
|
+ }
|
|
|
+
|
|
|
+ private List<StudentScoreEntity> getUniformlyDistributedNumbers(List<StudentScoreEntity> numbers, int count) {
|
|
|
+ if (numbers == null || numbers.isEmpty()) {
|
|
|
+ return new ArrayList<>();
|
|
|
+ }
|
|
|
+
|
|
|
+ // 如果需要的数字数量大于或等于列表长度,直接返回原列表
|
|
|
+ if (count >= numbers.size()) {
|
|
|
+ return new ArrayList<>(numbers);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 对列表进行排序
|
|
|
+ numbers.sort(new Comparator<StudentScoreEntity>() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int compare(StudentScoreEntity o1, StudentScoreEntity o2) {
|
|
|
+ Double c1 = o1.getAiScore();
|
|
|
+ Double c2 = o2.getAiScore();
|
|
|
+ if (c1 < c2) {
|
|
|
+ return -1;
|
|
|
+ } else if (c1 > c2) {
|
|
|
+ return 1;
|
|
|
+ } else {
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ // 计算步长(均匀分布的间隔)
|
|
|
+ double step = (double) (numbers.size() - 1) / (count - 1);
|
|
|
+
|
|
|
+ // 选取均匀分布的数字
|
|
|
+ List<StudentScoreEntity> result = new ArrayList<>();
|
|
|
+ for (int i = 0; i < count; i++) {
|
|
|
+ // 计算当前索引位置
|
|
|
+ int index = (int) Math.round(i * step);
|
|
|
+ // 确保索引不超过列表长度
|
|
|
+ if (index >= numbers.size()) {
|
|
|
+ index = numbers.size() - 1;
|
|
|
+ }
|
|
|
+ result.add(numbers.get(index));
|
|
|
+ }
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+}
|