xiatian 8 ay önce
ebeveyn
işleme
257cc894d4

+ 1 - 1
pom.xml

@@ -22,7 +22,7 @@
         <maven.compiler.target>1.8</maven.compiler.target>
         <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
         <project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
-        <qmth-boot-version>1.0.4</qmth-boot-version>
+        <qmth-boot-version>1.0.5</qmth-boot-version>
     </properties>
 
 	<dependencies>

+ 35 - 0
src/main/java/cn/com/qmth/am/bean/OcrDto.java

@@ -0,0 +1,35 @@
+package cn.com.qmth.am.bean;
+
+import java.io.File;
+
+import com.qmth.boot.core.solar.model.OrgInfo;
+
+public class OcrDto {
+
+    private File file;
+
+    private OrgInfo org;
+
+    public File getFile() {
+        return file;
+    }
+
+    public void setFile(File file) {
+        this.file = file;
+    }
+
+    public OrgInfo getOrg() {
+        return org;
+    }
+
+    public void setOrg(OrgInfo org) {
+        this.org = org;
+    }
+
+    public OcrDto(File file, OrgInfo org) {
+        super();
+        this.file = file;
+        this.org = org;
+    }
+
+}

+ 54 - 0
src/main/java/cn/com/qmth/am/multithread/AopTargetUtils.java

@@ -0,0 +1,54 @@
+package cn.com.qmth.am.multithread;
+
+import java.lang.reflect.Field;
+
+import org.springframework.aop.framework.AdvisedSupport;
+import org.springframework.aop.framework.AopProxy;
+import org.springframework.aop.support.AopUtils;
+
+public class AopTargetUtils {
+
+    public static Object getTarget(Object obj) {
+        if (!AopUtils.isAopProxy(obj)) {
+            return obj;
+        }
+        try {
+
+            // 判断是jdk还是cglib代理
+            if (AopUtils.isJdkDynamicProxy(obj)) {
+                obj = getJdkDynamicProxyTargetObject(obj);
+            } else {
+                obj = getCglibDynamicProxyTargetObject(obj);
+            }
+
+        } catch (Exception e) {
+
+        }
+        return obj;
+
+    }
+
+    private static Object getCglibDynamicProxyTargetObject(Object obj) throws Exception {
+        Field h = obj.getClass().getDeclaredField("CGLIB$CALLBACK_0");
+        h.setAccessible(true);
+
+        Object dynamicAdvisedInterceptor = h.get(obj);
+        Field advised = dynamicAdvisedInterceptor.getClass().getDeclaredField("advised");
+        advised.setAccessible(true);
+        Object target = ((AdvisedSupport) advised.get(dynamicAdvisedInterceptor)).getTargetSource().getTarget();
+        return target;
+    }
+
+    private static Object getJdkDynamicProxyTargetObject(Object obj) throws Exception {
+
+        Field h = obj.getClass().getSuperclass().getDeclaredField("h");
+        h.setAccessible(true);
+
+        AopProxy aopProxy = (AopProxy) h.get(obj);
+        Field advised = aopProxy.getClass().getDeclaredField("advised");
+        advised.setAccessible(true);
+        Object target = ((AdvisedSupport) advised.get(aopProxy)).getTargetSource().getTarget();
+        return target;
+
+    }
+}

+ 162 - 0
src/main/java/cn/com/qmth/am/multithread/Basket.java

