Browse Source

添加重排序模型

feng 10 months ago
parent
commit
6f6d70c090

+ 75 - 0
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/chain/xinference/RerankUtil.java

@@ -0,0 +1,75 @@
+package org.ruoyi.chain.xinference;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import okhttp3.*;
+
+import java.io.IOException;
+import java.util.*;
+
+public class RerankUtil {
+
+    private static final OkHttpClient client = new OkHttpClient();
+    private static final ObjectMapper objectMapper = new ObjectMapper();
+
+    public static List<Map<String, Object>> rerankNearestList(
+            String query,
+            List<Map<String, Object>> nearestList,
+            String rerankUrl,String modelName) throws IOException {
+
+        // 1. 提取 text 列表
+        List<String> documents = new ArrayList<>();
+        for (Map<String, Object> item : nearestList) {
+            documents.add(String.valueOf(item.get("text")));
+        }
+
+        // 2. 构造请求 JSON
+        Map<String, Object> payload = new HashMap<>();
+        payload.put("query", query);
+        payload.put("documents", documents);
+        payload.put("model", modelName);
+
+        String jsonBody = objectMapper.writeValueAsString(payload);
+
+        RequestBody requestBody = RequestBody.create(
+                jsonBody, MediaType.parse("application/json"));
+
+        Request request = new Request.Builder()
+                .url(rerankUrl)
+                .addHeader("accept", "application/json")
+                .post(requestBody)
+                .build();
+
+        // 3. 调用接口
+        Response response = client.newCall(request).execute();
+        if (!response.isSuccessful()) {
+            throw new RuntimeException("Rerank API 调用失败: " + response);
+        }
+
+        // 4. 解析返回 JSON
+        String responseStr = response.body().string();
+        JsonNode root = objectMapper.readTree(responseStr);
+        JsonNode results = root.get("results"); // 假设返回里有 data 字段 [{index:0, score:xx}, ...]
+
+        // 5. 排序 nearestList
+        List<Map<String, Object>> sortedList = new ArrayList<>();
+        List<Map<String, Object>> rerankResults = new ArrayList<>();
+
+        for (JsonNode node : results) {
+            Map<String, Object> item = new HashMap<>();
+            item.put("index", node.get("index").asInt());
+            item.put("score", node.get("relevance_score").asDouble());
+            rerankResults.add(item);
+        }
+
+        rerankResults.stream()
+                .sorted((a, b) -> Double.compare((Double) b.get("score"), (Double) a.get("score")))
+                .forEach(r -> {
+                    int idx = (Integer) r.get("index");
+                    Map<String, Object> original = nearestList.get(idx);
+                    sortedList.add(original);
+                });
+
+        return sortedList;
+    }
+}

+ 4 - 6
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java

@@ -188,7 +188,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
             try {
                 String jsonString = mapper.writeValueAsString(hashMap);
                 properties.put("metadata",jsonString);
-            } catch (JsonProcessingException e) {
+              } catch (JsonProcessingException e) {
                 throw new RuntimeException(e);
             }
 
@@ -294,8 +294,6 @@ public class VectorStoreServiceImpl implements VectorStoreService {
                     }
                 }
             }
-        } else {
-            System.err.println("Data is not a Map: " + data);
         }
         return results;
 
