feng 9 місяців тому
батько
коміт
0c771f5d8d

+ 17 - 20
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java

@@ -3,13 +3,10 @@ package org.ruoyi.service.impl;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.protobuf.ServiceException;
-import dev.langchain4j.data.embedding.Embedding;
 import dev.langchain4j.data.segment.TextSegment;
 import dev.langchain4j.model.embedding.EmbeddingModel;
 import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
 import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
-import dev.langchain4j.store.embedding.EmbeddingMatch;
-import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
 import dev.langchain4j.store.embedding.EmbeddingStore;
 import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
 import io.weaviate.client.Config;
@@ -25,7 +22,6 @@ import io.weaviate.client.v1.schema.model.WeaviateClass;
 import lombok.RequiredArgsConstructor;
 import lombok.SneakyThrows;
 import lombok.extern.slf4j.Slf4j;
-import org.apache.rocketmq.client.QueryResult;
 import org.ruoyi.chain.xinference.XinferenceClient;
 import org.ruoyi.common.core.service.ConfigService;
 import org.ruoyi.domain.bo.KnowledgeAttachBo;
@@ -251,14 +247,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         String fullClassName = className + kid;
         // 初始化 Weaviate 客户端
         WeaviateClient client = new WeaviateClient(new Config(protocol, host));
-        XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
-        String modelUid = "bge-large-zh-v1.5";
-        List<Double> queryVector;
-        try {
-            queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
-        } catch (IOException e) {
-            throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
-        }
+        final List<Double> queryVector = getDoubles(queryVectorBo);
         List<Float> floatList = toFloatList(queryVector);
 
         // 手动构造 GraphQL 查询字符串,包含 nearVector 和 limit
@@ -311,6 +300,21 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         return results;
 
     }
+
+    private static List<Double> getDoubles(QueryVectorBo queryVectorBo) {
+        XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
+        String vectorModelName = queryVectorBo.getVectorModelName();
+        String[] split = vectorModelName.split("/");
+        String modelUid = split[split.length-1];
+        List<Double> queryVector;
+        try {
+            queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
+        } catch (IOException e) {
+            throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
+        }
+        return queryVector;
+    }
+
     @Override
     public List<Map<String,Object>> getQueryVector(QueryVectorBo queryVectorBo,double score) {
         createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
@@ -321,14 +325,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
         String fullClassName = className + kid;
         // 初始化 Weaviate 客户端
         WeaviateClient client = new WeaviateClient(new Config(protocol, host));
-        XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
-        String modelUid = "bge-large-zh-v1.5";
-        List<Double> queryVector;
-        try {
-            queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
-        } catch (IOException e) {
-            throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
-        }
+        final List<Double> queryVector = getDoubles(queryVectorBo);
         List<Float> floatList = toFloatList(queryVector);
 
         // 手动构造 GraphQL 查询字符串,包含 nearVector 和 limit