@@ -0,0 +1,162 @@
+package cn.com.qmth.am.multithread;
+
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.qmth.boot.core.exception.StatusException;
+
+import cn.com.qmth.am.utils.Calculator;
+
+public class Basket<T> {
+
+    private String taskName;
+
+    private Integer total = 0;
+
+    private AtomicInteger process = new AtomicInteger(0);
+
+    private List<String> result = new CopyOnWriteArrayList<>();
+
+    private List<T> failed = new CopyOnWriteArrayList<>();
+
+    /**
+     * 数据阻塞队列
+     */
+    private BlockingQueue<Object> queue;
+
+    /**
+     * 多线程计数器,子线程都结束后主线程才继续执行
+     */
+    private CountDownLatch endGate;
+
+    /**
+     * 消费者数量
+     */
+    private int consumerCount;
+
+    /**
+     * 判断线程执行是否有出错,生产者、消费者出错都需要修改此值为true
+     */
+    private boolean isExcuteError = false;
+
+    public Basket(int consumerCount, String taskName) {
+        this.consumerCount = consumerCount;
+        this.taskName = taskName;
+        queue = new ArrayBlockingQueue<Object>(consumerCount * 2);
+        endGate = new CountDownLatch(consumerCount);
+    }
+
+    /**
+     * 生产数据,不采用put方法防止消费线程全部异常后生产线程阻塞
+     * 
+     * @param value
+     * @throws InterruptedException
+     */
+    protected void offer(final Object value) throws InterruptedException {
+        if (isExcuteError) {
+            throw new StatusException("线程异常");
+        } else {
+            boolean ret = queue.offer(value, 5, TimeUnit.SECONDS);
+            if (!ret) {
+                this.offer(value);
+            }
+        }
+    }
+
+    /**
+     * 消费数据,不采用take方法防止生产线程全部异常后消费线程阻塞
+     * 
+     * @return
+     * @throws InterruptedException
+     */
+    protected Object consume() throws InterruptedException {
+        if (isExcuteError) {
+            return new EndObject();
+        } else {
+            Object ob = queue.poll(5, TimeUnit.SECONDS);
+            if (ob == null) {
+                return this.consume();
+            } else {
+                return ob;
+            }
+        }
+    }
+
+    protected void endGateReset() {
+        endGate = new CountDownLatch(consumerCount);
+    }
+
+    protected void await() throws InterruptedException {
+        endGate.await();
+    }
+
+    protected void countDown() {
+        endGate.countDown();
+    }
+
+    protected boolean isExcuteError() {
+        return isExcuteError;
+    }
+
+    protected void setExcuteError(boolean isExcuteError) {
+        this.isExcuteError = isExcuteError;
+    }
+
+    protected int getConsumerCount() {
+        return consumerCount;
+    }
+
+    protected void setConsumerCount(int consumerCount) {
+        this.consumerCount = consumerCount;
+    }
+
+    public Integer getTotal() {
+        return total;
+    }
+
+    protected void setTotal(Integer total) {
+        this.total = total;
+    }
+
+    public AtomicInteger getProcess() {
+        return process;
+    }
+
+    protected void updateProcess(int add) {
+        process.addAndGet(add);
+    }
+
+    public String getProgress() {
+        if (total == 0) {
+            return "0%";
+        }
+        Double d = Calculator.divide(process.doubleValue(), total.doubleValue(), 4);
+        Double f = Calculator.multiply(d, 100);
+        return f + "%";
+    }
+
+    public List<String> getMsgs() {
+        return result;
+    }
+
+    public void addMsg(String msg) {
+        result.add(msg);
+    }
+
+    public String getTaskName() {
+        return taskName;
+    }
+
+    public void addFailDto(T t) {
+        failed.add(t);
+    }
+
+    public List<T> getFaildDto() {
+        return failed;
+    }
+}

+ 83 - 0
src/main/java/cn/com/qmth/am/multithread/Consumer.java