@@ -303,9 +301,9 @@ public class VectorStoreServiceImpl implements VectorStoreService {
 
     private static List<Double> getDoubles(QueryVectorBo queryVectorBo) {
         XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
-        String vectorModelName = queryVectorBo.getVectorModelName();
-        String[] split = vectorModelName.split("/");
-        String modelUid = split[split.length-1];
+        String embeddingModelName = queryVectorBo.getEmbeddingModelName();
+        String[] embeddingModelNameArr = embeddingModelName.split("/");
+        String modelUid = embeddingModelNameArr[embeddingModelNameArr.length-1];
         List<Double> queryVector;
         try {
             queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());

+ 2 - 1
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/ISseService.java

@@ -18,6 +18,7 @@ import org.springframework.http.ResponseEntity;
 import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 
+import java.io.IOException;
 import java.util.List;
 import java.util.Map;
 
@@ -61,7 +62,7 @@ public interface ISseService {
      */
     UploadFileResponse upload(MultipartFile file);
 
-    SchemaMessage getSchema(SchemaRequest schemaRequest, HttpServletRequest request);
+    SchemaMessage getSchema(SchemaRequest schemaRequest, HttpServletRequest request) throws IOException;
 
     @DS("open-db")
     int saveRag(ThinkProject thinkProject, HttpServletRequest request);

+ 39 - 25
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java

@@ -8,6 +8,7 @@ import jakarta.servlet.http.HttpServletRequest;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import okhttp3.ResponseBody;
+import org.ruoyi.chain.xinference.RerankUtil;
 import org.ruoyi.chat.factory.ChatServiceFactory;
 import org.ruoyi.chat.service.chat.IChatCostService;
 import org.ruoyi.chat.service.chat.IChatService;
@@ -59,7 +60,6 @@ import java.io.InputStream;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -97,6 +97,7 @@ public class SseServiceImpl implements ISseService {
     @Autowired
     private ThinkRagMapper thinkRagMapper;
 
+    private String rerankModelName = "bge-reranker-v2-m3";
 
     @Override
     public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
@@ -183,7 +184,7 @@ public class SseServiceImpl implements ISseService {
     /**
      *  构建消息列表
      */
-    private void buildChatMessageList(ChatRequest chatRequest){
+    private void buildChatMessageList(ChatRequest chatRequest) throws IOException {
         String sysPrompt;
         // 矫正模型名称 如果是gpt-image 则查询image类型模型 获取模型名称
         if(chatRequest.getModel().equals("gpt-image")) {
@@ -215,7 +216,10 @@ public class SseServiceImpl implements ISseService {
             queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
             queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
             List<Map<String,Object>> nearestList = vectorStoreService.getQueryVector(queryVectorBo);
-            for (Map prompt : nearestList) {
+            // 查询重排序模型 重排序模型现在只用到了bge-reranker-v2-m3,后续可以动态获取
+            ChatModelVo rerankModelVo = chatModelService.selectModelByName(rerankModelName);
+            List<Map<String, Object>> maps = RerankUtil.rerankNearestList(content, nearestList, rerankModelVo.getApiHost(),rerankModelName);
+            for (Map prompt : maps) {
                 StringBuilder context = new StringBuilder();
                 String text = String.valueOf(prompt.get("text"));
                 String filename = String.valueOf(prompt.get("filename"));
@@ -228,28 +232,34 @@ public class SseServiceImpl implements ISseService {
             // 设置知识库系统提示词
             sysPrompt = knowledgeInfoVo.getSystemPrompt();
             if(StringUtils.isEmpty(sysPrompt)){
-                sysPrompt = "你是一个专业的问答助手,根据提供的资料进行回答。\n" +
+                sysPrompt = "你是一个专业的问答助手,根据提供的资料进行回答。\n" +
                         "\n" +
-                        "资料来源如下,每条资料包含正文内容和来源文件名:\n" +
-                        "每段资料格式如下:\n" +
-                        "内容:<正文文本>\n" +
-                        "来源:<文件名>\n" +
+                        "资料格式如下,每条资料包含正文内容和来源文件名:  \n" +
+                        "内容:<正文文本>  \n" +
+                        "来源:<文件名>  \n" +
                         "\n" +
-                        "在回答时请遵循以下规则:\n" +
-                        "1. 只能引用提供的资料进行回答,不要编造信息。\n" +
-                        "2. 如果引用了某段资料,请在回答中对应位置用 ①②③ 等标注。- 内容与来源要一一对应,多处引用相同资料,也都统一使用标注。\n" +
-                        "3. 回答结尾,请根据标注,列出来源。例如:\n" +
-                        "   来源:\n" +
-                        "   ① 文件A.pdf\n" +
-                        "   ② 文件B.docx   - 多处引用相同资料,也使用同一的来源 "  +
-                        "当前时间:"+ DateUtils.getDate();
-//                sysPrompt ="###角色设定\n" +
-//                        "你是一个智能知识助手,专注于利用上下文中的信息来提供准确和相关的回答。\n" +
-//                        "###指令\n" +
-//                        "当用户的问题与上下文知识匹配时,利用上下文信息进行回答,在。如果问题与上下文不匹配,运用自身的推理能力生成合适的回答。\n" +
-//                        "###限制\n" +
-//                        "确保回答清晰简洁,避免提供不必要的细节。始终保持语气友好" +
-//                        "当前时间:"+ DateUtils.getDate();
+                        "回答时请严格遵守以下规则:  \n" +
+                        "1. 只能引用提供的资料,不得编造或添加无关信息。  \n" +
+                        "2. 回答中引用资料时,必须使用①②③……等标注符号。  \n" +
+                        "3. **同一来源文件必须使用唯一且固定的标注号,无论引用该文件多少次或多个片段,都只能用一个标注号。**  \n" +
+                        "4. 如果回答中有多个要点,每个要点后都必须加标注。  \n" +
+                        "5. 回答末尾必须列出所有用到的标注号和对应的唯一来源文件,格式如下:\n" +
+                        "\n" +
+                        "来源:  \n" +
+                        "① 文件A.docx  \n" +
+                        "② 文件B.pdf  \n" +
+                        "\n" +
+                        "6. 请严格按照以下示例格式回复,不要偏离示例:\n" +
+                        "\n" +
+                        "示例:  \n" +
+                        "1. 内容描述①  \n" +
+                        "2. 另一个内容①  \n" +
+                        "\n" +
+                        "来源:  \n" +
+                        "① 文件A.docx" +
+                        "注意,不能出现同一个来源出现不同的标注和来源"+
+                        "当前时间:" + DateUtils.getDate();
+
             }
         }else {
             sysPrompt = chatModelVo.getSystemPrompt();
@@ -370,7 +380,7 @@ public class SseServiceImpl implements ISseService {
     }
 
     @Override
-    public SchemaMessage getSchema(SchemaRequest schemaRequest, HttpServletRequest request) {
+    public SchemaMessage getSchema(SchemaRequest schemaRequest, HttpServletRequest request) throws IOException {
         SchemaMessage schema = new SchemaMessage();
         String prompt = schemaRequest.getMessages();
         float maxDistance = schemaRequest.getMaxDistance();
@@ -389,7 +399,11 @@ public class SseServiceImpl implements ISseService {
             queryVectorBo.setEmbeddingModelName(knowledgeInfoVo.getEmbeddingModelName());
             queryVectorBo.setMaxResults(knowledgeInfoVo.getRetrieveLimit());
             List<Map<String,Object>> nearestList = vectorStoreService.getQueryVector(queryVectorBo, maxDistance);
-            schema.setNearest(nearestList);
+            // 查询重排序模型 重排序模型现在只用到了bge-reranker-v2-m3,后续可以动态获取
+            ChatModelVo rerankModelVo = chatModelService.selectModelByName(rerankModelName);
+            List<Map<String, Object>> maps = RerankUtil.rerankNearestList(prompt, nearestList, rerankModelVo.getApiHost(),rerankModelName);
+
+            schema.setNearest(maps);
 
         }
         return schema;