Browse Source

修改OcrSupplier模型增加enable字段控制;增加OcrSupplier的测试接口

luoshi 7 months ago
parent
commit
c757d8a502

+ 15 - 1
src/main/java/com/qmth/ops/api/controller/admin/OcrSupplierController.java

@@ -1,5 +1,6 @@
 package com.qmth.ops.api.controller.admin;
 
+import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.ops.api.constants.OpsApiConstants;
 import com.qmth.ops.api.security.AdminSession;
 import com.qmth.ops.api.security.Permission;
@@ -8,9 +9,12 @@ import com.qmth.ops.biz.domain.OcrSupplier;
 import com.qmth.ops.biz.service.OcrClientService;
 import com.qmth.ops.biz.service.OcrSupplierService;
 import org.springframework.web.bind.annotation.*;
+import org.springframework.web.multipart.MultipartFile;
 
 import javax.annotation.Resource;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 
 @RestController
@@ -36,7 +40,9 @@ public class OcrSupplierController {
     @PostMapping("/insert")
     public OcrSupplier insert(@RequestAttribute AdminSession adminSession, OcrSupplier supplier) {
         adminSession.hasPermission(Permission.OCR_SUPPLIER_INSERT);
-        return ocrSupplierService.insert(supplier);
+        ocrSupplierService.insert(supplier);
+        ocrClientService.init();
+        return ocrSupplierService.getById(supplier.getId());
     }
 
     @PostMapping("/update")
@@ -47,4 +53,12 @@ public class OcrSupplierController {
         return ocrSupplierService.getById(supplier.getId());
     }
 
+    @PostMapping("/test")
+    public Object test(@RequestParam Long id, @RequestParam OcrType type, @RequestParam MultipartFile image)
+            throws Exception {
+        Map<String, String> result = new HashMap<>();
+        result.put("text", ocrClientService.forImage(id, type, image.getBytes()));
+        return result;
+    }
+
 }

+ 11 - 0
src/main/java/com/qmth/ops/api/vo/OcrSupplierVO.java

