瀏覽代碼

ocr api support ImageType

deason 1 月之前
父節點
當前提交
d15adfa548

+ 3 - 1
src/main/java/com/qmth/ops/api/controller/admin/OcrSupplierController.java

@@ -1,5 +1,6 @@
 package com.qmth.ops.api.controller.admin;
 
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.ops.api.constants.OpsApiConstants;
 import com.qmth.ops.api.security.AdminSession;
@@ -57,7 +58,8 @@ public class OcrSupplierController {
     public Object test(@RequestParam Long id, @RequestParam OcrType type, @RequestParam MultipartFile image)
             throws Exception {
         Map<String, String> result = new HashMap<>();
-        result.put("text", ocrClientService.forImage(id, type, image.getBytes()));
+        ImageType imageType = ImageType.find(image.getOriginalFilename());
+        result.put("text", ocrClientService.forImage(id, type, image.getBytes(), imageType));
         return result;
     }
 

+ 4 - 2
src/main/java/com/qmth/ops/api/controller/ai/OcrController.java

@@ -2,6 +2,7 @@ package com.qmth.ops.api.controller.ai;
 
 import com.qmth.boot.api.annotation.Aac;
 import com.qmth.boot.core.ai.model.AiConstants;
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.tools.signature.SignatureType;
 import com.qmth.ops.api.security.AccessOrg;
@@ -23,8 +24,9 @@ public class OcrController {
 
     @PostMapping(AiConstants.OCR_IMAGE_PATH)
     public String forImage(@RequestAttribute AccessOrg accessOrg, @RequestParam("type") OcrType type,
-            @RequestParam("image") MultipartFile file) throws Exception {
-        return ocrClientService.forImage(type, file.getBytes());
+                           @RequestParam("image") MultipartFile file) throws Exception {
+        ImageType imageType = ImageType.find(file.getOriginalFilename());
+        return ocrClientService.forImage(type, file.getBytes(), imageType);
     }
 
 }

+ 4 - 3
src/main/java/com/qmth/ops/biz/ai/client/OcrApiClient.java

@@ -1,6 +1,7 @@
 package com.qmth.ops.biz.ai.client;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.rateLimit.service.RateLimiter;
 import com.qmth.boot.core.rateLimit.service.impl.MemoryRateLimiter;
@@ -42,7 +43,7 @@ public abstract class OcrApiClient {
         return config;
     }
 
-    protected abstract Request buildRequest(OcrType type, byte[] image) throws Exception;
+    protected abstract Request buildRequest(OcrType type, byte[] image, ImageType imageType) throws Exception;
 
     protected abstract String buildResult(byte[] data, ObjectMapper mapper) throws IOException;
 
@@ -52,12 +53,12 @@ public abstract class OcrApiClient {
         return null;
     }
 
-    public String forImage(OcrType type, byte[] image) throws Exception {
+    public String forImage(OcrType type, byte[] image, ImageType imageType) throws Exception {
         if (queryRateLimiter != null && !queryRateLimiter.acquire()) {
             throw new OcrRateLimitExceeded(config.getQps());
         }
 
-        try (Response response = client.newCall(buildRequest(type, image)).execute()) {
+        try (Response response = client.newCall(buildRequest(type, image, imageType)).execute()) {
             byte[] data = response.body() != null ? response.body().bytes() : null;
             if (response.isSuccessful()) {
                 return buildResult(data, mapper);

+ 20 - 19
src/main/java/com/qmth/ops/biz/ai/client/aliyun/ocr/AliyunOcrClient.java

@@ -1,6 +1,7 @@
 package com.qmth.ops.biz.ai.client.aliyun.ocr;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.exception.ParameterException;
 import com.qmth.boot.core.exception.ReentrantException;
@@ -48,16 +49,16 @@ public class AliyunOcrClient extends OcrApiClient {
             }
         }
         switch (statusCode) {
-        case 400:
-        case 413:
-            throw new ParameterException(error != null ? error.getMessage() : "ocr request parameter error");
-        case 401:
-            throw new UnauthorizedException(error != null ? error.getMessage() : "ocr api unauthorized");
-        case 503:
-        case 504:
-            throw new ReentrantException(error != null ? error.getMessage() : "ocr api temporary faile");
-        default:
-            throw new StatusException(error != null ? error.getMessage() : "ocr api error");
+            case 400:
+            case 413:
+                throw new ParameterException(error != null ? error.getMessage() : "ocr request parameter error");
+            case 401:
+                throw new UnauthorizedException(error != null ? error.getMessage() : "ocr api unauthorized");
+            case 503:
+            case 504:
+                throw new ReentrantException(error != null ? error.getMessage() : "ocr api temporary faile");
+            default:
+                throw new StatusException(error != null ? error.getMessage() : "ocr api error");
         }
     }
 
@@ -123,14 +124,14 @@ public class AliyunOcrClient extends OcrApiClient {
     private Map<String, String> getCommonParameters(OcrType type) {
         String action;
         switch (type) {
-        case GENERAL:
-            action = "RecognizeGeneral";
-            break;
-        case HANDWRITING:
-            action = "RecognizeHandwriting";
-            break;
-        default:
-            throw new ParameterException("OcrType is invalid");
+            case GENERAL:
+                action = "RecognizeGeneral";
+                break;
+            case HANDWRITING:
+                action = "RecognizeHandwriting";
+                break;
+            default:
+                throw new ParameterException("OcrType is invalid");
         }
         return new HashMap<String, String>() {{
             put("Action", action); // 调用的接口名称,此处以 RecognizeGeneral 为例
@@ -172,7 +173,7 @@ public class AliyunOcrClient extends OcrApiClient {
     }
 
     @Override
-    protected Request buildRequest(OcrType type, byte[] image) throws Exception {
+    protected Request buildRequest(OcrType type, byte[] image, ImageType imageType) throws Exception {
         return new Request.Builder().url(buildUrl(type))
                 .post(RequestBody.create(MediaType.parse("application/octet-stream"), image)).build();
     }

+ 28 - 16
src/main/java/com/qmth/ops/biz/ai/client/aliyun/qwen_ocr/QwenOcrClient.java

@@ -2,15 +2,19 @@ package com.qmth.ops.biz.ai.client.aliyun.qwen_ocr;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.qmth.boot.core.ai.model.llm.ChatRole;
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
+import com.qmth.boot.core.exception.ParameterException;
+import com.qmth.boot.core.exception.StatusException;
+import com.qmth.boot.core.exception.UnauthorizedException;
 import com.qmth.boot.tools.codec.CodecUtils;
 import com.qmth.boot.tools.models.ByteArray;
 import com.qmth.ops.biz.ai.client.OcrApiClient;
 import com.qmth.ops.biz.ai.client.OcrApiConfig;
-import com.qmth.ops.biz.ai.exception.ChatRequestError;
 import okhttp3.MediaType;
 import okhttp3.Request;
 import okhttp3.RequestBody;
+import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -28,32 +32,37 @@ public class QwenOcrClient extends OcrApiClient {
     }
 
     @Override
-    protected Request buildRequest(OcrType type, byte[] image) throws Exception {
+    protected Request buildRequest(OcrType type, byte[] image, ImageType imageType) throws Exception {
         Map<String, Object> request = new HashMap<>();
         request.put("model", "qwen-vl-ocr");
+        // request.put("model", "qwen-vl-ocr-latest");
+
         List<Map<String, Object>> messages = new ArrayList<>();
         Map<String, Object> message = new HashMap<>();
         message.put("role", "user");
+
         List<Map<String, Object>> contents = new ArrayList<>();
         Map<String, Object> content1 = new HashMap<>();
         content1.put("type", "image_url");
         Map<String, String> urlBase64 = new HashMap<>();
-        String base64 = CodecUtils.toBase64(image);
-        urlBase64.put("url", "data:image/jpeg;base64," + base64);
+        String base64 = CodecUtils.toBase64(image);//默认:图片大小不超过10MB。
+        urlBase64.put("url", imageType.getBase64Prefix() + base64);
         content1.put("image_url", urlBase64);
-        // content1.put("min_pixels", 1000);
-        // content1.put("max_pixels", 1280000);
+        content1.put("min_pixels", 3136);//最小值:3136
+        content1.put("max_pixels", 2352000);//默认:1003520,最大值:28*28*30000
+        contents.add(content1);
+
         Map<String, Object> content2 = new HashMap<>();
         content2.put("type", "text");
         content2.put("text", "Read all the text in the image.");
-        contents.add(content1);
         contents.add(content2);
+
         messages.add(message);
         message.put("content", contents);
         request.put("messages", messages);
 
         String json = new ObjectMapper().writeValueAsString(request);
-        log.info("request:{}", json);
+        // log.info("request:{}", json);
         RequestBody requestBody = RequestBody.create(MediaType.parse("application/json"), json);
         return new Request.Builder().url(getConfig().getUrl())
                 .addHeader("Authorization", "Bearer " + getConfig().getSecret())
@@ -64,7 +73,10 @@ public class QwenOcrClient extends OcrApiClient {
     @Override
     protected String buildResult(byte[] data, ObjectMapper mapper) throws IOException {
         String json = data != null ? new String(data, StandardCharsets.UTF_8) : null;
-        log.info("response:{}", json);
+        // log.info("response:{}", json);
+        if (StringUtils.isEmpty(json)) {
+            return "";
+        }
 
         QwenOcrResult result = mapper.readValue(json, QwenOcrResult.class);
         return result.getChoices().stream().filter(choice -> choice.getMessage().getRole() == ChatRole.assistant)
@@ -74,7 +86,7 @@ public class QwenOcrClient extends OcrApiClient {
     @Override
     protected String handleError(byte[] data, int statusCode, ObjectMapper mapper) {
         String error = data != null ? new String(data, StandardCharsets.UTF_8) : null;
-        log.info("responseError:{}", error);
+        log.warn("responseError:{}", error);
 
         if (data != null) {
             try {
@@ -89,25 +101,25 @@ public class QwenOcrClient extends OcrApiClient {
 
         switch (statusCode) {
             case 400:
-                throw new ChatRequestError(Optional.ofNullable(error).orElse("chat request error"));
-            case 429:
-                throw new ChatRequestError(Optional.ofNullable(error).orElse("chat model rate limit exceeded"));
+                throw new ParameterException(Optional.ofNullable(error).orElse("ocr request parameter error"));
+            case 401:
+            case 403:
+                throw new UnauthorizedException(Optional.ofNullable(error).orElse("ocr api unauthorized"));
             default:
-                throw new ChatRequestError(Optional.ofNullable(error).orElse("chat model error"));
+                throw new StatusException(Optional.ofNullable(error).orElse("ocr api error"));
         }
     }
 
     public static void main(String[] args) throws Exception {
         OcrApiConfig config = new OcrApiConfig();
         config.setUrl("https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions");
-        // config.setUrl("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation");
         config.setSecret("sk-xxx");
         config.setQps(10);
 
         File file = new File("D:\\home\\test.png");
         byte[] image = ByteArray.fromFile(file).value();
         QwenOcrClient client = new QwenOcrClient(config);
-        String value = client.forImage(OcrType.HANDWRITING, image);
+        String value = client.forImage(OcrType.HANDWRITING, image, ImageType.find(file.getName()));
         System.out.println(value);
     }
 

+ 16 - 13
src/main/java/com/qmth/ops/biz/ai/client/baidu/ocr/BaiduOcrClient.java

@@ -1,6 +1,7 @@
 package com.qmth.ops.biz.ai.client.baidu.ocr;
 
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.exception.ParameterException;
 import com.qmth.boot.core.exception.ReentrantException;
@@ -57,21 +58,21 @@ public class BaiduOcrClient extends OcrApiClient {
             }
         }
         switch (statusCode) {
-        case 400:
-        case 413:
-            throw new ParameterException(error != null ? error.getMessage() : "ocr request parameter error");
-        case 401:
-            throw new UnauthorizedException(error != null ? error.getMessage() : "ocr api unauthorized");
-        case 503:
-        case 504:
-            throw new ReentrantException(error != null ? error.getMessage() : "ocr api temporary faile");
-        default:
-            throw new StatusException(error != null ? error.getMessage() : "ocr api error");
+            case 400:
+            case 413:
+                throw new ParameterException(error != null ? error.getMessage() : "ocr request parameter error");
+            case 401:
+                throw new UnauthorizedException(error != null ? error.getMessage() : "ocr api unauthorized");
+            case 503:
+            case 504:
+                throw new ReentrantException(error != null ? error.getMessage() : "ocr api temporary faile");
+            default:
+                throw new StatusException(error != null ? error.getMessage() : "ocr api error");
         }
     }
 
     @Override
-    protected Request buildRequest(OcrType type, byte[] image) throws Exception {
+    protected Request buildRequest(OcrType type, byte[] image, ImageType imageType) throws Exception {
         String url = buildUrl(type);
         SimpleDateFormat format = new SimpleDateFormat(BceV1Signer.DATE_FORMAT_PATTERN);
         format.setTimeZone(TimeZone.getTimeZone("UTC"));
@@ -115,9 +116,11 @@ public class BaiduOcrClient extends OcrApiClient {
         config.setKey("");
         config.setSecret("");
         config.setQps(10);
+        File file = new File("/Users/luoshi/Downloads/test.jpg");
+        byte[] image = ByteArray.fromFile(file).value();
+
         BaiduOcrClient client = new BaiduOcrClient(config);
-        System.out.println(client.forImage(OcrType.HANDWRITING,
-                ByteArray.fromFile(new File("/Users/luoshi/Downloads/test.jpg")).value()));
+        System.out.println(client.forImage(OcrType.HANDWRITING, image, ImageType.find(file.getName())));
     }
 
 }

+ 7 - 4
src/main/java/com/qmth/ops/biz/service/OcrClientService.java

@@ -1,5 +1,6 @@
 package com.qmth.ops.biz.service;
 
+import com.qmth.boot.core.ai.model.ocr.ImageType;
 import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.ops.biz.ai.client.OcrApiClient;
 import com.qmth.ops.biz.ai.client.OcrApiConfig;
@@ -54,18 +55,20 @@ public class OcrClientService {
         }
     }
 
-    public String forImage(Long id, OcrType type, byte[] imageData) throws Exception {
+    public String forImage(Long id, OcrType type, byte[] imageData, ImageType imageType) throws Exception {
         OcrApiClient client = clientMap.get(id);
+
         if (defaultClient == null) {
             throw new OcrClientNotFound(id);
         }
-        return client.forImage(type, imageData);
+        return client.forImage(type, imageData, imageType);
     }
 
-    public String forImage(OcrType type, byte[] imageData) throws Exception {
+    public String forImage(OcrType type, byte[] imageData, ImageType imageType) throws Exception {
         if (defaultClient == null) {
             throw new OcrClientNotFound();
         }
-        return defaultClient.forImage(type, imageData);
+        return defaultClient.forImage(type, imageData, imageType);
     }
+
 }