Răsfoiți Sursa

update parseDoc api

deason 1 lună în urmă
părinte
comite
7354ff8e84

+ 77 - 0
src/main/java/com/qmth/ops/biz/ai/client/DocApiClient.java

@@ -0,0 +1,77 @@
+package com.qmth.ops.biz.ai.client;
+
+import com.qmth.boot.core.ai.model.ocr.ParseDocTask;
+import com.qmth.boot.core.ai.model.ocr.ParseDocTaskResult;
+import com.qmth.boot.core.exception.StatusException;
+import com.qmth.boot.core.rateLimit.service.RateLimiter;
+import com.qmth.boot.core.rateLimit.service.impl.MemoryRateLimiter;
+import okhttp3.*;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.time.Duration;
+
+/**
+ * 文档解析类接口基础实现
+ */
+public abstract class DocApiClient {
+
+    private static final Logger log = LoggerFactory.getLogger(DocApiClient.class);
+
+    private final OcrApiConfig config;
+
+    private final OkHttpClient client;
+
+    private RateLimiter queryRateLimiter;
+
+    public DocApiClient(OcrApiConfig config) {
+        this.config = config;
+
+        OkHttpClient.Builder builder = new OkHttpClient.Builder().connectionPool(new ConnectionPool())
+                .connectTimeout(Duration.ofSeconds(10)).readTimeout(Duration.ofSeconds(50));
+        Interceptor interceptor = getInterceptor();
+        if (interceptor != null) {
+            builder.addInterceptor(interceptor);
+        }
+        this.client = builder.build();
+
+        if (config.getQps() > 0) {
+            this.queryRateLimiter = new MemoryRateLimiter(config.getQps(), 1000);
+        }
+    }
+
+    public abstract ParseDocTask parseDocTask(byte[] fileData, String fileName) throws Exception;
+
+    public abstract ParseDocTaskResult parseDocTaskQuery(String taskId) throws Exception;
+
+    protected OcrApiConfig getConfig() {
+        return config;
+    }
+
+    protected OkHttpClient getClient() {
+        return client;
+    }
+
+    protected RateLimiter getQueryRateLimiter() {
+        return queryRateLimiter;
+    }
+
+    protected Interceptor getInterceptor() {
+        return null;
+    }
+
+    protected byte[] download(String url) {
+        Request request = new Request.Builder().url(url).get().build();
+        try (Response response = this.getClient().newCall(request).execute();) {
+            ResponseBody respBody = response.body();
+            if (response.isSuccessful() && respBody != null) {
+                return respBody.bytes();
+            }
+            log.error("获取文件内容失败!responseCode:{}", response.code());
+            throw new StatusException("获取文件内容失败!");
+        } catch (Exception e) {
+            throw new StatusException("获取文件内容失败!", e);
+        }
+    }
+
+}

+ 26 - 19
src/main/java/com/qmth/ops/biz/ai/client/baidu/doc/BaiduParseDocClient.java

@@ -6,36 +6,41 @@ import com.qmth.boot.core.ai.model.ocr.ParseDocTaskResult;
 import com.qmth.boot.core.ai.model.ocr.ParseDocTaskStatus;
 import com.qmth.boot.core.exception.StatusException;
 import com.qmth.boot.tools.codec.CodecUtils;
+import com.qmth.ops.biz.ai.client.DocApiClient;
 import com.qmth.ops.biz.ai.client.OcrApiConfig;
 import com.qmth.ops.biz.ai.client.baidu.BceV1Signer;
+import com.qmth.ops.biz.ai.exception.OcrRateLimitExceeded;
 import okhttp3.*;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.http.HttpHeaders;
 
+import java.nio.charset.StandardCharsets;
 import java.text.SimpleDateFormat;
 import java.util.Date;
 import java.util.TimeZone;
-import java.util.concurrent.TimeUnit;
 
