|
|
@@ -5,7 +5,10 @@ import okhttp3.*;
|
|
|
import org.slf4j.Logger;
|
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
|
|
+import javax.net.ssl.*;
|
|
|
import java.io.IOException;
|
|
|
+import java.security.cert.CertificateException;
|
|
|
+import java.security.cert.X509Certificate;
|
|
|
import java.util.*;
|
|
|
|
|
|
/**
|
|
|
@@ -17,17 +20,85 @@ public class XinferenceClient {
|
|
|
private final OkHttpClient client;
|
|
|
private final ObjectMapper objectMapper;
|
|
|
private final String baseUrl;
|
|
|
+ private final String apiKey;
|
|
|
|
|
|
/**
|
|
|
* 构造函数,初始化客户端
|
|
|
- * @param baseUrl Xinference 服务的基础 URL,例如 http://192.168.238.116:9997
|
|
|
+ * @param baseUrl 服务的基础 URL,例如 https://192.168.10.124:5443/open/router
|
|
|
+ * @param apiKey API 密钥,用于认证(Bearer 格式)
|
|
|
*/
|
|
|
- public XinferenceClient(String baseUrl) {
|
|
|
- this.baseUrl = baseUrl.endsWith("/") ? baseUrl : baseUrl + "/";
|
|
|
- this.client = new OkHttpClient();
|
|
|
+ public XinferenceClient(String baseUrl, String apiKey) {
|
|
|
+ if (baseUrl == null || baseUrl.trim().isEmpty()) {
|
|
|
+ throw new IllegalArgumentException("baseUrl 不能为空");
|
|
|
+ }
|
|
|
+ // 确保 baseUrl 以 / 结尾,方便拼接路径
|
|
|
+ this.baseUrl = baseUrl.trim().endsWith("/") ? baseUrl.trim() : baseUrl.trim() + "/";
|
|
|
+ this.apiKey = apiKey;
|
|
|
+ // 记录 API Key 信息(部分隐藏,用于调试)
|
|
|
+ if (apiKey != null && !apiKey.trim().isEmpty()) {
|
|
|
+ String trimmedKey = apiKey.trim();
|
|
|
+ String maskedKey = trimmedKey.length() > 8
|
|
|
+ ? trimmedKey.substring(0, 4) + "..." + trimmedKey.substring(trimmedKey.length() - 4)
|
|
|
+ : "****";
|
|
|
+ logger.info("初始化 XinferenceClient,baseUrl: {}, API Key长度: {}, API Key(部分): {}",
|
|
|
+ this.baseUrl, trimmedKey.length(), maskedKey);
|
|
|
+ } else {
|
|
|
+ logger.warn("初始化 XinferenceClient,baseUrl: {}, API Key 为空或未提供", this.baseUrl);
|
|
|
+ }
|
|
|
+ this.client = createUnsafeOkHttpClient();
|
|
|
this.objectMapper = new ObjectMapper();
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * 构造函数,初始化客户端(兼容旧代码,不传 apiKey)
|
|
|
+ * @param baseUrl 服务的基础 URL,例如 https://192.168.10.124:5443/open/router
|
|
|
+ */
|
|
|
+ public XinferenceClient(String baseUrl) {
|
|
|
+ this(baseUrl, null);
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 创建一个信任所有证书的 OkHttpClient(仅用于开发环境)
|
|
|
+ * 注意:生产环境应该使用正确的 SSL 证书
|
|
|
+ */
|
|
|
+ private OkHttpClient createUnsafeOkHttpClient() {
|
|
|
+ try {
|
|
|
+ // 创建一个信任所有证书的 TrustManager
|
|
|
+ final TrustManager[] trustAllCerts = new TrustManager[]{
|
|
|
+ new X509TrustManager() {
|
|
|
+ @Override
|
|
|
+ public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public X509Certificate[] getAcceptedIssuers() {
|
|
|
+ return new X509Certificate[]{};
|
|
|
+ }
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ // 创建 SSLContext 并使用信任所有证书的 TrustManager
|
|
|
+ final SSLContext sslContext = SSLContext.getInstance("SSL");
|
|
|
+ sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
|
|
|
+
|
|
|
+ // 创建 HostnameVerifier,接受所有主机名
|
|
|
+ final HostnameVerifier hostnameVerifier = (hostname, session) -> true;
|
|
|
+
|
|
|
+ // 创建 OkHttpClient,配置 SSL 和主机名验证
|
|
|
+ return new OkHttpClient.Builder()
|
|
|
+ .sslSocketFactory(sslContext.getSocketFactory(), (X509TrustManager) trustAllCerts[0])
|
|
|
+ .hostnameVerifier(hostnameVerifier)
|
|
|
+ .build();
|
|
|
+ } catch (Exception e) {
|
|
|
+ logger.error("创建 OkHttpClient 失败,使用默认配置", e);
|
|
|
+ return new OkHttpClient();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* 获取可用模型列表
|
|
|
* @return 模型 ID 列表
|
|
|
@@ -35,13 +106,22 @@ public class XinferenceClient {
|
|
|
*/
|
|
|
public List<String> getAvailableModels() throws IOException {
|
|
|
String url = baseUrl + "v1/models";
|
|
|
- System.out.println(url);
|
|
|
- Request request = new Request.Builder().url("url").get().build();
|
|
|
+ logger.info("获取模型列表,请求URL: {}", url);
|
|
|
+ Request.Builder requestBuilder = new Request.Builder()
|
|
|
+ .url(url)
|
|
|
+ .get();
|
|
|
+
|
|
|
+ // 添加 API Key 认证头(Bearer 格式)
|
|
|
+ if (apiKey != null && !apiKey.trim().isEmpty()) {
|
|
|
+ requestBuilder.addHeader("Authorization", "Bearer " + apiKey.trim());
|
|
|
+ }
|
|
|
+
|
|
|
+ Request request = requestBuilder.build();
|
|
|
try (Response response = client.newCall(request).execute()) {
|
|
|
if (!response.isSuccessful()) {
|
|
|
String errorBody = response.body() != null ? response.body().string() : "无响应体";
|
|
|
- logger.error("获取模型列表失败,状态码: {}, 响应: {}", response.code(), errorBody);
|
|
|
- throw new IOException("获取模型列表失败,状态码: " + response.code());
|
|
|
+ logger.error("获取模型列表失败,请求URL: {}, 状态码: {}, 响应: {}", url, response.code(), errorBody);
|
|
|
+ throw new IOException("获取模型列表失败,状态码: " + response.code() + ", 响应: " + errorBody);
|
|
|
}
|
|
|
String responseBody = response.body().string();
|
|
|
logger.debug("模型列表响应: {}", responseBody);
|
|
|
@@ -71,6 +151,8 @@ public class XinferenceClient {
|
|
|
throw new IllegalArgumentException("inputTexts 不能为空");
|
|
|
}
|
|
|
|
|
|
+ // 拼接完整的 API URL: baseUrl + v1/embeddings
|
|
|
+ // 例如: https://192.168.10.124:5443/open/router/ + v1/embeddings
|
|
|
String url = baseUrl + "v1/embeddings";
|
|
|
Map<String, Object> json = new HashMap<>();
|
|
|
json.put("model", modelName);
|
|
|
@@ -81,13 +163,67 @@ public class XinferenceClient {
|
|
|
}
|
|
|
json.putAll(extraParams);
|
|
|
|
|
|
- logger.info("发送嵌入请求,model: {}, inputTexts: {}, extraParams: {}", modelName, inputTexts, extraParams);
|
|
|
- RequestBody body = RequestBody.create(objectMapper.writeValueAsString(json), MediaType.parse("application/json"));
|
|
|
- Request request = new Request.Builder()
|
|
|
+ // 构建请求体 JSON
|
|
|
+ String jsonBody = objectMapper.writeValueAsString(json);
|
|
|
+ RequestBody body = RequestBody.create(jsonBody, MediaType.parse("application/json"));
|
|
|
+
|
|
|
+ Request.Builder requestBuilder = new Request.Builder()
|
|
|
.url(url)
|
|
|
.post(body)
|
|
|
.addHeader("accept", "application/json")
|
|
|
- .build();
|
|
|
+ .addHeader("Content-Type", "application/json");
|
|
|
+
|
|
|
+ // 添加 API Key 认证头(Bearer 格式)
|
|
|
+ String trimmedApiKey = null;
|
|
|
+ if (apiKey != null && !apiKey.trim().isEmpty()) {
|
|
|
+ trimmedApiKey = apiKey.trim();
|
|
|
+ requestBuilder.addHeader("Authorization", "Bearer " + trimmedApiKey);
|
|
|
+ }
|
|
|
+
|
|
|
+ Request request = requestBuilder.build();
|
|
|
+
|
|
|
+ // ========== 打印请求前的所有参数 ==========
|
|
|
+ logger.info("========== 发送嵌入向量请求 ==========");
|
|
|
+ logger.info("请求URL: {}", url);
|
|
|
+ logger.info("请求方法: POST");
|
|
|
+ logger.info("请求参数:");
|
|
|
+ logger.info(" - model: {}", modelName);
|
|
|
+ logger.info(" - inputTexts数量: {}", inputTexts != null ? inputTexts.size() : 0);
|
|
|
+ if (inputTexts != null && !inputTexts.isEmpty()) {
|
|
|
+ for (int i = 0; i < Math.min(inputTexts.size(), 3); i++) {
|
|
|
+ String text = inputTexts.get(i);
|
|
|
+ String preview = text.length() > 50 ? text.substring(0, 50) + "..." : text;
|
|
|
+ logger.info(" - inputTexts[{}]: {}", i, preview);
|
|
|
+ }
|
|
|
+ if (inputTexts.size() > 3) {
|
|
|
+ logger.info(" - ... (还有 {} 条文本)", inputTexts.size() - 3);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ logger.info(" - extraParams: {}", extraParams);
|
|
|
+ logger.info("请求体JSON: {}", jsonBody);
|
|
|
+ logger.info("请求头信息:");
|
|
|
+ request.headers().forEach(header -> {
|
|
|
+ String headerName = header.getFirst();
|
|
|
+ String headerValue = header.getSecond();
|
|
|
+ if ("Authorization".equalsIgnoreCase(headerName)) {
|
|
|
+ // 隐藏 API Key 的中间部分
|
|
|
+ String maskedAuth = headerValue.length() > 20
|
|
|
+ ? headerValue.substring(0, 15) + "..." + headerValue.substring(headerValue.length() - 5)
|
|
|
+ : "****";
|
|
|
+ logger.info(" - {}: {}", headerName, maskedAuth);
|
|
|
+ } else {
|
|
|
+ logger.info(" - {}: {}", headerName, headerValue);
|
|
|
+ }
|
|
|
+ });
|
|
|
+ if (trimmedApiKey != null) {
|
|
|
+ String maskedKey = trimmedApiKey.length() > 8
|
|
|
+ ? trimmedApiKey.substring(0, 4) + "..." + trimmedApiKey.substring(trimmedApiKey.length() - 4)
|
|
|
+ : "****";
|
|
|
+ logger.info("API Key信息: 长度={}, 值(部分)={}", trimmedApiKey.length(), maskedKey);
|
|
|
+ } else {
|
|
|
+ logger.warn("API Key: 未提供(可能导致 401 认证失败)");
|
|
|
+ }
|
|
|
+ logger.info("========================================");
|
|
|
|
|
|
try (Response response = client.newCall(request).execute()) {
|
|
|
if (!response.isSuccessful()) {
|