浏览代码

增加prompt_template数据存储与访问接口

luoshi 1 年之前
父节点
当前提交
fc17231c76

+ 16 - 2
src/main/java/com/qmth/ops/api/controller/ai/LlmController.java

@@ -7,8 +7,10 @@ import com.qmth.boot.core.exception.ForbiddenException;
 import com.qmth.boot.tools.signature.SignatureType;
 import com.qmth.boot.tools.signature.SignatureType;
 import com.qmth.ops.api.security.AccessOrg;
 import com.qmth.ops.api.security.AccessOrg;
 import com.qmth.ops.biz.domain.LlmOrgConfig;
 import com.qmth.ops.biz.domain.LlmOrgConfig;
+import com.qmth.ops.biz.domain.LlmPromptTemplate;
 import com.qmth.ops.biz.service.LlmClientService;
 import com.qmth.ops.biz.service.LlmClientService;
 import com.qmth.ops.biz.service.LlmOrgConfigService;
 import com.qmth.ops.biz.service.LlmOrgConfigService;
+import com.qmth.ops.biz.service.LlmPromptTemplateService;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.bind.annotation.*;
 
 
@@ -24,6 +26,9 @@ public class LlmController {
     @Resource
     @Resource
     private LlmClientService llmClientService;
     private LlmClientService llmClientService;
 
 
+    @Resource
+    private LlmPromptTemplateService llmPromptTemplateService;
+
     @PostMapping(AiConstants.LLM_CHAT_PATH)
     @PostMapping(AiConstants.LLM_CHAT_PATH)
     public ChatResult chat(@RequestAttribute AccessOrg accessOrg,
     public ChatResult chat(@RequestAttribute AccessOrg accessOrg,
             @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type,
             @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type,
@@ -42,8 +47,8 @@ public class LlmController {
     @PostMapping(AiConstants.LLM_BALANCE_PATH)
     @PostMapping(AiConstants.LLM_BALANCE_PATH)
     public LlmAppBalance balance(@RequestAttribute AccessOrg accessOrg,
     public LlmAppBalance balance(@RequestAttribute AccessOrg accessOrg,
             @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type) {
             @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type) {
-        LlmOrgConfig config = llmOrgConfigService.findByOrgAndAppType(accessOrg.getOrg(), type);
         LlmAppBalance balance = new LlmAppBalance();
         LlmAppBalance balance = new LlmAppBalance();
+        LlmOrgConfig config = llmOrgConfigService.findByOrgAndAppType(accessOrg.getOrg(), type);
         if (config != null) {
         if (config != null) {
             balance.setPermitCount(config.getPermitCount());
             balance.setPermitCount(config.getPermitCount());
             balance.setLeftCount(config.getLeftCount());
             balance.setLeftCount(config.getLeftCount());
@@ -54,6 +59,15 @@ public class LlmController {
     @PostMapping(AiConstants.LLM_PROMPT_TEMPLATE_PATH)
     @PostMapping(AiConstants.LLM_PROMPT_TEMPLATE_PATH)
     public PromptTemplate getPromptTemplate(@RequestAttribute AccessOrg accessOrg,
     public PromptTemplate getPromptTemplate(@RequestAttribute AccessOrg accessOrg,
             @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type) {
             @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type) {
-        return null;
+        PromptTemplate template = new PromptTemplate();
+        LlmOrgConfig config = llmOrgConfigService.findByOrgAndAppType(accessOrg.getOrg(), type);
+        if (config != null) {
+            LlmPromptTemplate llmPromptTemplate = llmPromptTemplateService.findByAppType(type, config.getModelId());
+            if (llmPromptTemplate != null) {
+                template.setSystem(llmPromptTemplate.getSystem());
+                template.setUser(llmPromptTemplate.getUser());
+            }
+        }
+        return template;
     }
     }
 }
 }

+ 8 - 0
src/main/java/com/qmth/ops/biz/dao/LlmPromptTemplateDao.java

@@ -0,0 +1,8 @@
+package com.qmth.ops.biz.dao;
+
+import com.baomidou.mybatisplus.core.mapper.BaseMapper;
+import com.qmth.ops.biz.domain.LlmPromptTemplate;
+
+public interface LlmPromptTemplateDao extends BaseMapper<LlmPromptTemplate> {
+
+}

+ 52 - 0
src/main/java/com/qmth/ops/biz/domain/LlmPromptTemplate.java

@@ -0,0 +1,52 @@
+package com.qmth.ops.biz.domain;
+
+import com.baomidou.mybatisplus.annotation.TableName;
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
+
+import java.io.Serializable;
+
+@TableName("llm_prompt_template")
+public class LlmPromptTemplate implements Serializable {
+
+    private static final long serialVersionUID = -2886605312034160681L;
+
+    private LlmAppType appType;
+
+    private Long modelId;
+
+    private String system;
+
+    private String user;
+
+    public LlmAppType getAppType() {
+        return appType;
+    }
+
+    public void setAppType(LlmAppType appType) {
+        this.appType = appType;
+    }
+
+    public Long getModelId() {
+        return modelId;
+    }
+
+    public void setModelId(Long modelId) {
+        this.modelId = modelId;
+    }
+
+    public String getSystem() {
+        return system;
+    }
+
+    public void setSystem(String system) {
+        this.system = system;
+    }
+
+    public String getUser() {
+        return user;
+    }
+
+    public void setUser(String user) {
+        this.user = user;
+    }
+}

+ 25 - 0
src/main/java/com/qmth/ops/biz/service/LlmPromptTemplateService.java

@@ -0,0 +1,25 @@
+package com.qmth.ops.biz.service;
+
+import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
+import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
+import com.qmth.ops.biz.dao.LlmPromptTemplateDao;
+import com.qmth.ops.biz.domain.LlmPromptTemplate;
+import org.springframework.stereotype.Service;
+
+import javax.annotation.Resource;
+
+@Service
+public class LlmPromptTemplateService extends ServiceImpl<LlmPromptTemplateDao, LlmPromptTemplate> {
+
+    @Resource
+    private LlmPromptTemplateDao llmPromptTemplateDao;
+
+    public LlmPromptTemplate findByAppType(LlmAppType appType, Long modelId) {
+        return llmPromptTemplateDao.selectOne(
+                new LambdaQueryWrapper<LlmPromptTemplate>().eq(LlmPromptTemplate::getAppType, appType)
+                        .eq(LlmPromptTemplate::getModelId, modelId));
+    }
+
+}
+

+ 11 - 0
src/main/resources/script/init.sql

@@ -236,3 +236,14 @@ CREATE TABLE IF NOT EXISTS `llm_org_config`
     PRIMARY KEY (`org_id`, `app_type`)
     PRIMARY KEY (`org_id`, `app_type`)
 ) ENGINE = InnoDB
 ) ENGINE = InnoDB
   DEFAULT CHARSET = utf8mb4;
   DEFAULT CHARSET = utf8mb4;
+
+
+CREATE TABLE IF NOT EXISTS `llm_prompt_template`
+(
+    `app_type` varchar(32)         NOT NULL,
+    `model_id` bigint(20) unsigned NOT NULL,
+    `system`   text                NOT NULL,
+    `user`     text                NOT NULL,
+    PRIMARY KEY (`app_type`, `model_id`)
+) ENGINE = InnoDB
+  DEFAULT CHARSET = utf8mb4;