|
|
@@ -3,13 +3,10 @@ package org.ruoyi.service.impl;
|
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import com.google.protobuf.ServiceException;
|
|
|
-import dev.langchain4j.data.embedding.Embedding;
|
|
|
import dev.langchain4j.data.segment.TextSegment;
|
|
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
|
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
|
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
|
|
-import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
|
-import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
|
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
|
|
import io.weaviate.client.Config;
|
|
|
@@ -25,7 +22,6 @@ import io.weaviate.client.v1.schema.model.WeaviateClass;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.SneakyThrows;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
-import org.apache.rocketmq.client.QueryResult;
|
|
|
import org.ruoyi.chain.xinference.XinferenceClient;
|
|
|
import org.ruoyi.common.core.service.ConfigService;
|
|
|
import org.ruoyi.domain.bo.KnowledgeAttachBo;
|
|
|
@@ -251,14 +247,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
String fullClassName = className + kid;
|
|
|
// 初始化 Weaviate 客户端
|
|
|
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
|
|
|
- XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
|
|
|
- String modelUid = "bge-large-zh-v1.5";
|
|
|
- List<Double> queryVector;
|
|
|
- try {
|
|
|
- queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
|
|
|
- } catch (IOException e) {
|
|
|
- throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
|
|
|
- }
|
|
|
+ final List<Double> queryVector = getDoubles(queryVectorBo);
|
|
|
List<Float> floatList = toFloatList(queryVector);
|
|
|
|
|
|
// 手动构造 GraphQL 查询字符串,包含 nearVector 和 limit
|
|
|
@@ -311,6 +300,21 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
return results;
|
|
|
|
|
|
}
|
|
|
+
|
|
|
+ private static List<Double> getDoubles(QueryVectorBo queryVectorBo) {
|
|
|
+ XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
|
|
|
+ String vectorModelName = queryVectorBo.getVectorModelName();
|
|
|
+ String[] split = vectorModelName.split("/");
|
|
|
+ String modelUid = split[split.length-1];
|
|
|
+ List<Double> queryVector;
|
|
|
+ try {
|
|
|
+ queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
|
|
|
+ } catch (IOException e) {
|
|
|
+ throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
|
|
|
+ }
|
|
|
+ return queryVector;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public List<Map<String,Object>> getQueryVector(QueryVectorBo queryVectorBo,double score) {
|
|
|
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
|
|
@@ -321,14 +325,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
String fullClassName = className + kid;
|
|
|
// 初始化 Weaviate 客户端
|
|
|
WeaviateClient client = new WeaviateClient(new Config(protocol, host));
|
|
|
- XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
|
|
|
- String modelUid = "bge-large-zh-v1.5";
|
|
|
- List<Double> queryVector;
|
|
|
- try {
|
|
|
- queryVector = xinferenceClient.getEmbedding(modelUid, queryVectorBo.getQuery());
|
|
|
- } catch (IOException e) {
|
|
|
- throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
|
|
|
- }
|
|
|
+ final List<Double> queryVector = getDoubles(queryVectorBo);
|
|
|
List<Float> floatList = toFloatList(queryVector);
|
|
|
|
|
|
// 手动构造 GraphQL 查询字符串,包含 nearVector 和 limit
|