|
|
@@ -10,19 +10,30 @@ import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
|
import dev.langchain4j.store.embedding.weaviate.WeaviateEmbeddingStore;
|
|
|
+import io.weaviate.client.Config;
|
|
|
+import io.weaviate.client.WeaviateClient;
|
|
|
+import io.weaviate.client.base.Result;
|
|
|
+import io.weaviate.client.v1.data.model.WeaviateObject;
|
|
|
+import io.weaviate.client.v1.schema.model.Property;
|
|
|
+import io.weaviate.client.v1.schema.model.Schema;
|
|
|
+import io.weaviate.client.v1.schema.model.WeaviateClass;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.SneakyThrows;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.ruoyi.chain.xinference.XinferenceClient;
|
|
|
import org.ruoyi.common.core.service.ConfigService;
|
|
|
+import org.ruoyi.domain.bo.KnowledgeAttachBo;
|
|
|
import org.ruoyi.domain.bo.QueryVectorBo;
|
|
|
import org.ruoyi.domain.bo.StoreEmbeddingBo;
|
|
|
+import org.ruoyi.domain.vo.KnowledgeAttachVo;
|
|
|
+import org.ruoyi.service.IKnowledgeAttachService;
|
|
|
import org.ruoyi.service.VectorStoreService;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.ArrayList;
|
|
|
-import java.util.List;
|
|
|
+import java.lang.reflect.Field;
|
|
|
+import java.util.*;
|
|
|
|
|
|
/**
|
|
|
* 向量库管理
|
|
|
@@ -38,49 +49,162 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
|
|
|
private EmbeddingStore<TextSegment> embeddingStore;
|
|
|
|
|
|
+ @Autowired
|
|
|
+ private IKnowledgeAttachService knowledgeAttachService;
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
|
|
|
@Override
|
|
|
public void createSchema(String kid, String modelName) {
|
|
|
String protocol = configService.getConfigValue("weaviate", "protocol");
|
|
|
String host = configService.getConfigValue("weaviate", "host");
|
|
|
String className = configService.getConfigValue("weaviate", "classname");
|
|
|
+ String fullClassName = className + kid;
|
|
|
+
|
|
|
+ // Initialize Weaviate client
|
|
|
+ WeaviateClient client = new WeaviateClient(new Config(protocol, host));
|
|
|
+
|
|
|
+ // Check if schema exists
|
|
|
+ Result<Schema> existingSchema = client.schema().getter().run();
|
|
|
+ if (existingSchema.hasErrors()) {
|
|
|
+ throw new RuntimeException("Failed to get schema: " + existingSchema.getError().toString());
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check if class exists
|
|
|
+ boolean classExists = existingSchema.getResult().getClasses().stream()
|
|
|
+ .anyMatch(c -> c.getClassName().equals(fullClassName));
|
|
|
+ if (!classExists) {
|
|
|
+ WeaviateClass weaviateClass = WeaviateClass.builder()
|
|
|
+ .className(fullClassName)
|
|
|
+ .properties(Arrays.asList(
|
|
|
+ Property.builder().name("text").dataType(Arrays.asList("text")).build(),
|
|
|
+ Property.builder().name("indexFilterable").dataType(Arrays.asList("boolean")).build(),
|
|
|
+ Property.builder().name("indexSearchable").dataType(Arrays.asList("boolean")).build(),
|
|
|
+ Property.builder().name("docId").dataType(Arrays.asList("string")).build()
|
|
|
+ ))
|
|
|
+ .vectorizer("none") // Set to "none" since vectors are provided externally
|
|
|
+ .build();
|
|
|
+ Schema schema = Schema.builder()
|
|
|
+ .classes(Collections.singletonList(weaviateClass))
|
|
|
+ .build();
|
|
|
+ Result<Boolean> result = client.schema().classCreator().withClass(weaviateClass).run();
|
|
|
+ } else {
|
|
|
+ // Check if docId property exists
|
|
|
+ WeaviateClass existingClass = existingSchema.getResult().getClasses().stream()
|
|
|
+ .filter(c -> c.getClassName().equals(fullClassName))
|
|
|
+ .findFirst()
|
|
|
+ .orElse(null);
|
|
|
+ boolean hasDocId = existingClass.getProperties().stream()
|
|
|
+ .anyMatch(p -> p.getName().equals("docId"));
|
|
|
+ if (!hasDocId) {
|
|
|
+ // Add docId property
|
|
|
+ Property docIdProperty = Property.builder()
|
|
|
+ .name("docId")
|
|
|
+ .dataType(Arrays.asList("string"))
|
|
|
+ .build();
|
|
|
+ Result<Boolean> updateResult = client.schema().propertyCreator()
|
|
|
+ .withClassName(fullClassName)
|
|
|
+ .withProperty(docIdProperty)
|
|
|
+ .run();
|
|
|
+ }
|
|
|
+ }
|
|
|
embeddingStore = WeaviateEmbeddingStore.builder()
|
|
|
.scheme(protocol)
|
|
|
.host(host)
|
|
|
- .objectClass(className+kid)
|
|
|
- .scheme(protocol)
|
|
|
+ .objectClass(fullClassName)
|
|
|
.avoidDups(true)
|
|
|
.consistencyLevel("ALL")
|
|
|
.build();
|
|
|
}
|
|
|
-
|
|
|
@Override
|
|
|
- public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
|
|
- createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
|
|
|
- EmbeddingModel embeddingModel = getEmbeddingModel(storeEmbeddingBo.getEmbeddingModelName(),
|
|
|
- storeEmbeddingBo.getApiKey(), storeEmbeddingBo.getBaseUrl());
|
|
|
+ public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
|
|
+ String protocol = configService.getConfigValue("weaviate", "protocol");
|
|
|
+ String host = configService.getConfigValue("weaviate", "host");
|
|
|
+ String className = configService.getConfigValue("weaviate", "classname");
|
|
|
+ String kid = storeEmbeddingBo.getKid();
|
|
|
+ String fullClassName = className + kid;
|
|
|
+
|
|
|
+ // 初始化 Weaviate 客户端
|
|
|
+ WeaviateClient client = new WeaviateClient(new Config(protocol, host));
|
|
|
+ createSchema(kid, storeEmbeddingBo.getVectorModelName());
|
|
|
+
|
|
|
List<String> chunkList = storeEmbeddingBo.getChunkList();
|
|
|
+ String docId = storeEmbeddingBo.getDocId();
|
|
|
+ if (docId == null || docId.trim().isEmpty()) {
|
|
|
+ throw new IllegalArgumentException("docId 不能为空");
|
|
|
+ }
|
|
|
|
|
|
XinferenceClient xinferenceClient = new XinferenceClient(storeEmbeddingBo.getBaseUrl());
|
|
|
String modelName = storeEmbeddingBo.getEmbeddingModelName();
|
|
|
- if("quentinz/bge-large-zh-v1.5".contains(modelName)){
|
|
|
+ if ("quentinz/bge-large-zh-v1.5".contains(modelName)) {
|
|
|
modelName = "bge-large-zh-v1.5";
|
|
|
}
|
|
|
- for (String text : chunkList) {
|
|
|
+
|
|
|
+ for (int i = 0; i < chunkList.size(); i++) {
|
|
|
+ String text = chunkList.get(i);
|
|
|
+ if (text == null || text.trim().isEmpty()) {
|
|
|
+ System.out.println("跳过空文本: " + i);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
List<Double> vector;
|
|
|
try {
|
|
|
vector = xinferenceClient.getEmbedding(modelName, text);
|
|
|
} catch (IOException e) {
|
|
|
- throw new RuntimeException(e);
|
|
|
+ throw new RuntimeException("生成向量失败: " + e.getMessage(), e);
|
|
|
}
|
|
|
- Embedding embedding = new Embedding(toFloatArray(vector));
|
|
|
- if (text != null && !text.trim().isEmpty()) {
|
|
|
- TextSegment segment = TextSegment.from(text);
|
|
|
- embeddingStore.add(embedding, segment);
|
|
|
+
|
|
|
+ // 创建 Weaviate 对象
|
|
|
+ Map<String, Object> properties = new HashMap<>();
|
|
|
+ properties.put("text", text);
|
|
|
+ properties.put("indexFilterable", true);
|
|
|
+ properties.put("indexSearchable", true);
|
|
|
+ properties.put("docId", docId); // 显式设置 docId
|
|
|
+
|
|
|
+ try {
|
|
|
+ Result<WeaviateObject> result = client.data().creator()
|
|
|
+ .withClassName(fullClassName)
|
|
|
+ .withProperties(properties)
|
|
|
+ .withVector(toFloatArray(vector))
|
|
|
+ .run();
|
|
|
+ if (result.hasErrors()) {
|
|
|
+ throw new RuntimeException("存储对象失败: " + result.getError().toString());
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ throw new RuntimeException("存储向量失败: " + e.getMessage(), e);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+// @Override
|
|
|
+// public void storeEmbeddings(StoreEmbeddingBo storeEmbeddingBo) {
|
|
|
+// createSchema(storeEmbeddingBo.getKid(), storeEmbeddingBo.getVectorModelName());
|
|
|
+// List<String> chunkList = storeEmbeddingBo.getChunkList();
|
|
|
+// String docId = storeEmbeddingBo.getDocId();
|
|
|
+// XinferenceClient xinferenceClient = new XinferenceClient(storeEmbeddingBo.getBaseUrl());
|
|
|
+// String modelName = storeEmbeddingBo.getEmbeddingModelName();
|
|
|
+// if ("quentinz/bge-large-zh-v1.5".contains(modelName)) {
|
|
|
+// modelName = "bge-large-zh-v1.5";
|
|
|
+// }
|
|
|
+// for (int i = 0; i < chunkList.size(); i++) {
|
|
|
+// List<Double> vector;
|
|
|
+// String text = chunkList.get(i);
|
|
|
+// try {
|
|
|
+// vector = xinferenceClient.getEmbedding(modelName, text);
|
|
|
+// } catch (IOException e) {
|
|
|
+// throw new RuntimeException(e);
|
|
|
+// }
|
|
|
+// Metadata metadata = new Metadata();
|
|
|
+// metadata.put("docId", docId);
|
|
|
+// Embedding embedding = new Embedding(toFloatArray(vector));
|
|
|
+// if (text != null && !text.trim().isEmpty()) {
|
|
|
+// TextSegment segment = TextSegment.from(text,metadata);
|
|
|
+// final String add = embeddingStore.add(embedding, segment);
|
|
|
+// System.out.println(add);
|
|
|
+// }
|
|
|
+// }
|
|
|
+// }
|
|
|
+
|
|
|
@Override
|
|
|
public List<String> getQueryVector(QueryVectorBo queryVectorBo) {
|
|
|
createSchema(queryVectorBo.getKid(), queryVectorBo.getVectorModelName());
|
|
|
@@ -132,6 +256,43 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
return null;
|
|
|
}
|
|
|
|
|
|
+ public void removeByDocId(String docId) {
|
|
|
+ KnowledgeAttachBo knowledgeAttachBo = new KnowledgeAttachBo();
|
|
|
+ knowledgeAttachBo.setDocId(docId);
|
|
|
+ List<KnowledgeAttachVo> knowledgeAttachVoList = knowledgeAttachService.queryList(knowledgeAttachBo);
|
|
|
+ String kid = knowledgeAttachVoList.get(0).getKid();
|
|
|
+ createSchema(kid, "");
|
|
|
+
|
|
|
+ try {
|
|
|
+ // 反射拿到私有字段 client
|
|
|
+ Field clientField = embeddingStore.getClass().getDeclaredField("client");
|
|
|
+ clientField.setAccessible(true);
|
|
|
+ Object client = clientField.get(embeddingStore);
|
|
|
+
|
|
|
+ // 反射拿到私有字段 objectClass
|
|
|
+ Field objectClassField = embeddingStore.getClass().getDeclaredField("objectClass");
|
|
|
+ objectClassField.setAccessible(true);
|
|
|
+ String objectClass = (String) objectClassField.get(embeddingStore);
|
|
|
+
|
|
|
+ // 下面需要强转 client 为 io.weaviate.client.WeaviateClient
|
|
|
+ io.weaviate.client.WeaviateClient weaviateClient = (io.weaviate.client.WeaviateClient) client;
|
|
|
+
|
|
|
+ // 构造删除请求
|
|
|
+ weaviateClient.batch()
|
|
|
+ .objectsBatchDeleter()
|
|
|
+ .withClassName(objectClass)
|
|
|
+ .withWhere(io.weaviate.client.v1.filters.WhereFilter.builder()
|
|
|
+ .path("docId")
|
|
|
+ .operator("Equal")
|
|
|
+ .valueText(docId)
|
|
|
+ .build())
|
|
|
+ .run();
|
|
|
+
|
|
|
+ } catch (NoSuchFieldException | IllegalAccessException e) {
|
|
|
+ throw new RuntimeException(e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
@Override
|
|
|
public void removeById(String id, String modelName) {
|
|
|
@@ -162,6 +323,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
}
|
|
|
return embeddingModel;
|
|
|
}
|
|
|
+
|
|
|
public static List<Float> toFloatList(List<Double> doubleList) {
|
|
|
List<Float> floatList = new ArrayList<>(doubleList.size());
|
|
|
for (Double d : doubleList) {
|
|
|
@@ -169,8 +331,9 @@ public class VectorStoreServiceImpl implements VectorStoreService {
|
|
|
}
|
|
|
return floatList;
|
|
|
}
|
|
|
- public static float[] toFloatArray(List<Double> doubleList) {
|
|
|
- float[] floatArray = new float[doubleList.size()];
|
|
|
+
|
|
|
+ public static Float[] toFloatArray(List<Double> doubleList) {
|
|
|
+ Float[] floatArray = new Float[doubleList.size()];
|
|
|
for (int i = 0; i < doubleList.size(); i++) {
|
|
|
floatArray[i] = doubleList.get(i).floatValue();
|
|
|
}
|