瀏覽代碼

补充ocr客户端初始化逻辑

luoshi 1 年之前
父節點
當前提交
f26f5c023b

+ 50 - 0
src/main/java/com/qmth/ops/api/controller/admin/OcrSupplierController.java

@@ -0,0 +1,50 @@
+package com.qmth.ops.api.controller.admin;
+
+import com.qmth.ops.api.constants.OpsApiConstants;
+import com.qmth.ops.api.security.AdminSession;
+import com.qmth.ops.api.security.Permission;
+import com.qmth.ops.api.vo.OcrSupplierVO;
+import com.qmth.ops.biz.domain.OcrSupplier;
+import com.qmth.ops.biz.service.OcrClientService;
+import com.qmth.ops.biz.service.OcrSupplierService;
+import org.springframework.web.bind.annotation.*;
+
+import javax.annotation.Resource;
+import java.util.List;
+import java.util.stream.Collectors;
+
+@RestController
+@RequestMapping(OpsApiConstants.ADMIN_URI_PREFIX + "/ocr/supplier")
+public class OcrSupplierController {
+
+    @Resource
+    private OcrSupplierService ocrSupplierService;
+
+    @Resource
+    private OcrClientService ocrClientService;
+
+    @PostMapping("/list")
+    public List<OcrSupplierVO> list() {
+        return ocrSupplierService.list().stream().map(OcrSupplierVO::new).collect(Collectors.toList());
+    }
+
+    @PostMapping("/detail")
+    public OcrSupplier detail(@RequestParam Long id) {
+        return ocrSupplierService.getById(id);
+    }
+
+    @PostMapping("/insert")
+    public OcrSupplier insert(@RequestAttribute AdminSession adminSession, OcrSupplier supplier) {
+        adminSession.hasPermission(Permission.OCR_SUPPLIER_INSERT);
+        return ocrSupplierService.insert(supplier);
+    }
+
+    @PostMapping("/update")
+    public OcrSupplier update(@RequestAttribute AdminSession adminSession, OcrSupplier supplier) {
+        adminSession.hasPermission(Permission.OCR_SUPPLIER_EDIT);
+        ocrSupplierService.update(supplier);
+        ocrClientService.init();
+        return ocrSupplierService.getById(supplier.getId());
+    }
+
+}

+ 3 - 0
src/main/java/com/qmth/ops/api/security/Permission.java

