Quellcode durchsuchen

修改LlmPromptTemplate数据结构,增加多提示词模版支持

luoshi vor 1 Jahr
Ursprung
Commit
4b640f0a04

+ 12 - 3
src/main/java/com/qmth/ops/api/controller/admin/LlmModelController.java

@@ -1,5 +1,6 @@
 package com.qmth.ops.api.controller.admin;
 
+import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.ops.api.constants.OpsApiConstants;
 import com.qmth.ops.api.security.AdminSession;
 import com.qmth.ops.api.security.Permission;
@@ -48,8 +49,16 @@ public class LlmModelController {
     }
 
     @PostMapping("/prompt_template/list")
-    public List<LlmPromptTemplate> getPromptTemplate(@RequestParam Long modelId) {
-        return llmPromptTemplateService.findByModel(modelId);
+    public List<LlmPromptTemplate> getPromptTemplate(@RequestParam Long modelId, @RequestParam LlmAppType appType) {
+        return llmPromptTemplateService.findByModelAndAppType(modelId, appType);
+    }
+
+    @PostMapping("/prompt_template/insert")
+    public LlmPromptTemplate insertPromptTemplate(@RequestAttribute AdminSession adminSession,
+            LlmPromptTemplate template) {
+        adminSession.hasPermission(Permission.LLM_MODEL_EDIT);
+        llmPromptTemplateService.insert(template);
+        return template;
     }
 
     @PostMapping("/prompt_template/update")
@@ -57,7 +66,7 @@ public class LlmModelController {
             LlmPromptTemplate template) {
         adminSession.hasPermission(Permission.LLM_MODEL_EDIT);
         llmPromptTemplateService.update(template);
-        return llmPromptTemplateService.findByModelAndAppType(template.getModelId(), template.getAppType());
+        return llmPromptTemplateService.findById(template.getId());
     }
 
 }

+ 3 - 3
src/main/java/com/qmth/ops/api/controller/admin/LlmOrgConfigController.java

@@ -54,10 +54,10 @@ public class LlmOrgConfigController {
         return config;
     }
 