-public class BaiduParseDocClient {
+public class BaiduParseDocClient extends DocApiClient {
 
     private static final Logger log = LoggerFactory.getLogger(BaiduParseDocClient.class);
 
-    private OcrApiConfig config;
-
     public BaiduParseDocClient(OcrApiConfig config) {
-        this.config = config;
+        super(config);
     }
 
+    @Override
     public ParseDocTask parseDocTask(byte[] fileData, String fileName) throws Exception {
+        if (getQueryRateLimiter() != null && !getQueryRateLimiter().acquire()) {
+            throw new OcrRateLimitExceeded(getConfig().getQps());
+        }
+
         FormBody.Builder formBuilder = new FormBody.Builder();
         formBuilder.add("file_data", CodecUtils.toBase64(fileData));
         formBuilder.add("file_name", CodecUtils.urlEncode(fileName));
-        String url = config.getUrl() + "/rest/2.0/brain/online/v2/parser/task";
+        String url = getConfig().getUrl() + "/rest/2.0/brain/online/v2/parser/task";
         Request request = this.buildRequest(url, formBuilder);
 
-        try (Response response = this.getHttpClient().newCall(request).execute()) {
+        try (Response response = super.getClient().newCall(request).execute()) {
             ResponseBody respBody = response.body();
             String respBodyStr = respBody != null ? respBody.string() : "";
             log.info(respBodyStr);
@@ -54,13 +59,18 @@ public class BaiduParseDocClient {
         }
     }
 
+    @Override
     public ParseDocTaskResult parseDocTaskQuery(String taskId) throws Exception {
+        if (getQueryRateLimiter() != null && !getQueryRateLimiter().acquire()) {
+            throw new OcrRateLimitExceeded(getConfig().getQps());
+        }
+
         FormBody.Builder formBuilder = new FormBody.Builder();
         formBuilder.add("task_id", taskId);
-        String url = config.getUrl() + "/rest/2.0/brain/online/v2/parser/task/query";
+        String url = getConfig().getUrl() + "/rest/2.0/brain/online/v2/parser/task/query";
         Request request = this.buildRequest(url, formBuilder);
 
-        try (Response response = this.getHttpClient().newCall(request).execute()) {
+        try (Response response = super.getClient().newCall(request).execute()) {
             ResponseBody respBody = response.body();
             String respBodyStr = respBody != null ? respBody.string() : "";
             log.info(respBodyStr);
@@ -73,7 +83,11 @@ public class BaiduParseDocClient {
 
                     ParseDocTaskResult result = new ParseDocTaskResult();
                     result.setStatus(status != null ? status : ParseDocTaskStatus.FAILED);
-                    result.setContent(respResult.getMarkdownUrl());
+                    if (ParseDocTaskStatus.SUCCESS == status) {
+                        byte[] data = super.download(respResult.getMarkdownUrl());
+                        result.setContent(new String(data, StandardCharsets.UTF_8));
+                    }
+
                     return result;
                 }
             }
@@ -98,14 +112,7 @@ public class BaiduParseDocClient {
                 .build();
 
         return request.newBuilder().addHeader(HttpHeaders.AUTHORIZATION,
-                BceV1Signer.sign(request, config.getKey(), config.getSecret())).build();
-    }
-
-    private OkHttpClient getHttpClient() {
-        return new OkHttpClient.Builder()
-                .readTimeout(60, TimeUnit.SECONDS)
-                .connectTimeout(60, TimeUnit.SECONDS)
-                .build();
+                BceV1Signer.sign(request, getConfig().getKey(), getConfig().getSecret())).build();
     }
 
     public static void main(String[] args) throws Exception {
@@ -114,7 +121,7 @@ public class BaiduParseDocClient {
         config.setKey("xxx");
         config.setSecret("xxx");
         config.setQps(10);
-        BaiduParseDocClient client = new BaiduParseDocClient(config);
+        DocApiClient client = new BaiduParseDocClient(config);
 
         // File file = new File("D:\\home\\大纲.pdf");
         // byte[] fileData = ByteArray.fromFile(file).value();

+ 29 - 17
src/main/java/com/qmth/ops/biz/service/OcrClientService.java

@@ -4,9 +4,9 @@ import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.ai.model.ocr.ParseDocTask;
 import com.qmth.boot.core.ai.model.ocr.ParseDocTaskResult;
+import com.qmth.ops.biz.ai.client.DocApiClient;
 import com.qmth.ops.biz.ai.client.OcrApiClient;
 import com.qmth.ops.biz.ai.client.OcrApiConfig;
-import com.qmth.ops.biz.ai.client.baidu.doc.BaiduParseDocClient;
 import com.qmth.ops.biz.ai.exception.OcrClientNotFound;
 import com.qmth.ops.biz.domain.OcrSupplier;
 import org.slf4j.Logger;
@@ -46,22 +46,26 @@ public class OcrClientService {
             String className = OcrApiClient.class.getName().replace("OcrApiClient", supplier.getClientClass());
             OcrApiConfig config = new OcrApiConfig(supplier);
             Class<?> clientClass = Class.forName(className);
-            OcrApiClient client = (OcrApiClient) clientClass.getConstructor(OcrApiConfig.class).newInstance(config);
-            clientMap.put(supplier.getId(), client);
-            //取第一个enable=true的为默认客户端
-            if (supplier.getEnable() && defaultClient == null) {
-                defaultClient = client;
+
+            Object clientInstance = clientClass.getConstructor(OcrApiConfig.class).newInstance(config);
+            if (clientInstance instanceof OcrApiClient) {
+                OcrApiClient client = (OcrApiClient) clientInstance;
+                clientMap.put(supplier.getId(), client);
+                // 取第一个enable=true的为默认客户端
+                if (supplier.getEnable() && defaultClient == null) {
+                    defaultClient = client;
+                }
+            } else {
+                log.warn("DocApiClient supplier:{} class:{}", supplier.getName(), supplier.getClientClass());
             }
         } catch (Exception e) {
-            log.error("OCR api client init error, supplier={}, class={}", supplier.getName(),
-                    supplier.getClientClass());
+            log.error("OcrApiClient init error, supplier:{} class:{}", supplier.getName(), supplier.getClientClass());
         }
     }
 
     public String forImage(Long id, OcrType type, byte[] imageData, ImageType imageType) throws Exception {
         OcrApiClient client = clientMap.get(id);
-
-        if (defaultClient == null) {
+        if (client == null) {
             throw new OcrClientNotFound(id);
         }
         return client.forImage(type, imageData, imageType);
@@ -75,20 +79,28 @@ public class OcrClientService {
     }
 
     public ParseDocTask parseDocTask(byte[] fileData, String fileName) throws Exception {
-        OcrApiConfig apiConfig = this.getApiConfig(3L);//todo
-        BaiduParseDocClient client = new BaiduParseDocClient(apiConfig);
+        DocApiClient client = this.getDefaultDocApiClient();
         return client.parseDocTask(fileData, fileName);
     }
 
     public ParseDocTaskResult parseDocTaskQuery(String taskId) throws Exception {
-        OcrApiConfig apiConfig = this.getApiConfig(3L);//todo
-        BaiduParseDocClient client = new BaiduParseDocClient(apiConfig);
+        DocApiClient client = this.getDefaultDocApiClient();
         return client.parseDocTaskQuery(taskId);
     }
 
-    private OcrApiConfig getApiConfig(Long id) {
-        OcrSupplier supplier = ocrSupplierService.getById(id);
-        return new OcrApiConfig(supplier);
+    private DocApiClient getDefaultDocApiClient() {
+        OcrSupplier supplier = ocrSupplierService.getById(3L);//todo
+
+        try {
+            String className = DocApiClient.class.getName().replace(DocApiClient.class.getSimpleName(), supplier.getClientClass());
+            OcrApiConfig config = new OcrApiConfig(supplier);
+            Class<?> clientClass = Class.forName(className);
+
+            return (DocApiClient) clientClass.getConstructor(OcrApiConfig.class).newInstance(config);
+        } catch (Exception e) {
+            log.error("DocApiClient init error, supplier:{} class:{}", supplier.getName(), supplier.getClientClass());
+            throw new OcrClientNotFound();
+        }
     }
 
 }