Browse Source

tts以及stt模型调用接口

zhaohan 6 months ago
parent
commit
93d7fb2e77

+ 3 - 0
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/KnowledgeAttach.java

@@ -84,6 +84,9 @@ public class KnowledgeAttach extends BaseEntity {
    * 文件保存地址
    */
   private String publicUrl;
+  /**
+   * 文件本地保存地址
+   */
   private String filePath;
 
 

+ 3 - 0
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/domain/MedicalRecordView.java

@@ -42,3 +42,6 @@ public class MedicalRecordView implements Serializable {
 
 
 
+
+
+

+ 3 - 0
ruoyi-modules-api/ruoyi-knowledge-api/src/main/java/org/ruoyi/mapper/MedicalRecordViewMapper.java

@@ -10,3 +10,6 @@ public interface MedicalRecordViewMapper extends BaseMapperPlus<MedicalRecordVie
 
 
 
+
+
+

+ 14 - 0
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/controller/chat/ChatController.java

@@ -28,6 +28,11 @@ import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
 @Slf4j
 @RequiredArgsConstructor
 @RequestMapping("/chat")
+@CrossOrigin(
+        origins = "${cors.file-parse.origins}",
+        allowedHeaders = "*",
+        methods = {RequestMethod.GET, RequestMethod.POST, RequestMethod.PUT, RequestMethod.DELETE}
+)
 public class ChatController {
 
     private final ISseService sseService;
@@ -82,4 +87,13 @@ public class ChatController {
         return sseService.textToSpeed(textToSpeech);
     }
 
+    /**
+     * 文本转语音代理(解决跨域)
+     */
+    @PostMapping("/speech/proxy")
+    @ResponseBody
+    public ResponseEntity<Resource> speechProxy(@RequestBody TextToSpeech textToSpeech) {
+        return sseService.proxySpeech(textToSpeech);
+    }
+
 }

+ 31 - 2
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/MedicalRecordQc123Service.java

@@ -675,6 +675,12 @@ public class MedicalRecordQc123Service {
             return;
         }
 
+        // 记录每个日期原始病程记录条数,用于后续输出时区分“病程记录部分”和“医嘱部分”
+        Map<String, Integer> courseRecordCountMap = new HashMap<>();
+        for (Map.Entry<String, List<String>> entry : mergedByDate.entrySet()) {
+            courseRecordCountMap.put(entry.getKey(), entry.getValue() == null ? 0 : entry.getValue().size());
+        }
+
         // 2. 解析医嘱记录,并融合到对应日期的病程记录中
         if (rule.getQcTypeList().contains(13)) {
             QcStep step = stepMap.get(13);
@@ -707,11 +713,34 @@ public class MedicalRecordQc123Service {
 
             for (String date : batchDates) {
                 List<String> records = mergedByDate.get(date);
+                if (records == null || records.isEmpty()) {
+                    continue;
+                }
+
                 batchContent.append("=== ").append(date).append(" ===\n\n");
 
-                for (int i = 0; i < records.size(); i++) {
-                    batchContent.append((i + 1)).append(". ").append(records.get(i)).append("\n\n");
+                // 根据原始病程记录条数,拆分为“病程记录部分”和“医嘱部分”
+                int courseCount = courseRecordCountMap.getOrDefault(date, records.size());
+                if (courseCount < 0 || courseCount > records.size()) {
+                    courseCount = Math.min(Math.max(courseCount, 0), records.size());
                 }
+
+                int number = 1;
+
+                if (courseCount > 0) {
+                    batchContent.append("以下是病程记录部分:").append("\n\n");
+                    for (int i = 0; i < courseCount && i < records.size(); i++) {
+                        batchContent.append(number++).append(". ").append(records.get(i)).append("\n\n");
+                    }
+                }
+
+                if (courseCount < records.size()) {
+                    batchContent.append("以下是医嘱部分:").append("\n\n");
+                    for (int i = courseCount; i < records.size(); i++) {
+                        batchContent.append(number++).append(". ").append(records.get(i)).append("\n\n");
+                    }
+                }
+
                 batchContent.append("\n");
             }
 

+ 5 - 0
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/ISseService.java

@@ -54,6 +54,11 @@ public interface ISseService {
      */
     ResponseEntity<Resource> textToSpeed(TextToSpeech textToSpeech);
 
+    /**
+     * 文字转语音(代理外部TTS)
+     */
+    ResponseEntity<Resource> proxySpeech(TextToSpeech textToSpeech);
+
     /**
      * 上传文件到服务器
      *

+ 123 - 1
ruoyi-modules/ruoyi-chat/src/main/java/org/ruoyi/chat/service/chat/impl/SseServiceImpl.java

@@ -19,6 +19,7 @@ import org.ruoyi.common.chat.entity.chat.Message;
 import org.ruoyi.common.chat.entity.chat.SchemaMessage;
 import org.ruoyi.common.chat.entity.chat.ThinkProject;
 import org.ruoyi.common.chat.entity.files.UploadFileResponse;
+import org.ruoyi.common.chat.entity.whisper.Transcriptions;
 import org.ruoyi.common.chat.entity.whisper.WhisperResponse;
 import org.ruoyi.common.chat.openai.OpenAiStreamClient;
 import org.ruoyi.common.chat.request.ChatRequest;
@@ -28,6 +29,7 @@ import org.ruoyi.common.core.utils.StringUtils;
 import org.ruoyi.common.core.utils.file.FileUtils;
 import org.ruoyi.common.core.utils.file.MimeTypeUtils;
 import org.ruoyi.common.satoken.utils.LoginHelper;
+import org.ruoyi.common.core.service.ConfigService;
 import org.ruoyi.core.page.PageQuery;
 import org.ruoyi.core.page.TableDataInfo;
 import org.ruoyi.domain.ThinkModel;
@@ -44,14 +46,19 @@ import org.ruoyi.service.IChatSessionService;
 import org.ruoyi.service.IKnowledgeInfoService;
 import org.ruoyi.service.VectorStoreService;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.core.io.ByteArrayResource;
 import org.springframework.core.io.InputStreamResource;
 import org.springframework.core.io.Resource;
 import org.springframework.http.MediaType;
 import org.springframework.http.ResponseEntity;
 import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.http.HttpStatus;
 import org.springframework.stereotype.Service;
+import org.springframework.http.HttpEntity;
+import org.springframework.http.HttpHeaders;
 import org.springframework.web.multipart.MultipartFile;
 import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+import org.springframework.web.client.RestTemplate;
 
 import java.io.File;
 import java.io.FileOutputStream;
@@ -60,6 +67,7 @@ import java.io.InputStream;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -85,8 +93,12 @@ public class SseServiceImpl implements ISseService {
 
     private final IKnowledgeInfoService knowledgeInfoService;
 
+    private final RestTemplate restTemplate;
+
     private ChatModelVo chatModelVo;
 
+    private final ConfigService configService;
+
     @Autowired
     private JdbcTemplate jdbcTemplate;
 
@@ -99,6 +111,9 @@ public class SseServiceImpl implements ISseService {
 
     private String rerankModelName = "bge-reranker-v2-m3";
 
+    private static final String DEFAULT_TTS_PROXY_URL = "http://1.13.255.120:7899/v1/audio/speech";
+    private static final String DEFAULT_TTS_VOICE = "zh-CN-XiaoxiaoNeural";
+
     @Override
     public SseEmitter sseChat(ChatRequest chatRequest, HttpServletRequest request) {
         SseEmitter sseEmitter = new SseEmitter(0L);
@@ -382,6 +397,47 @@ public class SseServiceImpl implements ISseService {
         }
     }
 
+    /**
+     * 自定义TTS代理,解决前端跨域
+     */
+    @Override
+    public ResponseEntity<Resource> proxySpeech(TextToSpeech textToSpeech) {
+        if (textToSpeech == null || StringUtils.isEmpty(textToSpeech.getInput())) {
+            throw new IllegalArgumentException("input 不能为空");
+        }
+        String proxyUrl = resolveTtsProxyUrl();
+        String voice = StringUtils.isNotBlank(textToSpeech.getVoice()) ? textToSpeech.getVoice() : resolveTtsDefaultVoice();
+        String responseFormat = determineResponseFormat(textToSpeech);
+
+        Map<String, Object> payload = new HashMap<>();
+        payload.put("input", textToSpeech.getInput());
+        payload.put("voice", voice);
+        payload.put("response_format", responseFormat);
+        if (textToSpeech.getSpeed() != null) {
+            payload.put("speed", textToSpeech.getSpeed());
+        }
+        if (StringUtils.isNotBlank(textToSpeech.getModel())) {
+            payload.put("model", textToSpeech.getModel());
+        }
+
+        HttpHeaders headers = new HttpHeaders();
+        headers.setContentType(MediaType.APPLICATION_JSON);
+        HttpEntity<Map<String, Object>> entity = new HttpEntity<>(payload, headers);
+        ResponseEntity<byte[]> response = restTemplate.postForEntity(proxyUrl, entity, byte[].class);
+        if (!response.getStatusCode().is2xxSuccessful() || response.getBody() == null) {
+            throw new IllegalStateException("TTS代理服务调用失败:" + response.getStatusCode());
+        }
+        byte[] body = response.getBody();
+        MediaType mediaType = resolveMediaType(responseFormat, response.getHeaders().getContentType());
+        String filename = "speech." + responseFormat.toLowerCase();
+        return ResponseEntity.status(HttpStatus.OK)
+                .contentType(mediaType)
+                .contentLength(body.length)
+                .header(HttpHeaders.CONTENT_DISPOSITION, "inline; filename=" + filename)
+                .header(HttpHeaders.ACCEPT_RANGES, "bytes")
+                .body(new ByteArrayResource(body));
+    }
+
     /**
      * 语音转文字
      */
@@ -402,7 +458,23 @@ public class SseServiceImpl implements ISseService {
         } catch (IOException e) {
             throw new RuntimeException("Failed to convert MultipartFile to File", e);
         }
-        return openAiStreamClient.speechToTextTranscriptions(fileA);
+        Transcriptions transcriptions = Transcriptions.builder()
+                .model(resolveAudioModel())
+                .build();
+        return openAiStreamClient.speechToTextTranscriptions(fileA, transcriptions);
+    }
+
+    private String resolveAudioModel() {
+        String defaultModel = "whisper-large-v3";
+        try {
+            String audioModel = configService.getConfigValue("chat", "audioModel");
+            if (StringUtils.isNotBlank(audioModel)) {
+                return audioModel;
+            }
+        } catch (Exception e) {
+            log.warn("音频模型配置获取失败,使用默认模型 {} :{}", defaultModel, e.getMessage());
+        }
+        return defaultModel;
     }
 
 
@@ -452,6 +524,56 @@ public class SseServiceImpl implements ISseService {
         return file;
     }
 
+    private String resolveTtsProxyUrl() {
+        try {
+            String configValue = configService.getConfigValue("chat", "ttsProxyUrl");
+            if (StringUtils.isNotBlank(configValue)) {
+                return configValue;
+            }
+        } catch (Exception e) {
+            log.warn("获取ttsProxyUrl失败,使用默认地址 {} :{}", DEFAULT_TTS_PROXY_URL, e.getMessage());
+        }
+        return DEFAULT_TTS_PROXY_URL;
+    }
+
+    private String resolveTtsDefaultVoice() {
+        try {
+            String configValue = configService.getConfigValue("chat", "ttsDefaultVoice");
+            if (StringUtils.isNotBlank(configValue)) {
+                return configValue;
+            }
+        } catch (Exception e) {
+            log.warn("获取ttsDefaultVoice失败,使用默认值 {} :{}", DEFAULT_TTS_VOICE, e.getMessage());
+        }
+        return DEFAULT_TTS_VOICE;
+    }
+
+    private String determineResponseFormat(TextToSpeech textToSpeech) {
+        String format = textToSpeech.getResponseFormat();
+        if (StringUtils.isBlank(format)) {
+            format = "mp3";
+        }
+        textToSpeech.setResponseFormat(format);
+        return format;
+    }
+
+    private MediaType resolveMediaType(String responseFormat, MediaType upstreamType) {
+        String format = responseFormat != null ? responseFormat.toLowerCase() : "";
+        if (format.contains("mp3") || format.contains("mpeg")) {
+            return MediaType.parseMediaType("audio/mpeg");
+        }
+        if (format.contains("wav")) {
+            return MediaType.parseMediaType("audio/wav");
+        }
+        if (format.contains("ogg")) {
+            return MediaType.parseMediaType("audio/ogg");
+        }
+        if (upstreamType != null) {
+            return upstreamType;
+        }
+        return MediaType.APPLICATION_OCTET_STREAM;
+    }
+
     @Override
     public SchemaMessage getSchema(SchemaRequest schemaRequest, HttpServletRequest request) throws IOException {
         SchemaMessage schema = new SchemaMessage();