import cn.com.qmth.stmms.ms.Application; import cn.com.qmth.stmms.ms.marking.utils.RandomUtil; import cn.com.qmth.stmms.ms.core.domain.Level; import cn.com.qmth.stmms.ms.core.domain.MarkSubject; import cn.com.qmth.stmms.ms.core.domain.Paper; import cn.com.qmth.stmms.ms.core.domain.task.MarkTaskLevel; import cn.com.qmth.stmms.ms.core.repository.LevelRepo; import cn.com.qmth.stmms.ms.core.repository.MarkSubjectRepo; import cn.com.qmth.stmms.ms.core.repository.MarkTaskLevelRepo; import cn.com.qmth.stmms.ms.core.repository.PaperRepo; import cn.com.qmth.stmms.ms.core.vo.Subject; import org.junit.Test; import org.junit.runner.RunWith; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.junit4.SpringRunner; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @RunWith(SpringRunner.class) @SpringBootTest(classes = {Application.class}) // 指定启动类 public class BatchLevelTest { private static final Logger logger = LoggerFactory.getLogger(BatchLevelTest.class); @Autowired private PaperRepo paperRepo; @Autowired private MarkSubjectRepo markSubjectRepo; @Autowired private MarkTaskLevelRepo markTaskLevelRepo; @Autowired private LevelRepo levelRepo; @Autowired RandomUtil randomUtil; // 工作id private static Long WORK_ID = 1L; // 生成科目 private List subjects = Arrays.asList(Subject.SX); /** * 批量分档 * * @throws Exception */ // @Test public void level() { List levels = levelRepo.findByWorkIdOrderByCode(WORK_ID); doLevel(levels); } private void doLevel(List levels) { if (CollectionUtils.isEmpty(levels)) { throw new RuntimeException("档位值不存在"); } for (Subject subject : subjects) { List list = new ArrayList<>(); List listTask = new ArrayList<>(); MarkSubject markSubject = markSubjectRepo.findOne(WORK_ID + "-" + subject.name()); List papers = paperRepo.findByWorkIdAndSubject(WORK_ID, subject); papers = papers.stream().filter(m -> Objects.nonNull(m.getBatchNo()) && StringUtils.isEmpty(m.getLevel())).collect(Collectors.toList()); for (Paper paper : papers) { String level = levels.get(getRandom(levels.size())).getCode(); paper.setLevel(level); list.add(paper); List markTasks = markTaskLevelRepo.findByPaperIdAndStage(paper.getId(), markSubject.getStage()); for (MarkTaskLevel markTask : markTasks) { markTask.setResult(level); markTask.setLevel(level); listTask.add(markTask); } } List data1 = new ArrayList<>(); for (Paper p : list) { if (data1.size() == 2000) { paperRepo.save(data1); data1.clear(); } data1.add(p); } //将剩下的数据也导入 if (!data1.isEmpty()) { paperRepo.save(data1); } List data2 = new ArrayList<>(); for (MarkTaskLevel p : listTask) { if (data2.size() == 2000) { markTaskLevelRepo.save(data2); data2.clear(); } data2.add(p); } //将剩下的数据也导入 if (!data2.isEmpty()) { markTaskLevelRepo.save(data2); } } } private int getRandom(int size) { return (int) (Math.random() * size); } }