@@ -0,0 +1,83 @@
+package cn.com.qmth.am.multithread;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.qmth.boot.core.exception.StatusException;
+
+public abstract class Consumer<T> extends Thread {
+
+    private static final Logger LOG = LoggerFactory.getLogger(Consumer.class);
+
+    private Basket<T> basket;
+
+    private Consumer<T> consumer;
+
+    public Consumer() {
+    }
+
+    public Consumer<T> getConsumer() {
+        return consumer;
+    }
+
+    public void setConsumer(Consumer<T> consumer) {
+        this.consumer = consumer;
+    }
+
+    @Override
+    public void run() {
+        try {
+            while (true) {
+                // 先判断是否有异常结束
+                if (basket.isExcuteError()) {
+                    break;
+                }
+                // 取消费数据
+                Object o = basket.consume();
+                // 判断消费数据是否是结束
+                if (o instanceof EndObject) {
+                    break;
+                }
+                @SuppressWarnings("unchecked")
+                T t = (T) o;
+                // 消费数据实现
+                int disposeCount = consumer.consume(t);
+                if (basket.getTotal() > 0) {
+                    basket.updateProcess(disposeCount);
+                    processInfo();
+                }
+            }
+        } catch (StatusException e) {
+            LOG.error(e.getMessage(), e);
+            addMsg(e.getMessage());
+            basket.setExcuteError(true);
+        } catch (Exception e) {
+            LOG.error(e.getMessage(), e);
+            basket.setExcuteError(true);
+        } finally {
+            basket.countDown();
+        }
+    }
+
+    protected abstract int consume(T t);
+
+    protected void setBasket(Basket<T> basket) {
+        this.basket = basket;
+    }
+
+    protected Basket<T> getBasket() {
+        return this.basket;
+    }
+
+    protected void addMsg(String msg) {
+        this.basket.addMsg(msg);
+    }
+
+    protected void addFailDto(T t) {
+        this.basket.addFailDto(t);
+    }
+
+    protected void processInfo() {
+        LOG.info(basket.getTaskName() + " 处理进度" + basket.getProgress());
+    }
+}

+ 11 - 0
src/main/java/cn/com/qmth/am/multithread/EndObject.java

@@ -0,0 +1,11 @@
+package cn.com.qmth.am.multithread;
+
+/**
+ * 消费结束标识对象
+ * 
+ * @author xiatian
+ *
+ */
+public class EndObject {
+
+}

+ 167 - 0
src/main/java/cn/com/qmth/am/multithread/Producer.java

