Browse Source

扩展aliyun实现ocr手写识别服务接口,增加ocr供应商数据库存储

luoshi 1 year ago
parent
commit
ab1240ac2a
20 changed files with 315 additions and 29 deletions
  1. 30 0
      src/main/java/com/qmth/ops/api/controller/ai/OcrController.java
  2. 8 6
      src/main/java/com/qmth/ops/biz/ai/client/OcrApiClient.java
  3. 13 0
      src/main/java/com/qmth/ops/biz/ai/client/OcrApiConfig.java
  4. 2 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatClient.java
  5. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatInput.java
  6. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatOutput.java
  7. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatRequest.java
  8. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatResult.java
  9. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunUsage.java
  10. 31 12
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrClient.java
  11. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrData.java
  12. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrResult.java
  13. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureChatClient.java
  14. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureChatResult.java
  15. 1 1
      src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureUsage.java
  16. 13 0
      src/main/java/com/qmth/ops/biz/ai/exception/OcrClientNotFound.java
  17. 8 0
      src/main/java/com/qmth/ops/biz/dao/OcrSupplierDao.java
  18. 104 0
      src/main/java/com/qmth/ops/biz/domain/OcrSupplier.java
  19. 58 0
      src/main/java/com/qmth/ops/biz/service/OcrClientService.java
  20. 38 0
      src/main/java/com/qmth/ops/biz/service/OcrSupplierService.java

+ 30 - 0
src/main/java/com/qmth/ops/api/controller/ai/OcrController.java

@@ -0,0 +1,30 @@
+package com.qmth.ops.api.controller.ai;
+
+import com.qmth.boot.api.annotation.Aac;
+import com.qmth.boot.core.ai.model.AiConstants;
+import com.qmth.boot.core.ai.model.ocr.OcrType;
+import com.qmth.boot.tools.signature.SignatureType;
+import com.qmth.ops.api.security.AccessOrg;
+import com.qmth.ops.biz.service.OcrClientService;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestAttribute;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.multipart.MultipartFile;
+
+import javax.annotation.Resource;
+
+@RestController
+@Aac(auth = true, signType = SignatureType.SECRET)
+public class OcrController {
+
+    @Resource
+    private OcrClientService ocrClientService;
+
+    @PostMapping(AiConstants.OCR_IMAGE_PATH)
+    public String forImage(@RequestAttribute AccessOrg accessOrg, @RequestParam("type") OcrType type,
+            @RequestParam("image") MultipartFile file) throws Exception {
+        return ocrClientService.forImage(type, file.getBytes());
+    }
+
+}

+ 8 - 6
src/main/java/com/qmth/ops/biz/ai/client/OcrApiClient.java

@@ -1,10 +1,14 @@
 package com.qmth.ops.biz.ai.client;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.rateLimit.service.RateLimiter;
 import com.qmth.boot.core.rateLimit.service.impl.MemoryRateLimiter;
 import com.qmth.ops.biz.ai.exception.OcrRateLimitExceeded;
-import okhttp3.*;
+import okhttp3.ConnectionPool;
+import okhttp3.OkHttpClient;
+import okhttp3.Request;
+import okhttp3.Response;
 
 import java.io.IOException;
 
@@ -34,19 +38,17 @@ public abstract class OcrApiClient {
         return config;
     }
 
-    protected abstract String buildUrl() throws Exception;
+    protected abstract Request buildRequest(OcrType type, byte[] file) throws Exception;
 
     protected abstract String buildResult(byte[] data, ObjectMapper mapper) throws IOException;
 
     protected abstract String handleError(byte[] data, int statusCode, ObjectMapper mapper);
 
