Răsfoiți Sursa

省中适配模型的https

ligao 4 luni în urmă
părinte
comite
8f22d19dc7

+ 148 - 12
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/chain/xinference/XinferenceClient.java

@@ -5,7 +5,10 @@ import okhttp3.*;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.net.ssl.*;
 import java.io.IOException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
 import java.util.*;
 
 /**
@@ -17,17 +20,85 @@ public class XinferenceClient {
     private final OkHttpClient client;
     private final ObjectMapper objectMapper;
     private final String baseUrl;
+    private final String apiKey;
 
     /**
      * 构造函数,初始化客户端
-     * @param baseUrl Xinference 服务的基础 URL,例如 http://192.168.238.116:9997
+     * @param baseUrl 服务的基础 URL,例如 https://192.168.10.124:5443/open/router
+     * @param apiKey API 密钥,用于认证(Bearer 格式)
      */
-    public XinferenceClient(String baseUrl) {
-        this.baseUrl = baseUrl.endsWith("/") ? baseUrl : baseUrl + "/";
-        this.client = new OkHttpClient();
+    public XinferenceClient(String baseUrl, String apiKey) {
+        if (baseUrl == null || baseUrl.trim().isEmpty()) {
+            throw new IllegalArgumentException("baseUrl 不能为空");
+        }
+        // 确保 baseUrl 以 / 结尾,方便拼接路径
+        this.baseUrl = baseUrl.trim().endsWith("/") ? baseUrl.trim() : baseUrl.trim() + "/";
+        this.apiKey = apiKey;
+        // 记录 API Key 信息(部分隐藏,用于调试)
+        if (apiKey != null && !apiKey.trim().isEmpty()) {
+            String trimmedKey = apiKey.trim();
+            String maskedKey = trimmedKey.length() > 8 
+                ? trimmedKey.substring(0, 4) + "..." + trimmedKey.substring(trimmedKey.length() - 4)
+                : "****";
+            logger.info("初始化 XinferenceClient,baseUrl: {}, API Key长度: {}, API Key(部分): {}", 
+                    this.baseUrl, trimmedKey.length(), maskedKey);
+        } else {
+            logger.warn("初始化 XinferenceClient,baseUrl: {}, API Key 为空或未提供", this.baseUrl);
+        }
+        this.client = createUnsafeOkHttpClient();
         this.objectMapper = new ObjectMapper();
     }
 
+    /**
+     * 构造函数,初始化客户端(兼容旧代码,不传 apiKey)
+     * @param baseUrl 服务的基础 URL,例如 https://192.168.10.124:5443/open/router
+     */
+    public XinferenceClient(String baseUrl) {
+        this(baseUrl, null);
+    }
+
+    /**
+     * 创建一个信任所有证书的 OkHttpClient(仅用于开发环境)
+     * 注意:生产环境应该使用正确的 SSL 证书
+     */
+    private OkHttpClient createUnsafeOkHttpClient() {
+        try {
+            // 创建一个信任所有证书的 TrustManager
+            final TrustManager[] trustAllCerts = new TrustManager[]{
+                new X509TrustManager() {
+                    @Override
+                    public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
+                    }
+
+                    @Override
+                    public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
+                    }
+
+                    @Override
+                    public X509Certificate[] getAcceptedIssuers() {
+                        return new X509Certificate[]{};
+                    }
+                }
+            };
+
+            // 创建 SSLContext 并使用信任所有证书的 TrustManager
+            final SSLContext sslContext = SSLContext.getInstance("SSL");
+            sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
+
+            // 创建 HostnameVerifier,接受所有主机名
+            final HostnameVerifier hostnameVerifier = (hostname, session) -> true;
+
+            // 创建 OkHttpClient,配置 SSL 和主机名验证
+            return new OkHttpClient.Builder()
+                    .sslSocketFactory(sslContext.getSocketFactory(), (X509TrustManager) trustAllCerts[0])
+                    .hostnameVerifier(hostnameVerifier)
+                    .build();
+        } catch (Exception e) {
+            logger.error("创建 OkHttpClient 失败,使用默认配置", e);
+            return new OkHttpClient();
+        }
+    }
+
     /**
      * 获取可用模型列表
      * @return 模型 ID 列表
@@ -35,13 +106,22 @@ public class XinferenceClient {
      */
     public List<String> getAvailableModels() throws IOException {
         String url = baseUrl + "v1/models";
-        System.out.println(url);
-        Request request = new Request.Builder().url("url").get().build();
+        logger.info("获取模型列表,请求URL: {}", url);
+        Request.Builder requestBuilder = new Request.Builder()
+                .url(url)
+                .get();
+        
+        // 添加 API Key 认证头(Bearer 格式)
+        if (apiKey != null && !apiKey.trim().isEmpty()) {
+            requestBuilder.addHeader("Authorization", "Bearer " + apiKey.trim());
+        }
+        
+        Request request = requestBuilder.build();
         try (Response response = client.newCall(request).execute()) {
             if (!response.isSuccessful()) {
                 String errorBody = response.body() != null ? response.body().string() : "无响应体";
-                logger.error("获取模型列表失败,状态码: {}, 响应: {}", response.code(), errorBody);
-                throw new IOException("获取模型列表失败,状态码: " + response.code());
+                logger.error("获取模型列表失败,请求URL: {}, 状态码: {}, 响应: {}", url, response.code(), errorBody);
+                throw new IOException("获取模型列表失败,状态码: " + response.code() + ", 响应: " + errorBody);
             }
             String responseBody = response.body().string();
             logger.debug("模型列表响应: {}", responseBody);
@@ -71,6 +151,8 @@ public class XinferenceClient {
             throw new IllegalArgumentException("inputTexts 不能为空");
         }
 
+        // 拼接完整的 API URL: baseUrl + v1/embeddings
+        // 例如: https://192.168.10.124:5443/open/router/ + v1/embeddings
         String url = baseUrl + "v1/embeddings";
         Map<String, Object> json = new HashMap<>();
         json.put("model", modelName);
@@ -81,13 +163,67 @@ public class XinferenceClient {
         }
         json.putAll(extraParams);
 
-        logger.info("发送嵌入请求,model: {}, inputTexts: {}, extraParams: {}", modelName, inputTexts, extraParams);
-        RequestBody body = RequestBody.create(objectMapper.writeValueAsString(json), MediaType.parse("application/json"));
-        Request request = new Request.Builder()
+        // 构建请求体 JSON
+        String jsonBody = objectMapper.writeValueAsString(json);
+        RequestBody body = RequestBody.create(jsonBody, MediaType.parse("application/json"));
+        
+        Request.Builder requestBuilder = new Request.Builder()
                 .url(url)
                 .post(body)
                 .addHeader("accept", "application/json")
-                .build();
+                .addHeader("Content-Type", "application/json");
+        
+        // 添加 API Key 认证头(Bearer 格式)
+        String trimmedApiKey = null;
+        if (apiKey != null && !apiKey.trim().isEmpty()) {
+            trimmedApiKey = apiKey.trim();
+            requestBuilder.addHeader("Authorization", "Bearer " + trimmedApiKey);
+        }
+        
+        Request request = requestBuilder.build();
+        
+        // ========== 打印请求前的所有参数 ==========
+        logger.info("========== 发送嵌入向量请求 ==========");
+        logger.info("请求URL: {}", url);
+        logger.info("请求方法: POST");
+        logger.info("请求参数:");
+        logger.info("  - model: {}", modelName);
+        logger.info("  - inputTexts数量: {}", inputTexts != null ? inputTexts.size() : 0);
+        if (inputTexts != null && !inputTexts.isEmpty()) {
+            for (int i = 0; i < Math.min(inputTexts.size(), 3); i++) {
+                String text = inputTexts.get(i);
+                String preview = text.length() > 50 ? text.substring(0, 50) + "..." : text;
+                logger.info("  - inputTexts[{}]: {}", i, preview);
+            }
+            if (inputTexts.size() > 3) {
+                logger.info("  - ... (还有 {} 条文本)", inputTexts.size() - 3);
+            }
+        }
+        logger.info("  - extraParams: {}", extraParams);
+        logger.info("请求体JSON: {}", jsonBody);
+        logger.info("请求头信息:");
+        request.headers().forEach(header -> {
+            String headerName = header.getFirst();
+            String headerValue = header.getSecond();
+            if ("Authorization".equalsIgnoreCase(headerName)) {
+                // 隐藏 API Key 的中间部分
+                String maskedAuth = headerValue.length() > 20 
+                    ? headerValue.substring(0, 15) + "..." + headerValue.substring(headerValue.length() - 5)
+                    : "****";
+                logger.info("  - {}: {}", headerName, maskedAuth);
+            } else {
+                logger.info("  - {}: {}", headerName, headerValue);
+            }
+        });
+        if (trimmedApiKey != null) {
+            String maskedKey = trimmedApiKey.length() > 8 
+                ? trimmedApiKey.substring(0, 4) + "..." + trimmedApiKey.substring(trimmedApiKey.length() - 4)
+                : "****";
+            logger.info("API Key信息: 长度={}, 值(部分)={}", trimmedApiKey.length(), maskedKey);
+        } else {
+            logger.warn("API Key: 未提供(可能导致 401 认证失败)");
+        }
+        logger.info("========================================");
 
         try (Response response = client.newCall(request).execute()) {
             if (!response.isSuccessful()) {

+ 11 - 4
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/service/impl/VectorStoreServiceImpl.java

@@ -173,10 +173,17 @@ public class VectorStoreServiceImpl implements VectorStoreService {
             throw new IllegalArgumentException("docId 不能为空");
         }
 
-        XinferenceClient xinferenceClient = new XinferenceClient(storeEmbeddingBo.getBaseUrl());
+        XinferenceClient xinferenceClient = new XinferenceClient(storeEmbeddingBo.getBaseUrl(), storeEmbeddingBo.getApiKey());
         String modelName = storeEmbeddingBo.getEmbeddingModelName();
-        if ("quentinz/bge-large-zh-v1.5".contains(modelName)) {
-            modelName = "bge-large-zh-v1.5";
+        
+        // 模型名称映射:将 bge-large-zh-v1.5 映射为 bge-large-zh(API 期望的模型名称)
+        if (modelName != null && (modelName.contains("bge-large-zh-v1.5") || modelName.equals("bge-large-zh-v1.5"))) {
+            modelName = "bge-large-zh";
+        }
+        
+        // 兼容旧逻辑:处理 quentinz/bge-large-zh-v1.5 格式
+        if (modelName != null && modelName.contains("quentinz/bge-large-zh-v1.5")) {
+            modelName = "bge-large-zh";
         }
 
         for (int i = 0; i < chunkList.size(); i++) {
@@ -411,7 +418,7 @@ public class VectorStoreServiceImpl implements VectorStoreService {
     }
 
     private static List<Double> getDoubles(QueryVectorBo queryVectorBo) {
-        XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl());
+        XinferenceClient xinferenceClient = new XinferenceClient(queryVectorBo.getBaseUrl(), queryVectorBo.getApiKey());
         String embeddingModelName = queryVectorBo.getEmbeddingModelName();
         String[] embeddingModelNameArr = embeddingModelName.split("/");
         String modelUid = embeddingModelNameArr[embeddingModelNameArr.length - 1];