@@ -0,0 +1,167 @@
+package cn.com.qmth.am.multithread;
+
+import java.lang.reflect.ParameterizedType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.qmth.boot.core.exception.StatusException;
+
+import cn.com.qmth.am.utils.SpringContextHolder;
+
+public abstract class Producer<T, C extends Consumer<T>> {
+
+    private static final Logger LOG = LoggerFactory.getLogger(Producer.class);
+
+    private Basket<T> basket;
+
+    /**
+     * 消费线程class
+     */
+    private List<Consumer<T>> consumers;
+
+    public void startDispose(int consumerCount) {
+        startDispose(consumerCount, null, 0);
+    }
+
+    /**
+     * 处理开始方法
+     */
+    public void startDispose(int consumerCount, Map<String, Object> param) {
+        startDispose(consumerCount, param, 0);
+    }
+
+    public void startDispose(int consumerCount, int total) {
+        startDispose(consumerCount, null, total);
+    }
+
+    public void startDispose(int consumerCount, Map<String, Object> param, int total) {
+        // 启动消费者
+        startConsumer(consumerCount, total);
+        // 开始处理
+        dispose(param);
+    }
+
+    @SuppressWarnings("unchecked")
+    private void startConsumer(int consumerCount, int total) {
+        if (consumerCount <= 0) {
+            consumerCount = 1;
+        }
+        ParameterizedType pt = (ParameterizedType) this.getClass().getGenericSuperclass();
+        Class<C> clazz = (Class<C>) pt.getActualTypeArguments()[1];
+        consumers = new ArrayList<>();
+        this.basket = new Basket<T>(consumerCount, getTaskName());
+        basket.setTotal(total);
+        // 启动消费者
+        int count = basket.getConsumerCount();
+        for (int i = 0; i < count; i++) {
+            Consumer<T> co = SpringContextHolder.getBean(clazz);
+            co.setBasket(basket);
+            co.setConsumer(co);
+            co.start();
+            consumers.add((Consumer<T>) AopTargetUtils.getTarget(co));
+        }
+    }
+
+    private void dispose(Map<String, Object> param) {
+        try {
+            // 生产数据
+            int index = 0;
+            for (;;) {
+                T dto = findData(param, index);
+                if (dto == null) {
+                    // 拿不到数据,结束消费
+                    break;
+                }
+                offer(dto);
+                index++;
+            }
+            // 发送生产结束信息
+            endConsumer();
+            // 等待子线程结束
+            await();
+            // 判断子线程是否正常结束
+            if (basket.isExcuteError()) {
+                throw new StatusException("处理失败,线程异常");
+            }
+        } catch (StatusException e) {
+            LOG.error(e.getMessage(), e);
+            // 获取异常时发送异常结束信息
+            endConsumerAsError();
+            throw e;
+        } catch (Exception e) {
+            LOG.error(e.getMessage(), e);
+            // 获取异常时发送异常结束信息
+            endConsumerAsError();
+            throw new StatusException("处理失败", e);
+        }
+    }
+
+    /**
+     * 出异常后修改标识
+     * 
+     */
+    private void endConsumerAsError() {
+        basket.setExcuteError(true);
+    }
+
+    /**
+     * 正常结束消费者
+     * 
+     * @throws InterruptedException
+     */
+    private void endConsumer() throws InterruptedException {
+        int count = basket.getConsumerCount();
+        EndObject eo = new EndObject();
+        for (int i = 0; i < count; i++) {
+            basket.offer(eo);
+        }
+
+    }
+
+    /**
+     * 生产数据
+     * 
+     * @param ob
+     * @throws InterruptedException
+     */
+    private void offer(Object ob) throws InterruptedException {
+        synchronized (basket) {
+            basket.offer(ob);
+        }
+    }
+
+    /**
+     * 等待所有消费者结束
+     * 
+     * @throws InterruptedException
+     */
+    private void await() throws InterruptedException {
+        basket.await();
+    }
+
+    protected abstract T findData(Map<String, Object> param, int index);
+
+    protected abstract String getTaskName();
+
+    public List<String> getMsgs() {
+        return this.basket.getMsgs();
+    }
+
+    public Integer getTotal() {
+        return this.basket.getTotal();
+    }
+
+    public AtomicInteger getProcess() {
+        return this.basket.getProcess();
+    }
+
+    public List<T> getFaildDto() {
+        return this.basket.getFaildDto();
+    }
+
+}

+ 102 - 0
src/main/java/cn/com/qmth/am/multithread/consumer/OcrConsumer.java