@@ -37,6 +37,9 @@ public enum Permission {
     USER_VIEW("用户管理", null),
     USER_INSERT("用户新增", null),
     USER_EDIT("用户修改", null),
+    OCR_SUPPLIER_VIEW("OCR服务商管理", null),
+    OCR_SUPPLIER_INSERT("OCR服务商新增", null),
+    OCR_SUPPLIER_EDIT("OCR服务商修改", null),
     LLM_SUPPLIER_VIEW("大模型服务商管理", null),
     LLM_SUPPLIER_INSERT("大模型服务商新增", null),
     LLM_SUPPLIER_EDIT("大模型服务商修改", null),

+ 53 - 0
src/main/java/com/qmth/ops/api/vo/OcrSupplierVO.java

@@ -0,0 +1,53 @@
+package com.qmth.ops.api.vo;
+
+import com.qmth.ops.biz.domain.OcrSupplier;
+
+public class OcrSupplierVO {
+
+    private Long id;
+
+    private String name;
+
+    private Long createTime;
+
+    private Long updateTime;
+
+    public OcrSupplierVO(OcrSupplier supplier) {
+        this.id = supplier.getId();
+        this.name = supplier.getName();
+        this.createTime = supplier.getCreateTime();
+        this.updateTime = supplier.getUpdateTime();
+    }
+
+    public Long getId() {
+        return id;
+    }
+
+    public void setId(Long id) {
+        this.id = id;
+    }
+
+    public String getName() {
+        return name;
+    }
+
+    public void setName(String name) {
+        this.name = name;
+    }
+
+    public Long getCreateTime() {
+        return createTime;
+    }
+
+    public void setCreateTime(Long createTime) {
+        this.createTime = createTime;
+    }
+
+    public Long getUpdateTime() {
+        return updateTime;
+    }
+
+    public void setUpdateTime(Long updateTime) {
+        this.updateTime = updateTime;
+    }
+}

+ 3 - 3
src/main/java/com/qmth/ops/biz/ai/client/OcrApiClient.java

@@ -38,17 +38,17 @@ public abstract class OcrApiClient {
         return config;
     }
 
-    protected abstract Request buildRequest(OcrType type, byte[] file) throws Exception;
+    protected abstract Request buildRequest(OcrType type, byte[] image) throws Exception;
 
     protected abstract String buildResult(byte[] data, ObjectMapper mapper) throws IOException;
 
     protected abstract String handleError(byte[] data, int statusCode, ObjectMapper mapper);
 
-    public String call(OcrType type, byte[] file) throws Exception {
+    public String forImage(OcrType type, byte[] image) throws Exception {
         if (queryRateLimiter != null && !queryRateLimiter.acquire()) {
             throw new OcrRateLimitExceeded(config.getQps());
         }
-        Response response = client.newCall(buildRequest(type, file)).execute();
+        Response response = client.newCall(buildRequest(type, image)).execute();
         byte[] data = response.body() != null ? response.body().bytes() : null;
         if (response.isSuccessful()) {
             return buildResult(data, mapper);

+ 6 - 6
src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrClient.java

@@ -90,7 +90,7 @@ public class AliyunOcrClient extends OcrApiClient {
         }
     }
 
-    public String getSignature(String url, String httpMethod) throws Exception {
+    private String getSignature(String url, String httpMethod) throws Exception {
         // 解析url中的参数部分
         URL u = new URL(url);
         String query = u.getQuery();
@@ -132,7 +132,7 @@ public class AliyunOcrClient extends OcrApiClient {
             action = "RecognizeHandwriting";
             break;
         default:
-            action = "";
+            throw new ParameterException("OcrType is invalid");
         }
         return new HashMap<String, String>() {{
             put("Action", action); // 调用的接口名称,此处以 RecognizeGeneral 为例
@@ -150,7 +150,7 @@ public class AliyunOcrClient extends OcrApiClient {
     /**
      * 识别本地文件代码示例。以 RecognizeGeneral 接口为例。
      */
-    protected String buildUrl(OcrType type) throws Exception {
+    private String buildUrl(OcrType type) throws Exception {
         // 获取公共请求参数
         Map<String, String> parametersMap = getCommonParameters(type);
         // 初始化请求URL
@@ -171,9 +171,9 @@ public class AliyunOcrClient extends OcrApiClient {
     }
 
     @Override
-    protected Request buildRequest(OcrType type, byte[] file) throws Exception {
+    protected Request buildRequest(OcrType type, byte[] image) throws Exception {
         return new Request.Builder().url(buildUrl(type))
-                .post(RequestBody.create(MediaType.parse("application/octet-stream"), file)).build();
+                .post(RequestBody.create(MediaType.parse("application/octet-stream"), image)).build();
     }
 
     public static void main(String[] args) throws Exception {
@@ -183,7 +183,7 @@ public class AliyunOcrClient extends OcrApiClient {
         config.setSecret("twrXT7Dp1kG1bV5HZn6vgpoypu9PnZ");
         config.setQps(0);
         AliyunOcrClient client = new AliyunOcrClient(config);
-        System.out.println(client.call(OcrType.GENERAL,
+        System.out.println(client.forImage(OcrType.GENERAL,
                 ByteArray.fromFile(new File("/Users/luoshi/Downloads/cache/1-1.jpg")).value()));
     }
 

+ 4 - 0
src/main/java/com/qmth/ops/biz/service/InitService.java

@@ -28,6 +28,9 @@ public class InitService implements SqlProvider, CommandLineRunner {
     @Resource
     private LlmClientService llmClientService;
 
+    @Resource
+    private OcrClientService ocrClientService;
+
     @Override
     public String get() {
         try {
@@ -59,6 +62,7 @@ public class InitService implements SqlProvider, CommandLineRunner {
         }
 
         llmClientService.init();
+        ocrClientService.init();
         log.info("LLM Client初始化完成");
     }
 

+ 1 - 0
src/main/java/com/qmth/ops/biz/service/LlmClientService.java

@@ -30,6 +30,7 @@ public class LlmClientService {
     private Map<Long, ChatApiClient> chatApiClientMap = new HashMap<>();
 
     public synchronized void init() {
+        chatApiClientMap.clear();
         List<LlmSupplier> supplierList = supplierService.list();
         for (LlmSupplier supplier : supplierList) {
             List<LlmModel> modelList = modelService.listBySupplier(supplier.getId());

+ 6 - 9
src/main/java/com/qmth/ops/biz/service/OcrClientService.java

@@ -10,9 +10,7 @@ import org.slf4j.LoggerFactory;
 import org.springframework.stereotype.Service;
 
 import javax.annotation.Resource;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 
 @Service
 public class OcrClientService {
@@ -21,15 +19,15 @@ public class OcrClientService {
 
     private OcrApiClient defaultClient;
 
-    private Map<Long, OcrApiClient> clientMap = new HashMap<>();
-
     @Resource
     private OcrSupplierService ocrSupplierService;
 
-    public void init() {
+    public synchronized void init() {
+        //暂时读取第一个为默认客户端
+        defaultClient = null;
         List<OcrSupplier> list = ocrSupplierService.list();
-        for (OcrSupplier supplier : list) {
-            initApiClient(supplier);
+        if (!list.isEmpty()) {
+            initApiClient(list.get(0));
         }
     }
 
@@ -39,7 +37,6 @@ public class OcrClientService {
             OcrApiConfig config = new OcrApiConfig(supplier);
             Class<?> clientClass = Class.forName(className);
             OcrApiClient client = (OcrApiClient) clientClass.getConstructor(OcrApiConfig.class).newInstance(config);
-            clientMap.put(supplier.getId(), client);
             if (defaultClient == null) {
                 defaultClient = client;
             }
@@ -53,6 +50,6 @@ public class OcrClientService {
         if (defaultClient == null) {
             throw new OcrClientNotFound(type);
         }
-        return defaultClient.call(type, imageData);
+        return defaultClient.forImage(type, imageData);
     }
 }