Эх сурвалжийг харах

增加ocr供应商baidu;增加ocrsupplier的prior字段

luoshi 8 сар өмнө
parent
commit
ebaa5cc738

+ 11 - 5
src/main/java/com/qmth/ops/biz/ai/client/OcrApiClient.java

@@ -5,10 +5,7 @@ import com.qmth.boot.core.ai.model.ocr.OcrType;
 import com.qmth.boot.core.rateLimit.service.RateLimiter;
 import com.qmth.boot.core.rateLimit.service.impl.MemoryRateLimiter;
 import com.qmth.ops.biz.ai.exception.OcrRateLimitExceeded;
-import okhttp3.ConnectionPool;
-import okhttp3.OkHttpClient;
-import okhttp3.Request;
-import okhttp3.Response;
+import okhttp3.*;
 
 import java.io.IOException;
 
@@ -26,7 +23,12 @@ public abstract class OcrApiClient {
     private RateLimiter queryRateLimiter;
 
     public OcrApiClient(OcrApiConfig config) {
-        this.client = new OkHttpClient.Builder().connectionPool(new ConnectionPool()).build();
+        OkHttpClient.Builder buidler = new OkHttpClient.Builder().connectionPool(new ConnectionPool());
+        Interceptor interceptor = getInterceptor();
+        if (interceptor != null) {
+            buidler.addInterceptor(interceptor);
+        }
+        this.client = buidler.build();
         this.mapper = new ObjectMapper();
         this.config = config;
         if (config.getQps() > 0) {
@@ -44,6 +46,10 @@ public abstract class OcrApiClient {
 
     protected abstract String handleError(byte[] data, int statusCode, ObjectMapper mapper);
 
+    protected Interceptor getInterceptor() {
+        return null;
+    }
+
     public String forImage(OcrType type, byte[] image) throws Exception {
         if (queryRateLimiter != null && !queryRateLimiter.acquire()) {
             throw new OcrRateLimitExceeded(config.getQps());

+ 41 - 0
src/main/java/com/qmth/ops/biz/ai/client/baidu/BaiduError.java

@@ -0,0 +1,41 @@
+package com.qmth.ops.biz.ai.client.baidu;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class BaiduError {
+
+    @JsonProperty("log_id")
+    private String logId;
+
+    @JsonProperty("error_code")
+    private Integer code;
+
+    @JsonProperty("error_msg")
+    private String message;
+
+    public String getLogId() {
+        return logId;
+    }
+
+    public void setLogId(String logId) {
+        this.logId = logId;
+    }
+
+    public Integer getCode() {
+        return code;
+    }
+
+    public void setCode(Integer code) {
+        this.code = code;
+    }
+
+    public String getMessage() {
+        return message;
+    }
+
+    public void setMessage(String message) {
+        this.message = message;
+    }
+}

+ 176 - 0
src/main/java/com/qmth/ops/biz/ai/client/baidu/BceV1Signer.java

@@ -0,0 +1,176 @@
+package com.qmth.ops.biz.ai.client.baidu;
+
+import com.qmth.boot.tools.models.ByteArray;
+import okhttp3.Headers;
+import okhttp3.HttpUrl;
+import okhttp3.Request;
+import org.apache.commons.lang3.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.http.HttpHeaders;
+
+import javax.crypto.Mac;
+import javax.crypto.spec.SecretKeySpec;
+import java.nio.charset.StandardCharsets;
+import java.util.*;
+
+/**
+ * The V1 implementation of Signer with the BCE signing protocol.
+ */
+public class BceV1Signer {
+
+    private static final Logger logger = LoggerFactory.getLogger(BceV1Signer.class);
+
+    public static final String X_BCE_DATE = "x-bce-date";
+
+    private static final String BCE_PREFIX = "x-bce-";
+
+    private static final String BCE_AUTH_VERSION = "bce-auth-v1";
+
+    // Default headers to sign with the BCE signing protocol.
+    private static final Set<String> defaultHeadersToSign = new HashSet<>();
+
+    private static BitSet URI_UNRESERVED_CHARACTERS = new BitSet();
+
+    private static String[] PERCENT_ENCODED_STRINGS = new String[256];
+
+    private static final String headerJoiner = "\n";
+
+    private static final String singedHeaderJoiner = ";";
+
+    private static final String queryStringJoiner = "&";
+
+    private static final int DEFAULT_EXPIRATION_IN_SECONDS = 1800;
+
+    public static final String DATE_FORMAT_PATTERN = "yyyy-MM-dd'T'HH:mm:ss'Z'";
+
+    static {
+        BceV1Signer.defaultHeadersToSign.add(HttpHeaders.HOST.toLowerCase());
+        //BceV1Signer.defaultHeadersToSign.add(HttpHeaders.CONTENT_TYPE.toLowerCase());
+        //BceV1Signer.defaultHeadersToSign.add(HttpHeaders.CONTENT_LENGTH.toLowerCase());
+        //BceV1Signer.defaultHeadersToSign.add("content-md5");
+
+        for (int i = 'a'; i <= 'z'; i++) {
+            URI_UNRESERVED_CHARACTERS.set(i);
+        }
+        for (int i = 'A'; i <= 'Z'; i++) {
+            URI_UNRESERVED_CHARACTERS.set(i);
+        }
+        for (int i = '0'; i <= '9'; i++) {
+            URI_UNRESERVED_CHARACTERS.set(i);
+        }
+        URI_UNRESERVED_CHARACTERS.set('-');
+        URI_UNRESERVED_CHARACTERS.set('.');
+        URI_UNRESERVED_CHARACTERS.set('_');
+        URI_UNRESERVED_CHARACTERS.set('~');
+
+        for (int i = 0; i < PERCENT_ENCODED_STRINGS.length; ++i) {
+            PERCENT_ENCODED_STRINGS[i] = String.format("%%%02X", i);
+        }
+    }
+
+    public static String sign(Request request, String accessKey, String accessSecret) {
+        String timestamp = request.header(X_BCE_DATE);
+        String authString =
+                BceV1Signer.BCE_AUTH_VERSION + "/" + accessKey + "/" + timestamp + "/" + DEFAULT_EXPIRATION_IN_SECONDS;
+
+        String signingKey = sha256Hex(accessSecret, authString);
+        // Formatting the URL with signing protocol.
+        String canonicalURI = getCanonicalURIPath(request.url().uri().getPath());
+        // Formatting the query string with signing protocol.
+        String canonicalQueryString = getCanonicalQueryString(request.url());
+        // Formatting the headers from the request based on signing protocol.
+        String canonicalHeader = getCanonicalHeaders(request.headers());
+        String signedHeaders = getSignedHeaders(request.headers());
+
+        String canonicalRequest =
+                request.method().toUpperCase() + "\n" + canonicalURI + "\n" + canonicalQueryString + "\n"
+                        + canonicalHeader;
+
+        // Signing the canonical request using key with sha-256 algorithm.
+        String signature = sha256Hex(signingKey, canonicalRequest);
+
+        String authorizationHeader = authString + "/" + signedHeaders + "/" + signature;
+
+        //logger.debug("\nCanonicalRequest:\n{}\n-----------\nAuthorization:\n{}", canonicalRequest, authorizationHeader);
+
+        return authorizationHeader;
+    }
+
+    private static String getCanonicalURIPath(String path) {
+        if (path == null) {
+            return "/";
+        } else if (path.startsWith("/")) {
+            return normalizePath(path);
+        } else {
+            return "/" + normalizePath(path);
+        }
+    }
+
+    private static String getSignedHeaders(Headers headers) {
+        Set<String> headerStrings = new HashSet<>();
+        for (String name : headers.names()) {
+            name = name.toLowerCase();
+            if (name.startsWith(BCE_PREFIX) || defaultHeadersToSign.contains(name)) {
+                headerStrings.add(name);
+            }
+        }
+        List<String> list = new ArrayList<>(headerStrings);
+        Collections.sort(list);
+        return StringUtils.join(list, singedHeaderJoiner);
+    }
+
+    private static String getCanonicalHeaders(Headers headers) {
+        List<String> headerStrings = new LinkedList<>();
+        for (String name : headers.names()) {
+            String value = StringUtils.trimToEmpty(headers.get(name));
+            name = name.toLowerCase();
+            if (name.startsWith(BCE_PREFIX) || defaultHeadersToSign.contains(name)) {
+                headerStrings.add(normalize(name) + ':' + normalize(value));
+            }
+        }
+        Collections.sort(headerStrings);
+        return StringUtils.join(headerStrings, headerJoiner);
+    }
+
+    private static String sha256Hex(String signingKey, String stringToSign) {
+        try {
+            Mac mac = Mac.getInstance("HmacSHA256");
+            mac.init(new SecretKeySpec(signingKey.getBytes(StandardCharsets.UTF_8), "HmacSHA256"));
+            return ByteArray.fromArray(mac.doFinal(stringToSign.getBytes(StandardCharsets.UTF_8))).toHexString()
+                    .toLowerCase();
+        } catch (Exception e) {
+            throw new RuntimeException("Fail to generate the signature", e);
+        }
+    }
+
+    private static String normalizePath(String path) {
+        return normalize(path).replace("%2F", "/");
+    }
+
+    private static String normalize(String value) {
+        StringBuilder builder = new StringBuilder();
+        for (byte b : value.getBytes(StandardCharsets.UTF_8)) {
+            if (URI_UNRESERVED_CHARACTERS.get(b & 0xFF)) {
+                builder.append((char) b);
+            } else {
+                builder.append(PERCENT_ENCODED_STRINGS[b & 0xFF]);
+            }
+        }
+        return builder.toString();
+    }
+
+    private static String getCanonicalQueryString(HttpUrl url) {
+        List<String> parameterStrings = new ArrayList<>();
+        for (String name : url.queryParameterNames()) {
+            if (HttpHeaders.AUTHORIZATION.equalsIgnoreCase(name)) {
+                continue;
+            }
+            String value = url.queryParameter(name);
+            parameterStrings.add(normalize(name) + '=' + normalize(StringUtils.trimToEmpty(value)));
+        }
+        Collections.sort(parameterStrings);
+        return StringUtils.join(parameterStrings, queryStringJoiner);
+    }
+
+}

+ 124 - 0
src/main/java/com/qmth/ops/biz/ai/client/baidu/ocr/BaiduOcrClient.java

@@ -0,0 +1,124 @@
+package com.qmth.ops.biz.ai.client.baidu.ocr;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.qmth.boot.core.ai.model.ocr.OcrType;
+import com.qmth.boot.core.exception.ParameterException;
+import com.qmth.boot.core.exception.ReentrantException;
+import com.qmth.boot.core.exception.StatusException;
+import com.qmth.boot.core.exception.UnauthorizedException;
+import com.qmth.boot.tools.codec.CodecUtils;
+import com.qmth.boot.tools.models.ByteArray;
+import com.qmth.ops.biz.ai.client.OcrApiClient;
+import com.qmth.ops.biz.ai.client.OcrApiConfig;
+import com.qmth.ops.biz.ai.client.baidu.BaiduError;
+import com.qmth.ops.biz.ai.client.baidu.BceV1Signer;
+import okhttp3.FormBody;
+import okhttp3.HttpUrl;
+import okhttp3.Interceptor;
+import okhttp3.Request;
+import org.apache.commons.lang3.StringUtils;
+import org.springframework.http.HttpHeaders;
+
+import java.io.File;
+import java.io.IOException;
+import java.text.SimpleDateFormat;
+import java.util.Date;
+import java.util.TimeZone;
+import java.util.stream.Collectors;
+
+public class BaiduOcrClient extends OcrApiClient {
+
+    public BaiduOcrClient(OcrApiConfig config) {
+        super(config);
+    }
+
+    @Override
+    protected String buildResult(byte[] data, ObjectMapper mapper) throws IOException {
+        BaiduOcrResult result = mapper.readValue(data, BaiduOcrResult.class);
+        if (result != null && result.getWordsList() != null) {
+            if (!result.getWordsList().isEmpty()) {
+                return mapper.readValue(data, BaiduOcrResult.class).getWordsList().stream().map(BaiduOcrWords::getWords)
+                        .collect(Collectors.joining(" "));
+                //.replaceAll("☰", "").replaceAll("≡", "");
+            } else {
+                return StringUtils.EMPTY;
+            }
+        } else {
+            return handleError(data, 500, mapper);
+        }
+    }
+
+    @Override
+    protected String handleError(byte[] data, int statusCode, ObjectMapper mapper) {
+        BaiduError error = null;
+        if (data != null) {
+            try {
+                error = mapper.readValue(data, BaiduError.class);
+            } catch (Exception ignore) {
+            }
+        }
+        switch (statusCode) {
+        case 400:
+        case 413:
+            throw new ParameterException(error != null ? error.getMessage() : "ocr request parameter error");
+        case 401:
+            throw new UnauthorizedException(error != null ? error.getMessage() : "ocr api unauthorized");
+        case 503:
+        case 504:
+            throw new ReentrantException(error != null ? error.getMessage() : "ocr api temporary faile");
+        default:
+            throw new StatusException(error != null ? error.getMessage() : "ocr api error");
+        }
+    }
+
+    @Override
+    protected Request buildRequest(OcrType type, byte[] image) throws Exception {
+        String url = buildUrl(type);
+        SimpleDateFormat format = new SimpleDateFormat(BceV1Signer.DATE_FORMAT_PATTERN);
+        format.setTimeZone(TimeZone.getTimeZone("UTC"));
+        return new Request.Builder().url(url).addHeader(BceV1Signer.X_BCE_DATE, format.format(new Date()))
+                .addHeader(HttpHeaders.HOST, HttpUrl.parse(url).host())
+                .addHeader(HttpHeaders.CONTENT_TYPE, "application/x-www-form-urlencoded").post(buildForm(type, image))
+                .build();
+    }
+
+    private String buildUrl(OcrType type) {
+        StringBuilder url = new StringBuilder(getConfig().getUrl());
+        if (type == OcrType.GENERAL) {
+            url.append("general_basic");
+        } else if (type == OcrType.HANDWRITING) {
+            url.append("handwriting");
+        }
+        return url.toString();
+    }
+
+    private FormBody buildForm(OcrType type, byte[] image) {
+        FormBody.Builder builder = new FormBody.Builder();
+        builder.add("image", CodecUtils.toBase64(image));
+        //if (type == OcrType.HANDWRITING) {
+        //builder.add("detect_alteration", "true");
+        //}
+        return builder.build();
+    }
+
+    @Override
+    protected Interceptor getInterceptor() {
+        return chain -> {
+            Request request = chain.request();
+            return chain.proceed(request.newBuilder().addHeader(HttpHeaders.AUTHORIZATION,
+                    BceV1Signer.sign(request, getConfig().getKey(), getConfig().getSecret())).build());
+        };
+    }
+
+    public static void main(String[] args) throws Exception {
+        OcrApiConfig config = new OcrApiConfig();
+        config.setUrl("https://aip.baidubce.com/rest/2.0/ocr/v1/");
+        config.setKey("");
+        config.setSecret("");
+        config.setQps(10);
+        BaiduOcrClient client = new BaiduOcrClient(config);
+        System.out.println(client.forImage(OcrType.HANDWRITING,
+                ByteArray.fromFile(new File("/Users/luoshi/Downloads/test.jpg")).value()));
+    }
+
+}

+ 43 - 0
src/main/java/com/qmth/ops/biz/ai/client/baidu/ocr/BaiduOcrResult.java

@@ -0,0 +1,43 @@
+package com.qmth.ops.biz.ai.client.baidu.ocr;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.util.List;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class BaiduOcrResult {
+
+    @JsonProperty("log_id")
+    private String logId;
+
+    @JsonProperty("words_result_num")
+    private Integer wordsResultNumber;
+
+    @JsonProperty("words_result")
+    private List<BaiduOcrWords> wordsList;
+
+    public String getLogId() {
+        return logId;
+    }
+
+    public void setLogId(String logId) {
+        this.logId = logId;
+    }
+
+    public Integer getWordsResultNumber() {
+        return wordsResultNumber;
+    }
+
+    public void setWordsResultNumber(Integer wordsResultNumber) {
+        this.wordsResultNumber = wordsResultNumber;
+    }
+
+    public List<BaiduOcrWords> getWordsList() {
+        return wordsList;
+    }
+
+    public void setWordsList(List<BaiduOcrWords> wordsList) {
+        this.wordsList = wordsList;
+    }
+}

+ 19 - 0
src/main/java/com/qmth/ops/biz/ai/client/baidu/ocr/BaiduOcrWords.java

@@ -0,0 +1,19 @@
+package com.qmth.ops.biz.ai.client.baidu.ocr;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class BaiduOcrWords {
+
+    @JsonProperty("words")
+    private String words;
+
+    public String getWords() {
+        return words;
+    }
+
+    public void setWords(String words) {
+        this.words = words;
+    }
+}

+ 10 - 0
src/main/java/com/qmth/ops/biz/domain/OcrSupplier.java

@@ -28,6 +28,8 @@ public class OcrSupplier implements Serializable {
 
     private Integer qps;
 
+    private Boolean prior;
+
     private Long createTime;
 
     private Long updateTime;
@@ -88,6 +90,14 @@ public class OcrSupplier implements Serializable {
         this.qps = qps;
     }
 
+    public Boolean getPrior() {
+        return prior;
+    }
+
+    public void setPrior(Boolean prior) {
+        this.prior = prior;
+    }
+
     public Long getCreateTime() {
         return createTime;
     }

+ 2 - 2
src/main/java/com/qmth/ops/biz/service/OcrClientService.java

@@ -23,11 +23,11 @@ public class OcrClientService {
     private OcrSupplierService ocrSupplierService;
 
     public synchronized void init() {
-        //暂时读取第一个为默认客户端
+        //取第一个prior的为默认客户端,没有则取第一个
         defaultClient = null;
         List<OcrSupplier> list = ocrSupplierService.list();
         if (!list.isEmpty()) {
-            initApiClient(list.get(0));
+            initApiClient(list.stream().filter(OcrSupplier::getPrior).findFirst().orElse(list.get(0)));
         }
     }
 

+ 1 - 0
src/main/java/com/qmth/ops/biz/service/OcrSupplierService.java

@@ -31,6 +31,7 @@ public class OcrSupplierService extends ServiceImpl<OcrSupplierDao, OcrSupplier>
                 .set(supplier.getKey() != null, OcrSupplier::getKey, supplier.getKey())
                 .set(supplier.getSecret() != null, OcrSupplier::getSecret, supplier.getSecret())
                 .set(supplier.getQps() != null, OcrSupplier::getQps, supplier.getQps())
+                .set(supplier.getPrior() != null, OcrSupplier::getPrior, supplier.getPrior())
                 .set(OcrSupplier::getUpdateTime, System.currentTimeMillis()).eq(OcrSupplier::getId, supplier.getId()));
     }
 

+ 1 - 0
src/main/resources/script/init.sql

@@ -262,6 +262,7 @@ CREATE TABLE IF NOT EXISTS `ocr_supplier`
     `secret`       varchar(128)        NOT NULL,
     `client_class` varchar(128)        NOT NULL,
     `qps`          int(11)             NOT NULL,
+    `prior`        tinyint(1)          NOT NULL,
     `create_time`  bigint(20)          NOT NULL,
     `update_time`  bigint(20)          NOT NULL,
     PRIMARY KEY (`id`)