@@ -0,0 +1,102 @@
+package cn.com.qmth.am.multithread.consumer;
+
+import java.awt.image.BufferedImage;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.IOException;
+
+import javax.imageio.ImageIO;
+
+import org.apache.commons.io.FileUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Scope;
+import org.springframework.stereotype.Service;
+
+import com.qmth.boot.core.ai.client.OcrApiClient;
+import com.qmth.boot.core.ai.model.ocr.OcrType;
+import com.qmth.boot.core.exception.StatusException;
+import com.qmth.boot.core.retrofit.exception.RetrofitResponseError;
+import com.qmth.boot.core.retrofit.utils.SignatureInfo;
+import com.qmth.boot.core.retrofit.utils.UploadFile;
+import com.qmth.boot.core.solar.model.OrgInfo;
+
+import cn.com.qmth.am.bean.OcrDto;
+import cn.com.qmth.am.bean.StudentScoreImageDto;
+import cn.com.qmth.am.multithread.Consumer;
+
+@Scope("prototype")
+@Service
+public class OcrConsumer extends Consumer<OcrDto> {
+
+    private static final Logger log = LoggerFactory.getLogger(OcrConsumer.class);
+
+    @Autowired
+    private OcrApiClient ocrApiClient;
+
+    @Override
+    public int consume(OcrDto ocrDto) {
+        File file = ocrDto.getFile();
+        String name = file.getName().substring(0, file.getName().lastIndexOf("."));
+        File txt = new File(file.getParentFile().getAbsolutePath() + "/" + name + ".txt");
+        if (!txt.exists()) {
+            try {
+                StudentScoreImageDto dto = new StudentScoreImageDto();
+                dto.setImage(fileToByte(file));
+                String content = ocrDispose(dto, ocrDto.getOrg());
+                FileUtils.write(txt, content, "utf-8");
+            } catch (Exception e) {
+                log.error("ocr异常", e);
+                addFailDto(ocrDto);
+            }
+        }
+        return 1;
+    }
+
+    private byte[] fileToByte(File img) {
+        ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        try {
+            BufferedImage bi;
+            bi = ImageIO.read(img);
+            ImageIO.write(bi, "jpg", baos);
+            byte[] bytes = baos.toByteArray();
+            return bytes;
+        } catch (Exception e) {
+            throw new RuntimeException(e);
+        } finally {
+            try {
+                baos.close();
+            } catch (IOException e) {
+            }
+        }
+    }
+
+    private String ocrDispose(StudentScoreImageDto dto, OrgInfo org) {
+        SignatureInfo signature = SignatureInfo.secret(org.getAccessKey(), org.getAccessSecret());
+        try {
+            return ocrApiClient.forImage(signature, OcrType.HANDWRITING, UploadFile.build("image", "", dto.getImage()));
+        } catch (Exception e) {
+            log.error("ocr异常", e);
+            if (e instanceof RetrofitResponseError) {
+                RetrofitResponseError tem = (RetrofitResponseError) e;
+                if (tem.getCode() == 503) {
+                    if (dto.getRetry() <= 3) {
+                        try {
+                            Thread.sleep(3000);
+                        } catch (InterruptedException e1) {
+                        }
+                        dto.setRetry(dto.getRetry() + 1);
+                        return ocrDispose(dto, org);
+                    } else {
+                        throw new StatusException("重试次数过多");
+                    }
+                } else {
+                    throw e;
+                }
+            } else {
+                throw e;
+            }
+        }
+    }
+}

+ 31 - 0
src/main/java/cn/com/qmth/am/multithread/producer/OcrProducer.java

@@ -0,0 +1,31 @@
+package cn.com.qmth.am.multithread.producer;
+
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.stereotype.Service;
+
+import cn.com.qmth.am.bean.OcrDto;
+import cn.com.qmth.am.multithread.Producer;
+import cn.com.qmth.am.multithread.consumer.OcrConsumer;
+
+@Service
+public class OcrProducer extends Producer<OcrDto, OcrConsumer> {
+
+    @SuppressWarnings("unchecked")
+    @Override
+    protected OcrDto findData(Map<String, Object> param, int index) {
+        List<OcrDto> ret = (List<OcrDto>) param.get("files");
+
+        if (index >= ret.size()) {
+            return null;
+        }
+        return ret.get(index);
+    }
+
+    @Override
+    protected String getTaskName() {
+        return "本地OCR处理";
+    }
+
+}

+ 28 - 79
src/main/java/cn/com/qmth/am/service/impl/OcrServiceImpl.java

@@ -1,28 +1,22 @@
 package cn.com.qmth.am.service.impl;
 
-import java.awt.image.BufferedImage;
-import java.io.ByteArrayOutputStream;
 import java.io.File;
-import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
-import javax.imageio.ImageIO;
-
-import org.apache.commons.io.FileUtils;
+import org.apache.commons.collections4.CollectionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Service;
 
-import com.qmth.boot.core.ai.client.OcrApiClient;
-import com.qmth.boot.core.ai.model.ocr.OcrType;
-import com.qmth.boot.core.exception.StatusException;
-import com.qmth.boot.core.retrofit.exception.RetrofitResponseError;
-import com.qmth.boot.core.retrofit.utils.SignatureInfo;
-import com.qmth.boot.core.retrofit.utils.UploadFile;
 import com.qmth.boot.core.solar.model.OrgInfo;
 import com.qmth.boot.core.solar.service.SolarService;
 