@@ -10,6 +10,8 @@ public class OcrSupplierVO {
 
     private Integer qps;
 
+    private Boolean enable;
+
     private Long createTime;
 
     private Long updateTime;
@@ -18,6 +20,7 @@ public class OcrSupplierVO {
         this.id = supplier.getId();
         this.name = supplier.getName();
         this.qps = supplier.getQps();
+        this.enable = supplier.getEnable();
         this.createTime = supplier.getCreateTime();
         this.updateTime = supplier.getUpdateTime();
     }
@@ -46,6 +49,14 @@ public class OcrSupplierVO {
         this.qps = qps;
     }
 
+    public Boolean getEnable() {
+        return enable;
+    }
+
+    public void setEnable(Boolean enable) {
+        this.enable = enable;
+    }
+
     public Long getCreateTime() {
         return createTime;
     }

+ 6 - 3
src/main/java/com/qmth/ops/biz/ai/exception/OcrClientNotFound.java

@@ -1,13 +1,16 @@
 package com.qmth.ops.biz.ai.exception;
 
-import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.exception.NotFoundException;
 
 public class OcrClientNotFound extends NotFoundException {
 
     private static final long serialVersionUID = 6299687661502614806L;
 
-    public OcrClientNotFound(OcrType type) {
-        super("OCR api client not found for type: " + type);
+    public OcrClientNotFound() {
+        super("OCR api client not found");
+    }
+
+    public OcrClientNotFound(Long id) {
+        super("OCR api client not found for supplierId=" + id);
     }
 }

+ 5 - 5
src/main/java/com/qmth/ops/biz/domain/OcrSupplier.java

@@ -28,7 +28,7 @@ public class OcrSupplier implements Serializable {
 
     private Integer qps;
 
-    private Boolean prior;
+    private Boolean enable;
 
     private Long createTime;
 
@@ -90,12 +90,12 @@ public class OcrSupplier implements Serializable {
         this.qps = qps;
     }
 
-    public Boolean getPrior() {
-        return prior;
+    public Boolean getEnable() {
+        return enable;
     }
 
-    public void setPrior(Boolean prior) {
-        this.prior = prior;
+    public void setEnable(Boolean enable) {
+        this.enable = enable;
     }
 
     public Long getCreateTime() {

+ 21 - 5
src/main/java/com/qmth/ops/biz/service/OcrClientService.java

@@ -9,8 +9,11 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.stereotype.Service;
 
+import javax.annotation.PostConstruct;
 import javax.annotation.Resource;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 @Service
 public class OcrClientService {
@@ -19,15 +22,18 @@ public class OcrClientService {
 
     private OcrApiClient defaultClient;
 
+    private Map<Long, OcrApiClient> clientMap;
+
     @Resource
     private OcrSupplierService ocrSupplierService;
 
+    @PostConstruct
     public synchronized void init() {
-        //取第一个prior的为默认客户端,没有则取第一个
         defaultClient = null;
+        clientMap = new HashMap<>();
         List<OcrSupplier> list = ocrSupplierService.list();
-        if (!list.isEmpty()) {
-            initApiClient(list.stream().filter(OcrSupplier::getPrior).findFirst().orElse(list.get(0)));
+        for (OcrSupplier supplier : list) {
+            initApiClient(supplier);
         }
     }
 
@@ -37,7 +43,9 @@ public class OcrClientService {
             OcrApiConfig config = new OcrApiConfig(supplier);
             Class<?> clientClass = Class.forName(className);
             OcrApiClient client = (OcrApiClient) clientClass.getConstructor(OcrApiConfig.class).newInstance(config);
-            if (defaultClient == null) {
+            clientMap.put(supplier.getId(), client);
+            //取第一个enable=true的为默认客户端
+            if (supplier.getEnable() && defaultClient == null) {
                 defaultClient = client;
             }
         } catch (Exception e) {
@@ -46,9 +54,17 @@ public class OcrClientService {
         }
     }
 
+    public String forImage(Long id, OcrType type, byte[] imageData) throws Exception {
+        OcrApiClient client = clientMap.get(id);
+        if (defaultClient == null) {
+            throw new OcrClientNotFound(id);
+        }
+        return client.forImage(type, imageData);
+    }
+
     public String forImage(OcrType type, byte[] imageData) throws Exception {
         if (defaultClient == null) {
-            throw new OcrClientNotFound(type);
+            throw new OcrClientNotFound();
         }
         return defaultClient.forImage(type, imageData);
     }

+ 1 - 1
src/main/java/com/qmth/ops/biz/service/OcrSupplierService.java

@@ -31,7 +31,7 @@ public class OcrSupplierService extends ServiceImpl<OcrSupplierDao, OcrSupplier>
                 .set(supplier.getKey() != null, OcrSupplier::getKey, supplier.getKey())
                 .set(supplier.getSecret() != null, OcrSupplier::getSecret, supplier.getSecret())
                 .set(supplier.getQps() != null, OcrSupplier::getQps, supplier.getQps())
-                .set(supplier.getPrior() != null, OcrSupplier::getPrior, supplier.getPrior())
+                .set(supplier.getEnable() != null, OcrSupplier::getEnable, supplier.getEnable())
                 .set(OcrSupplier::getUpdateTime, System.currentTimeMillis()).eq(OcrSupplier::getId, supplier.getId()));
     }
 

+ 1 - 1
src/main/resources/script/init.sql

@@ -262,7 +262,7 @@ CREATE TABLE IF NOT EXISTS `ocr_supplier`
     `secret`       varchar(128)        NOT NULL,
     `client_class` varchar(128)        NOT NULL,
     `qps`          int(11)             NOT NULL,
-    `prior`        tinyint(1)          NOT NULL,
+    `enable`       tinyint(1)          NOT NULL,
     `create_time`  bigint(20)          NOT NULL,
     `update_time`  bigint(20)          NOT NULL,
     PRIMARY KEY (`id`)