-    @PostMapping("/update/model")
-    public LlmOrgConfig updateModel(@RequestAttribute AdminSession adminSession, LlmOrgConfig config) {
+    @PostMapping("/update/model_prompt")
+    public LlmOrgConfig updateModelAndPrompt(@RequestAttribute AdminSession adminSession, LlmOrgConfig config) {
         adminSession.hasPermission(Permission.LLM_ORG_CONFIG_EDIT);
-        llmOrgConfigService.updateModel(config);
+        llmOrgConfigService.updateModelAndPrompt(config);
         return llmOrgConfigService.findByOrgAndAppType(config.getOrgId(), config.getAppType());
     }
 

+ 10 - 9
src/main/java/com/qmth/ops/api/controller/ai/LlmController.java

@@ -13,11 +13,11 @@ import com.qmth.ops.biz.domain.LlmPromptTemplate;
 import com.qmth.ops.biz.service.LlmClientService;
 import com.qmth.ops.biz.service.LlmOrgConfigService;
 import com.qmth.ops.biz.service.LlmPromptTemplateService;
+import org.apache.commons.lang3.StringUtils;
 import org.springframework.validation.annotation.Validated;
 import org.springframework.web.bind.annotation.*;
 
 import javax.annotation.Resource;
-import java.util.Map;
 
 @RestController
 @Aac(auth = true, signType = SignatureType.SECRET)
@@ -48,24 +48,26 @@ public class LlmController {
 
     @PostMapping(AiConstants.LLM_CHAT_TEMPLATE_PATH)
     public ChatResult chatTemplate(@RequestAttribute AccessOrg accessOrg,
-            @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type, @RequestBody Map<String, Object> param)
+            @RequestHeader(AiConstants.LLM_APP_TYPE_HEADER) LlmAppType type, @RequestBody Object param)
             throws Exception {
         LlmOrgConfig config = llmOrgConfigService.findByOrgAndAppType(accessOrg.getOrg().getId(), type);
         if (config == null || config.getLeftCount() <= 0) {
             throw new ForbiddenException(
                     "Chat api is disabled or exhausted for org=" + accessOrg.getOrg().getCode() + ", app_type=" + type);
         }
-        LlmPromptTemplate llmPromptTemplate = llmPromptTemplateService.findByModelAndAppType(config.getModelId(), type);
+        LlmPromptTemplate llmPromptTemplate = llmPromptTemplateService.findById(config.getPromptId());
         if (llmPromptTemplate == null) {
             throw new NotFoundException(
                     "Chat prompt template not found for app_type=" + type + ", modelId=" + config.getModelId());
         }
         ChatRequest request = new ChatRequest();
-        if (llmPromptTemplate.getSystem() != null) {
-            request.addMessage(ChatRole.system, FreemarkerUtil.getValue(llmPromptTemplate.getSystem(), param, ""));
+        String systemMessage = FreemarkerUtil.getValue(llmPromptTemplate.getSystem(), param, null);
+        String userMessage = FreemarkerUtil.getValue(llmPromptTemplate.getUser(), param, null);
+        if (StringUtils.isNotBlank(systemMessage)) {
+            request.addMessage(ChatRole.system, systemMessage);
         }
-        if (llmPromptTemplate.getUser() != null) {
-            request.addMessage(ChatRole.user, FreemarkerUtil.getValue(llmPromptTemplate.getUser(), param, ""));
+        if (StringUtils.isNotBlank(userMessage)) {
+            request.addMessage(ChatRole.user, userMessage);
         }
         ChatResult result = llmClientService.chat(request, config.getModelId());
         llmOrgConfigService.consume(config);
@@ -90,8 +92,7 @@ public class LlmController {
         PromptTemplate template = new PromptTemplate();
         LlmOrgConfig config = llmOrgConfigService.findByOrgAndAppType(accessOrg.getOrg().getId(), type);
         if (config != null) {
-            LlmPromptTemplate llmPromptTemplate = llmPromptTemplateService
-                    .findByModelAndAppType(config.getModelId(), type);
+            LlmPromptTemplate llmPromptTemplate = llmPromptTemplateService.findById(config.getPromptId());
             if (llmPromptTemplate != null) {
                 template.setSystem(llmPromptTemplate.getSystem());
                 template.setUser(llmPromptTemplate.getUser());

+ 0 - 7
src/main/java/com/qmth/ops/biz/dao/LlmPromptTemplateDao.java

@@ -1,15 +1,8 @@
 package com.qmth.ops.biz.dao;
 
 import com.baomidou.mybatisplus.core.mapper.BaseMapper;
-import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.ops.biz.domain.LlmPromptTemplate;
-import org.apache.ibatis.annotations.Param;
-import org.apache.ibatis.annotations.Update;
 
 public interface LlmPromptTemplateDao extends BaseMapper<LlmPromptTemplate> {
 
-    @Update("replace into llm_prompt_template (model_id, app_type, system, user) "
-            + "values (#{modelId}, #{appType}, #{system}, #{user})")
-    void replace(@Param("modelId") Long modelId, @Param("appType") LlmAppType appType, @Param("system") String system,
-            @Param("user") String user);
 }

+ 11 - 1
src/main/java/com/qmth/ops/biz/domain/LlmOrgConfig.java

@@ -5,7 +5,7 @@ import com.qmth.boot.core.ai.model.llm.LlmAppType;
 
 import java.io.Serializable;
 
-@TableName("llm_org_config")
+@TableName(value = "llm_org_config", autoResultMap = true)
 public class LlmOrgConfig implements Serializable {
 
     private static final long serialVersionUID = -593409647805304621L;
@@ -16,6 +16,8 @@ public class LlmOrgConfig implements Serializable {
 
     private Long modelId;
 
+    private Long promptId;
+
     private Integer permitCount;
 
     private Integer leftCount;
@@ -44,6 +46,14 @@ public class LlmOrgConfig implements Serializable {
         this.modelId = modelId;
     }
 
+    public Long getPromptId() {
+        return promptId;
+    }
+
+    public void setPromptId(Long promptId) {
+        this.promptId = promptId;
+    }
+
     public Integer getPermitCount() {
         return permitCount;
     }

+ 23 - 0
src/main/java/com/qmth/ops/biz/domain/LlmPromptTemplate.java

@@ -1,5 +1,7 @@
 package com.qmth.ops.biz.domain;
 
+import com.baomidou.mybatisplus.annotation.IdType;
+import com.baomidou.mybatisplus.annotation.TableId;
 import com.baomidou.mybatisplus.annotation.TableName;
 import com.qmth.boot.core.ai.model.llm.LlmAppType;
 
@@ -10,6 +12,9 @@ public class LlmPromptTemplate implements Serializable {
 
     private static final long serialVersionUID = -2886605312034160681L;
 
+    @TableId(type = IdType.AUTO)
+    private Long id;
+
     private Long modelId;
 
     private LlmAppType appType;
@@ -18,6 +23,16 @@ public class LlmPromptTemplate implements Serializable {
 
     private String user;
 
+    private String remark;
+
+    public Long getId() {
+        return id;
+    }
+
+    public void setId(Long id) {
+        this.id = id;
+    }
+
     public LlmAppType getAppType() {
         return appType;
     }
@@ -49,4 +64,12 @@ public class LlmPromptTemplate implements Serializable {
     public void setUser(String user) {
         this.user = user;
     }
+
+    public String getRemark() {
+        return remark;
+    }
+
+    public void setRemark(String remark) {
+        this.remark = remark;
+    }
 }

+ 2 - 1
src/main/java/com/qmth/ops/biz/service/LlmOrgConfigService.java

@@ -29,9 +29,10 @@ public class LlmOrgConfigService extends ServiceImpl<LlmOrgConfigDao, LlmOrgConf
     }
 
     @Transactional
-    public void updateModel(LlmOrgConfig llmOrgConfig) {
+    public void updateModelAndPrompt(LlmOrgConfig llmOrgConfig) {
         llmOrgConfigDao.update(llmOrgConfig,
                 new LambdaUpdateWrapper<LlmOrgConfig>().set(LlmOrgConfig::getModelId, llmOrgConfig.getModelId())
+                        .set(LlmOrgConfig::getPromptId, llmOrgConfig.getPromptId())
                         .eq(LlmOrgConfig::getOrgId, llmOrgConfig.getOrgId())
                         .eq(LlmOrgConfig::getAppType, llmOrgConfig.getAppType()));
     }

+ 21 - 13
src/main/java/com/qmth/ops/biz/service/LlmPromptTemplateService.java

@@ -1,7 +1,7 @@
 package com.qmth.ops.biz.service;
 
 import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
-import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
+import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
 import com.qmth.boot.core.ai.model.llm.LlmAppType;
 import com.qmth.ops.biz.dao.LlmPromptTemplateDao;
 import com.qmth.ops.biz.domain.LlmPromptTemplate;
@@ -14,30 +14,38 @@ import javax.annotation.Resource;
 import java.util.List;
 
 @Service
-public class LlmPromptTemplateService extends ServiceImpl<LlmPromptTemplateDao, LlmPromptTemplate> {
+public class LlmPromptTemplateService {
 
     private static final String CACHE_NAME = "llm_prompt_template";
 
     @Resource
     private LlmPromptTemplateDao llmPromptTemplateDao;
 
-    @Cacheable(value = CACHE_NAME, key = "#modelId+'_'+#appType")
-    public LlmPromptTemplate findByModelAndAppType(Long modelId, LlmAppType appType) {
-        return llmPromptTemplateDao.selectOne(
-                new LambdaQueryWrapper<LlmPromptTemplate>().eq(LlmPromptTemplate::getModelId, modelId)
-                        .eq(LlmPromptTemplate::getAppType, appType));
+    @Cacheable(value = CACHE_NAME, key = "#id")
+    public LlmPromptTemplate findById(Long id) {
+        return llmPromptTemplateDao.selectById(id);
     }
 
-    public List<LlmPromptTemplate> findByModel(Long modelId) {
-        return llmPromptTemplateDao
-                .selectList(new LambdaQueryWrapper<LlmPromptTemplate>().eq(LlmPromptTemplate::getModelId, modelId));
+    public List<LlmPromptTemplate> findByModelAndAppType(Long modelId, LlmAppType appType) {
+        return llmPromptTemplateDao.selectList(
+                new LambdaQueryWrapper<LlmPromptTemplate>().eq(modelId != null, LlmPromptTemplate::getModelId, modelId)
+                        .eq(appType != null, LlmPromptTemplate::getAppType, appType));
     }
 
     @Transactional
-    @CachePut(value = CACHE_NAME, key = "#template.modelId+'_'+template.appType", unless = "#template==null")
+    @CachePut(value = CACHE_NAME, key = "#template.id", unless = "#template==null")
+    public void insert(LlmPromptTemplate template) {
+        llmPromptTemplateDao.insert(template);
+    }
+
+    @Transactional
+    @CachePut(value = CACHE_NAME, key = "#template.id", unless = "#template==null")
     public void update(LlmPromptTemplate template) {
-        llmPromptTemplateDao
-                .replace(template.getModelId(), template.getAppType(), template.getSystem(), template.getUser());
+        llmPromptTemplateDao.update(template,
+                new LambdaUpdateWrapper<LlmPromptTemplate>().set(LlmPromptTemplate::getSystem, template.getSystem())
+                        .set(LlmPromptTemplate::getUser, template.getUser())
+                        .set(LlmPromptTemplate::getRemark, template.getRemark())
+                        .eq(LlmPromptTemplate::getId, template.getId()));
     }
 
 }

+ 1 - 1
src/main/resources/application.properties

@@ -1,4 +1,4 @@
-server.port=8080
+server.port=8090
 ##spring.resources.static-locations=classpath:/static/,file:/Users/luoshi/develop/project/ops-web/,file:/Users/luoshi/develop/data/
 
 com.qmth.api.http-trace=true

+ 3 - 0
src/main/resources/script/init.sql

@@ -231,6 +231,7 @@ CREATE TABLE IF NOT EXISTS `llm_org_config`
     `org_id`       bigint(20) unsigned NOT NULL,
     `app_type`     varchar(32)         NOT NULL,
     `model_id`     bigint(20) unsigned NOT NULL,
+    `prompt_id`    bigint(20) unsigned NOT NULL,
     `permit_count` int(11)             NOT NULL,
     `left_count`   int(11)             NOT NULL,
     PRIMARY KEY (`org_id`, `app_type`)
@@ -240,10 +241,12 @@ CREATE TABLE IF NOT EXISTS `llm_org_config`
 
 CREATE TABLE IF NOT EXISTS `llm_prompt_template`
 (
+    `id`       bigint(20) unsigned NOT NULL AUTO_INCREMENT,
     `model_id` bigint(20) unsigned NOT NULL,
     `app_type` varchar(32)         NOT NULL,
     `system`   text                NOT NULL,
     `user`     text                NOT NULL,
+    `remark`   varchar(32)         NOT NULL,
     PRIMARY KEY (`model_id`, `app_type`)
 ) ENGINE = InnoDB
   DEFAULT CHARSET = utf8mb4;