Эх сурвалжийг харах

chat接口及实现类增加appType参数,控制请求参数

luoshi 11 сар өмнө
parent
commit
1aa65bc697

+ 2 - 2
src/main/java/com/qmth/ops/api/controller/ai/LlmController.java

@@ -41,7 +41,7 @@ public class LlmController {
             throw new ForbiddenException(
                     "Chat api is disabled or exhausted for org=" + accessOrg.getOrg().getCode() + ", app_type=" + type);
         }
-        ChatResult result = llmClientService.chat(request, config.getModelId());
+        ChatResult result = llmClientService.chat(request, config.getModelId(), type);
         llmOrgConfigService.consume(config);
         return result;
     }
@@ -69,7 +69,7 @@ public class LlmController {
         if (StringUtils.isNotBlank(userMessage)) {
             request.addMessage(ChatRole.user, userMessage);
         }
-        ChatResult result = llmClientService.chat(request, config.getModelId());
+        ChatResult result = llmClientService.chat(request, config.getModelId(), type);
         llmOrgConfigService.consume(config);
         return result;
     }

+ 10 - 7
src/main/java/com/qmth/ops/biz/ai/client/ChatApiClient.java

@@ -3,12 +3,14 @@ package com.qmth.ops.biz.ai.client;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.qmth.boot.core.ai.model.llm.ChatRequest;
 import com.qmth.boot.core.ai.model.llm.ChatResult;
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.boot.core.rateLimit.service.RateLimiter;
 import com.qmth.boot.core.rateLimit.service.impl.MemoryRateLimiter;
 import com.qmth.ops.biz.ai.exception.ChatRateLimitExceeded;
 import okhttp3.*;
 
 import java.io.IOException;