-    public String call(byte[] image) throws Exception {
+    public String call(OcrType type, byte[] file) throws Exception {
         if (queryRateLimiter != null && !queryRateLimiter.acquire()) {
             throw new OcrRateLimitExceeded(config.getQps());
         }
-        RequestBody body = RequestBody.create(MediaType.parse("application/octet-stream"), image);
-        Request httpRequest = new Request.Builder().url(buildUrl()).post(body).build();
-        Response response = client.newCall(httpRequest).execute();
+        Response response = client.newCall(buildRequest(type, file)).execute();
         byte[] data = response.body() != null ? response.body().bytes() : null;
         if (response.isSuccessful()) {
             return buildResult(data, mapper);

+ 13 - 0
src/main/java/com/qmth/ops/biz/ai/client/OcrApiConfig.java

@@ -1,5 +1,7 @@
 package com.qmth.ops.biz.ai.client;
 
+import com.qmth.ops.biz.domain.OcrSupplier;
+
 public class OcrApiConfig {
 
     private String url;
@@ -10,6 +12,17 @@ public class OcrApiConfig {
 
     private int qps;
 
+    public OcrApiConfig() {
+
+    }
+
+    public OcrApiConfig(OcrSupplier supplier) {
+        this.url = supplier.getUrl();
+        this.key = supplier.getKey();
+        this.secret = supplier.getSecret();
+        this.qps = supplier.getQps();
+    }
+
     public String getUrl() {
         return url;
     }

+ 2 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunChatClient.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatClient.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.llm;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.qmth.boot.core.ai.model.llm.ChatRequest;
@@ -9,6 +9,7 @@ import com.qmth.boot.core.exception.StatusException;
 import com.qmth.ops.biz.ai.client.ChatApiClient;
 import com.qmth.ops.biz.ai.client.ChatApiConfig;
 import com.qmth.ops.biz.ai.client.Credentials;
+import com.qmth.ops.biz.ai.client.aliyun.AliyunError;
 import com.qmth.ops.biz.ai.exception.ChatRequestError;
 
 import java.io.IOException;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunChatInput.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatInput.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.llm;
 
 import com.qmth.boot.core.ai.model.llm.ChatMessage;
 

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunChatOutput.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatOutput.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.llm;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.qmth.boot.core.ai.model.llm.ChatChoice;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunChatRequest.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatRequest.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.llm;
 
 import com.qmth.boot.core.ai.model.llm.ChatRequest;
 

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunChatResult.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatResult.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.llm;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.qmth.boot.core.ai.model.llm.ChatResult;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunUsage.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunUsage.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.llm;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.fasterxml.jackson.annotation.JsonProperty;

+ 31 - 12
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunOcrClient.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrClient.java

@@ -1,6 +1,7 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.ocr;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.exception.ParameterException;
 import com.qmth.boot.core.exception.ReentrantException;
 import com.qmth.boot.core.exception.StatusException;
@@ -9,6 +10,10 @@ import com.qmth.boot.tools.models.ByteArray;
 import com.qmth.boot.tools.uuid.FastUUID;
 import com.qmth.ops.biz.ai.client.OcrApiClient;
 import com.qmth.ops.biz.ai.client.OcrApiConfig;
+import com.qmth.ops.biz.ai.client.aliyun.AliyunError;
+import okhttp3.MediaType;
+import okhttp3.Request;
+import okhttp3.RequestBody;
 
 import javax.crypto.Mac;
 import javax.crypto.SecretKey;
@@ -26,11 +31,8 @@ import java.util.*;
 
 public class AliyunOcrClient extends OcrApiClient {
 
-    private String action;
-
-    public AliyunOcrClient(OcrApiConfig config, String action) {
+    public AliyunOcrClient(OcrApiConfig config) {
         super(config);
-        this.action = action;
     }
 
     @Override
@@ -44,7 +46,7 @@ public class AliyunOcrClient extends OcrApiClient {
         if (data != null) {
             try {
                 error = mapper.readValue(data, AliyunError.class);
-            } catch (Exception e) {
+            } catch (Exception ignore) {
             }
         }
         switch (statusCode) {
@@ -120,7 +122,18 @@ public class AliyunOcrClient extends OcrApiClient {
      *
      * @return 公共请求参数组成的字典
      */
-    private Map<String, String> getCommonParameters(String action) {
+    private Map<String, String> getCommonParameters(OcrType type) {
+        String action;
+        switch (type) {
+        case GENERAL:
+            action = "RecognizeGeneral";
+            break;
+        case HANDWRITING:
+            action = "RecognizeHandwriting";
+            break;
+        default:
+            action = "";
+        }
         return new HashMap<String, String>() {{
             put("Action", action); // 调用的接口名称,此处以 RecognizeGeneral 为例
             put("Version", "2021-07-07"); // API版本。OCR的固定值:2021-07-07
@@ -137,10 +150,9 @@ public class AliyunOcrClient extends OcrApiClient {
     /**
      * 识别本地文件代码示例。以 RecognizeGeneral 接口为例。
      */
-    @Override
-    protected String buildUrl() throws Exception {
+    protected String buildUrl(OcrType type) throws Exception {
         // 获取公共请求参数
-        Map<String, String> parametersMap = getCommonParameters(action);
+        Map<String, String> parametersMap = getCommonParameters(type);
         // 初始化请求URL
         StringBuilder urlBuilder = new StringBuilder(getConfig().getUrl());
         urlBuilder.append("?");
@@ -158,14 +170,21 @@ public class AliyunOcrClient extends OcrApiClient {
         return url;
     }
 
+    @Override
+    protected Request buildRequest(OcrType type, byte[] file) throws Exception {
+        return new Request.Builder().url(buildUrl(type))
+                .post(RequestBody.create(MediaType.parse("application/octet-stream"), file)).build();
+    }
+
     public static void main(String[] args) throws Exception {
         OcrApiConfig config = new OcrApiConfig();
         config.setUrl("https://ocr-api.cn-hangzhou.aliyuncs.com/");
         config.setKey("LTAI5t6D5p62tDjYgwSz2mTR");
         config.setSecret("twrXT7Dp1kG1bV5HZn6vgpoypu9PnZ");
         config.setQps(0);
-        AliyunOcrClient client = new AliyunOcrClient(config, "RecognizeHandwriting");
-        System.out.println(client.call(ByteArray.fromFile(new File("/Users/luoshi/Downloads/cache/1-1.jpg")).value()));
+        AliyunOcrClient client = new AliyunOcrClient(config);
+        System.out.println(client.call(OcrType.GENERAL,
+                ByteArray.fromFile(new File("/Users/luoshi/Downloads/cache/1-1.jpg")).value()));
     }
 
 }

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunOcrData.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrData.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.ocr;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.fasterxml.jackson.annotation.JsonProperty;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/aliyun/AliyunOcrResult.java → src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrResult.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.aliyun;
+package com.qmth.ops.biz.ai.client.aliyun.ocr;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.fasterxml.jackson.annotation.JsonProperty;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/azure/AzureChatClient.java → src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureChatClient.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.azure;
+package com.qmth.ops.biz.ai.client.azure.llm;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.qmth.boot.core.ai.model.llm.ChatRequest;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/azure/AzureChatResult.java → src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureChatResult.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.azure;
+package com.qmth.ops.biz.ai.client.azure.llm;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.qmth.boot.core.ai.model.llm.ChatChoice;

+ 1 - 1
src/main/java/com/qmth/ops/biz/ai/client/azure/AzureUsage.java → src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureUsage.java

@@ -1,4 +1,4 @@
-package com.qmth.ops.biz.ai.client.azure;
+package com.qmth.ops.biz.ai.client.azure.llm;
 
 import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
 import com.fasterxml.jackson.annotation.JsonProperty;

+ 13 - 0
src/main/java/com/qmth/ops/biz/ai/exception/OcrClientNotFound.java

@@ -0,0 +1,13 @@
+package com.qmth.ops.biz.ai.exception;
+
+import com.qmth.boot.core.ai.model.ocr.OcrType;
+import com.qmth.boot.core.exception.NotFoundException;
+
+public class OcrClientNotFound extends NotFoundException {
+
+    private static final long serialVersionUID = 6299687661502614806L;
+
+    public OcrClientNotFound(OcrType type) {
+        super("OCR api client not found for type: " + type);
+    }
+}

+ 8 - 0
src/main/java/com/qmth/ops/biz/dao/OcrSupplierDao.java

@@ -0,0 +1,8 @@
+package com.qmth.ops.biz.dao;
+
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.qmth.ops.biz.domain.OcrSupplier;
+
+public interface OcrSupplierDao extends BaseMapper<OcrSupplier> {
+
+}

+ 104 - 0
src/main/java/com/qmth/ops/biz/domain/OcrSupplier.java

@@ -0,0 +1,104 @@
+package com.qmth.ops.biz.domain;
+
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
+import com.baomidou.mybatisplus.annotation.TableName;
+
+import java.io.Serializable;
+
+@TableName("ocr_supplier")
+public class OcrSupplier implements Serializable {
+
+    private static final long serialVersionUID = 5460549877564943447L;
+
+    @TableId(type = IdType.AUTO)
+    private Long id;
+
+    private String name;
+
+    private String url;
+
+    private String key;
+
+    private String secret;
+
+    private String clientClass;
+
+    private Integer qps;
+
+    private Long createTime;
+
+    private Long updateTime;
+
+    public Long getId() {
+        return id;
+    }
+
+    public void setId(Long id) {
+        this.id = id;
+    }
+
+    public String getName() {
+        return name;
+    }
+
+    public void setName(String name) {
+        this.name = name;
+    }
+
+    public String getUrl() {
+        return url;
+    }
+
+    public void setUrl(String url) {
+        this.url = url;
+    }
+
+    public String getSecret() {
+        return secret;
+    }
+
+    public void setSecret(String secret) {
+        this.secret = secret;
+    }
+
+    public String getKey() {
+        return key;
+    }
+
+    public void setKey(String key) {
+        this.key = key;
+    }
+
+    public String getClientClass() {
+        return clientClass;
+    }
+
+    public void setClientClass(String clientClass) {
+        this.clientClass = clientClass;
+    }
+
+    public Integer getQps() {
+        return qps;
+    }
+
+    public void setQps(Integer qps) {
+        this.qps = qps;
+    }
+
+    public Long getCreateTime() {
+        return createTime;
+    }
+
+    public void setCreateTime(Long createTime) {
+        this.createTime = createTime;
+    }
+
+    public Long getUpdateTime() {
+        return updateTime;
+    }
+
+    public void setUpdateTime(Long updateTime) {
+        this.updateTime = updateTime;
+    }
+}

+ 58 - 0
src/main/java/com/qmth/ops/biz/service/OcrClientService.java

@@ -0,0 +1,58 @@
+package com.qmth.ops.biz.service;
+
+import com.qmth.boot.core.ai.model.ocr.OcrType;
+import com.qmth.ops.biz.ai.client.OcrApiClient;
+import com.qmth.ops.biz.ai.client.OcrApiConfig;
+import com.qmth.ops.biz.ai.exception.OcrClientNotFound;
+import com.qmth.ops.biz.domain.OcrSupplier;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.Resource;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+@Service
+public class OcrClientService {
+
+    private static final Logger log = LoggerFactory.getLogger(OcrClientService.class);
+
+    private OcrApiClient defaultClient;
+
+    private Map<Long, OcrApiClient> clientMap = new HashMap<>();
+
+    @Resource
+    private OcrSupplierService ocrSupplierService;
+
+    public void init() {
+        List<OcrSupplier> list = ocrSupplierService.list();
+        for (OcrSupplier supplier : list) {
+            initApiClient(supplier);
+        }
+    }
+
+    private void initApiClient(OcrSupplier supplier) {
+        try {
+            String className = OcrApiClient.class.getName().replace("OcrApiClient", supplier.getClientClass());
+            OcrApiConfig config = new OcrApiConfig(supplier);
+            Class<?> clientClass = Class.forName(className);
+            OcrApiClient client = (OcrApiClient) clientClass.getConstructor(OcrApiConfig.class).newInstance(config);
+            clientMap.put(supplier.getId(), client);
+            if (defaultClient == null) {
+                defaultClient = client;
+            }
+        } catch (Exception e) {
+            log.error("OCR api client init error, supplier={}, class={}", supplier.getName(),
+                    supplier.getClientClass());
+        }
+    }
+
+    public String forImage(OcrType type, byte[] imageData) throws Exception {
+        if (defaultClient == null) {
+            throw new OcrClientNotFound(type);
+        }
+        return defaultClient.call(type, imageData);
+    }
+}

+ 38 - 0
src/main/java/com/qmth/ops/biz/service/OcrSupplierService.java

@@ -0,0 +1,38 @@
+package com.qmth.ops.biz.service;
+
+import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
+import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
+import com.qmth.ops.biz.dao.OcrSupplierDao;
+import com.qmth.ops.biz.domain.OcrSupplier;
+import org.springframework.stereotype.Service;
+import org.springframework.transaction.annotation.Transactional;
+
+import javax.annotation.Resource;
+
+@Service
+public class OcrSupplierService extends ServiceImpl<OcrSupplierDao, OcrSupplier> {
+
+    @Resource
+    private OcrSupplierDao supplierDao;
+
+    @Transactional
+    public OcrSupplier insert(OcrSupplier supplier) {
+        supplier.setCreateTime(System.currentTimeMillis());
+        supplier.setUpdateTime(supplier.getCreateTime());
+        supplierDao.insert(supplier);
+        return supplier;
+    }
+
+    @Transactional
+    public void update(OcrSupplier supplier) {
+        supplierDao.update(supplier, new LambdaUpdateWrapper<OcrSupplier>()
+                .set(supplier.getName() != null, OcrSupplier::getName, supplier.getName())
+                .set(supplier.getUrl() != null, OcrSupplier::getUrl, supplier.getUrl())
+                .set(supplier.getKey() != null, OcrSupplier::getKey, supplier.getKey())
+                .set(supplier.getSecret() != null, OcrSupplier::getSecret, supplier.getSecret())
+                .set(supplier.getQps() != null, OcrSupplier::getQps, supplier.getQps())
+                .set(OcrSupplier::getUpdateTime, System.currentTimeMillis()).eq(OcrSupplier::getId, supplier.getId()));
+    }
+
+}
+