-import cn.com.qmth.am.bean.StudentScoreImageDto;
+import cn.com.qmth.am.bean.OcrDto;
+import cn.com.qmth.am.multithread.producer.OcrProducer;
 import cn.com.qmth.am.service.OcrService;
 
 @Service
@@ -30,86 +24,41 @@ public class OcrServiceImpl implements OcrService {
 
     private static final Logger log = LoggerFactory.getLogger(OcrService.class);
 
-    @Autowired
-    private OcrApiClient ocrApiClient;
-
     @Autowired
     private SolarService solarService;
 
     @Override
     public void ocr() {
-        File dir = new File("d:/ocr");
+        log.warn("OcrService ocr start*************");
+        File dir = new File("d:/ocr/data");
         OrgInfo org = solarService.getOrgList().get(0);
-        disposeFile(dir, org);
+        List<OcrDto> files = new ArrayList<>();
+        disposeFile(files, dir, org);
+        OcrProducer producer = new OcrProducer();
+        Map<String, Object> param = new HashMap<>();
+        param.put("files", files);
+        producer.startDispose(4, param, files.size());
+        while (true) {
+            List<OcrDto> failed = producer.getFaildDto();
+            if (CollectionUtils.isEmpty(failed)) {
+                break;
+            }
+            param = new HashMap<>();
+            param.put("files", failed);
+            producer = new OcrProducer();
+            producer.startDispose(4, param, failed.size());
+        }
         log.warn("OcrService ocr finish*************");
     }
 
-    private void disposeFile(File file, OrgInfo org) {
+    private void disposeFile(List<OcrDto> files, File file, OrgInfo org) {
         if (file.isFile() && file.getName().toLowerCase().endsWith(".jpg")) {
-            StudentScoreImageDto dto = new StudentScoreImageDto();
-            dto.setImage(fileToByte(file));
-            String content = ocrDispose(dto, org);
-            String name = file.getName().substring(0, file.getName().lastIndexOf("."));
-            File txt = new File(file.getParentFile().getAbsolutePath() + "/" + name + ".txt");
-            if (txt.exists()) {
-                txt.delete();
-            }
-            try {
-                FileUtils.write(txt, content, "utf-8");
-            } catch (IOException e) {
-                throw new RuntimeException(e);
-            }
+            files.add(new OcrDto(file, org));
         } else {
             if (file.isDirectory()) {
                 for (File subFile : file.listFiles()) {
-                    disposeFile(subFile, org);
-                }
-            }
-        }
-    }
-
-    private byte[] fileToByte(File img) {
-        ByteArrayOutputStream baos = new ByteArrayOutputStream();
-        try {
-            BufferedImage bi;
-            bi = ImageIO.read(img);
-            ImageIO.write(bi, "jpg", baos);
-            byte[] bytes = baos.toByteArray();
-            return bytes;
-        } catch (Exception e) {
-            throw new RuntimeException(e);
-        } finally {
-            try {
-                baos.close();
-            } catch (IOException e) {
-            }
-        }
-    }
-
-    private String ocrDispose(StudentScoreImageDto dto, OrgInfo org) {
-        SignatureInfo signature = SignatureInfo.secret(org.getAccessKey(), org.getAccessSecret());
-        try {
-            return ocrApiClient.forImage(signature, OcrType.HANDWRITING, UploadFile.build("image", "", dto.getImage()));
-        } catch (Exception e) {
-            log.error("ocr异常", e);
-            if (e instanceof RetrofitResponseError) {
-                RetrofitResponseError tem = (RetrofitResponseError) e;
-                if (tem.getCode() == 503) {
-                    if (dto.getRetry() <= 3) {
-                        try {
-                            Thread.sleep(3000);
-                        } catch (InterruptedException e1) {
-                        }
-                        dto.setRetry(dto.getRetry() + 1);
-                        return ocrDispose(dto, org);
-                    } else {
-                        throw new StatusException("重试次数过多");
-                    }
-                } else {
-                    throw e;
+                    disposeFile(files, subFile, org);
                 }
-            } else {
-                throw e;
             }
         }
     }