+import java.time.Duration;
 
 /**
  * 大模型chat类接口基础实现
@@ -24,7 +26,8 @@ public abstract class ChatApiClient {
     private RateLimiter queryRateLimiter;
 
     public ChatApiClient(ChatApiConfig config) {
-        this.client = new OkHttpClient.Builder().connectionPool(new ConnectionPool()).build();
+        this.client = new OkHttpClient.Builder().connectionPool(new ConnectionPool())
+                .connectTimeout(Duration.ofSeconds(10)).readTimeout(Duration.ofSeconds(30)).build();
         this.mapper = new ObjectMapper();
         this.config = config;
         if (config.getQpm() > 0) {
@@ -36,22 +39,22 @@ public abstract class ChatApiClient {
         return config;
     }
 
-    protected abstract Headers buildHeader(Headers.Builder headerBuilder);
+    protected abstract Headers buildHeader(Headers.Builder headerBuilder, LlmAppType appType);
 
-    protected abstract Object buildRequest(ChatRequest request);
+    protected abstract Object buildRequest(ChatRequest request, LlmAppType appType);
 
     protected abstract ChatResult buildResult(byte[] data, ObjectMapper mapper) throws IOException;
 
     protected abstract ChatResult handleError(byte[] data, int statusCode, ObjectMapper mapper);
 
-    public ChatResult call(ChatRequest request) throws Exception {
+    public ChatResult call(ChatRequest request, LlmAppType appType) throws Exception {
         if (queryRateLimiter != null && !queryRateLimiter.acquire()) {
             throw new ChatRateLimitExceeded(config.getSupplier(), config.getModel(), config.getQpm());
         }
         RequestBody body = RequestBody
-                .create(MediaType.parse("application/json"), mapper.writeValueAsBytes(buildRequest(request)));
-        Request httpRequest = new Request.Builder().url(config.getUrl()).headers(buildHeader(new Headers.Builder()))
-                .post(body).build();
+                .create(MediaType.parse("application/json"), mapper.writeValueAsBytes(buildRequest(request, appType)));
+        Request httpRequest = new Request.Builder().url(config.getUrl())
+                .headers(buildHeader(new Headers.Builder(), appType)).post(body).build();
         Response response = client.newCall(httpRequest).execute();
         byte[] data = response.body() != null ? response.body().bytes() : null;
         if (response.isSuccessful()) {

+ 32 - 5
src/main/java/com/qmth/ops/biz/ai/client/aliyun/llm/AliyunChatClient.java

@@ -3,6 +3,8 @@ package com.qmth.ops.biz.ai.client.aliyun.llm;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.qmth.boot.core.ai.model.llm.ChatRequest;
 import com.qmth.boot.core.ai.model.llm.ChatResult;
+import com.qmth.boot.core.ai.model.llm.ChatRole;
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.boot.core.exception.ReentrantException;
 import com.qmth.boot.core.exception.StatusException;
 import com.qmth.ops.biz.ai.client.ChatApiClient;
@@ -24,14 +26,23 @@ public class AliyunChatClient extends ChatApiClient {
     }
 
     @Override
-    protected Headers buildHeader(Headers.Builder headerBuilder) {
-        return headerBuilder.add(AUTH_HEADER_NAME, AUTH_HEADER_VALUE + getConfig().getSecret())
-                .add("X-DashScope-DataInspection", "{\"input\":\"disable\", \"output\":\"disable\"}").build();
+    protected Headers buildHeader(Headers.Builder headerBuilder, LlmAppType appType) {
+        headerBuilder.add(AUTH_HEADER_NAME, AUTH_HEADER_VALUE + getConfig().getSecret());
+        if (appType == LlmAppType.AUTO_SCORE) {
+            headerBuilder.add("X-DashScope-DataInspection", "{\"input\":\"disable\", \"output\":\"disable\"}");
+        }
+        return headerBuilder.build();
     }
 
     @Override
-    protected Object buildRequest(ChatRequest request) {
-        return new AliyunChatRequest(request, getConfig().getModel());
+    protected Object buildRequest(ChatRequest request, LlmAppType appType) {
+        AliyunChatRequest chatRequest = new AliyunChatRequest(request, getConfig().getModel());
+        if (appType == LlmAppType.AUTO_SCORE) {
+            chatRequest.getParameters().put("top_p", 0.1);
+        } else if (appType == LlmAppType.AUTO_GENERATE_QUESTION) {
+            chatRequest.getParameters().put("top_p", 0.9);
+        }
+        return chatRequest;
     }
 
     @Override
@@ -58,4 +69,20 @@ public class AliyunChatClient extends ChatApiClient {
         }
     }
 
+    public static void main(String[] args) throws Exception {
+        ChatApiConfig config = new ChatApiConfig();
+        config.setSupplier("aliyun");
+        config.setUrl("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation");
+        config.setSecret("");
+        config.setModel("qwen-turbo");
+        config.setQpm(0);
+        AliyunChatClient client = new AliyunChatClient(config);
+        ChatRequest request = new ChatRequest();
+        request.addMessage(ChatRole.user,
+                "作为高等数学科目的命题老师,请按照下列要求出1道单选试题\n" + "试题题干前用单独一行'【题干】'作为内容\n" + "试题答案前用单独一行'【答案】'作为内容\n"
+                        + "试题答案解析前用单独一行'【解析】'作为内容\n" + "试题包含4个选项,选项内容前用单独一行'【选项】'作为内容,且每个选项前用大写英文字母开头\n"
+                        + "请按照上述要求出1道高等数学的单选试题");
+        System.out.println(
+                new ObjectMapper().writeValueAsString(client.call(request, LlmAppType.AUTO_GENERATE_QUESTION)));
+    }
 }

+ 3 - 2
src/main/java/com/qmth/ops/biz/ai/client/azure/llm/AzureChatClient.java

@@ -3,6 +3,7 @@ package com.qmth.ops.biz.ai.client.azure.llm;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.qmth.boot.core.ai.model.llm.ChatRequest;
 import com.qmth.boot.core.ai.model.llm.ChatResult;
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.boot.core.exception.NotFoundException;
 import com.qmth.boot.core.exception.StatusException;
 import com.qmth.ops.biz.ai.client.ChatApiClient;
@@ -25,12 +26,12 @@ public class AzureChatClient extends ChatApiClient {
     }
 
     @Override
-    protected Headers buildHeader(Headers.Builder headerBuilder) {
+    protected Headers buildHeader(Headers.Builder headerBuilder, LlmAppType appType) {
         return headerBuilder.add(AUTH_HEADER_NAME, getConfig().getSecret()).build();
     }
 
     @Override
-    protected Object buildRequest(ChatRequest request) {
+    protected Object buildRequest(ChatRequest request, LlmAppType appType) {
         return request;
     }
 

+ 3 - 2
src/main/java/com/qmth/ops/biz/service/LlmClientService.java

@@ -2,6 +2,7 @@ package com.qmth.ops.biz.service;
 
 import com.qmth.boot.core.ai.model.llm.ChatRequest;
 import com.qmth.boot.core.ai.model.llm.ChatResult;
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.ops.biz.ai.client.ChatApiClient;
 import com.qmth.ops.biz.ai.client.ChatApiConfig;
 import com.qmth.ops.biz.ai.exception.ChatClientNotFound;
@@ -65,11 +66,11 @@ public class LlmClientService {
         }
     }
 
-    public ChatResult chat(ChatRequest request, Long modelId) throws Exception {
+    public ChatResult chat(ChatRequest request, Long modelId, LlmAppType appType) throws Exception {
         ChatApiClient client = chatApiClientMap.get(modelId);
         if (client == null) {
             throw new ChatClientNotFound(modelId);
         }
-        return client.call(request);
+        return client.call(request, appType);
     }
 }