feng 9 mesi fa
parent
commit
c21303df56

+ 22 - 22
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java

@@ -3,13 +3,10 @@ package org.ruoyi.service.impl;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.protobuf.ServiceException;
-import dev.langchain4j.data.embedding.Embedding;
 import dev.langchain4j.data.segment.TextSegment;
 import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
 import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
-import dev.langchain4j.store.embedding.EmbeddingMatch;
-import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
 import dev.langchain4j.store.embedding.EmbeddingStore;
 import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
 import io.weaviate.client.Config;
@@ -25,7 +22,6 @@ import io.weaviate.client.v1.schema.model.WeaviateClass;
 import lombok.RequiredArgsConstructor;
 import lombok.SneakyThrows;
 import lombok.extern.slf4j.Slf4j;
-import org.apache.rocketmq.client.QueryResult;
 import org.ruoyi.chain.xinference.XinferenceClient;
 import org.ruoyi.common.core.service.ConfigService;
 import org.ruoyi.domain.bo.KnowledgeAttachBo;
@@ -60,8 +56,6 @@ public class VectorStoreServiceImpl implements VectorStoreService {
     private IKnowledgeAttachService knowledgeAttachService;
 
 
-
-
     @Override
     public void createSchema(String kid, String modelName) {
         String protocol = configService.getConfigValue("weaviate", "protocol");
@@ -138,6 +132,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
                 .consistencyLevel("ALL")
                 .build();
     }
+
     @Override
     public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
         String protocol = configService.getConfigValue("weaviate", "protocol");
@@ -183,12 +178,12 @@ public class VectorStoreServiceImpl implements VectorStoreService {
             properties.put("docId", docId);
             properties.put("filename", storeEmbeddingBo.getFileName());
             HashMap<String, Object> hashMap = new HashMap<>();
-            hashMap.put("docId",docId);
-            hashMap.put("filename",storeEmbeddingBo.getFileName());
+            hashMap.put("docId", docId);
+            hashMap.put("filename", storeEmbeddingBo.getFileName());
             ObjectMapper mapper = new ObjectMapper();
             try {
                 String jsonString = mapper.writeValueAsString(hashMap);
-                properties.put("metadata",jsonString);
+                properties.put("metadata", jsonString);
             } catch (JsonProcessingException e) {
                 throw new RuntimeException(e);
             }
@@ -239,7 +234,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
 //    }
 
     @Override
-    public List<Map<String,Object>> getQueryVector(QueryVectorBo queryVectorBo) {
+    public List<Map<String, Object>> getQueryVector(QueryVectorBo queryVectorBo) {
         createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
         String protocol = configService.getConfigValue("weaviate", "protocol");
         String host = configService.getConfigValue("weaviate", "host");
@@ -249,7 +244,10 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         // 初始化 Weaviate 客户端
         WeaviateClient client = new WeaviateClient(new Config(protocol, host));
         XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
-        String modelUid = "bge-large-zh-v1.5";
+        String embeddingModelName = queryVectorBo.getEmbeddingModelName();
+        //对齐模型,在数据库中存储为:quentinz/bge-large-zh-v1.5
+        String[] split = embeddingModelName.split("/");
+        String modelUid = split[split.length - 1];
         List<Double> queryVector;
         try {
             queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
@@ -282,7 +280,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         Result<GraphQLResponse> result = raw.run();
         // 获取 Get 下的值
         Object data = result.getResult().getData();
-        List<Map<String,Object>> results = new ArrayList<>();
+        List<Map<String, Object>> results = new ArrayList<>();
         if (data instanceof Map) {
             Map<String, Object> dataMap = (Map<String, Object>) data;
             Map<String, Object> getMap = (Map<String, Object>) dataMap.get("Get");
@@ -294,8 +292,8 @@ public class VectorStoreServiceImpl implements VectorStoreService {
                         String text = (String) item.get("text");
                         String docId = (String) item.get("docId");
                         String filename = (String) item.get("filename");
-                        hashMap.put("filename",filename);
-                        hashMap.put("text",text);
+                        hashMap.put("filename", filename);
+                        hashMap.put("text", text);
                         results.add(hashMap);
                     }
                 }
@@ -306,8 +304,9 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         return results;
 
     }
+
     @Override
-    public List<Map<String,Object>> getQueryVector(QueryVectorBo queryVectorBo,double score) {
+    public List<Map<String, Object>> getQueryVector(QueryVectorBo queryVectorBo, double score) {
         createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
         String protocol = configService.getConfigValue("weaviate", "protocol");
         String host = configService.getConfigValue("weaviate", "host");
@@ -317,7 +316,10 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         // 初始化 Weaviate 客户端
         WeaviateClient client = new WeaviateClient(new Config(protocol, host));
         XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
-        String modelUid = "bge-large-zh-v1.5";
+        String embeddingModelName = queryVectorBo.getEmbeddingModelName();
+        //对齐模型,在数据库中存储为:quentinz/bge-large-zh-v1.5
+        String[] split = embeddingModelName.split("/");
+        String modelUid = split[split.length - 1];
         List<Double> queryVector;
         try {
             queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
@@ -333,7 +335,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         String graphqlQuery = String.format(
                 "{" +
                         "  Get {%n" +
-                        "    %s(nearVector: {vector: [%s], certainty:"+score+"}, limit: %d) {%n" +
+                        "    %s(nearVector: {vector: [%s], certainty:" + score + "}, limit: %d) {%n" +
                         "      text%n" +
                         "      docId%n" +
                         "      filename%n" +
@@ -350,7 +352,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         Result<GraphQLResponse> result = raw.run();
         // 获取 Get 下的值
         Object data = result.getResult().getData();
-        List<Map<String,Object>> results = new ArrayList<>();
+        List<Map<String, Object>> results = new ArrayList<>();
         if (data instanceof Map) {
             Map<String, Object> dataMap = (Map<String, Object>) data;
             Map<String, Object> getMap = (Map<String, Object>) dataMap.get("Get");
@@ -362,14 +364,12 @@ public class VectorStoreServiceImpl implements VectorStoreService {
                         String text = (String) item.get("text");
                         String docId = (String) item.get("docId");
                         String filename = (String) item.get("filename");
-                        hashMap.put("filename",filename);
-                        hashMap.put("text",text);
+                        hashMap.put("filename", filename);
+                        hashMap.put("text", text);
                         results.add(hashMap);
                     }
                 }
             }
-        } else {
-            System.err.println("Data is not a Map: " + data);
         }
         return results;