nlp_server.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. """
  2. 医疗NLP服务
  3. 功能: 将医生口述文本结构化为医疗实体
  4. """
  5. from flask import Flask, request, jsonify
  6. from flask_cors import CORS
  7. import torch
  8. from transformers import AutoTokenizer, AutoModelForTokenClassification
  9. import re
  10. app = Flask(__name__)
  11. CORS(app) # 允许跨域请求
  12. # 全局变量存储模型
  13. tokenizer = None
  14. model = None
  15. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  16. # 医疗实体标签映射
  17. LABEL_MAP = {
  18. 'B-CHIEF': '主诉',
  19. 'I-CHIEF': '主诉',
  20. 'B-HISTORY': '现病史',
  21. 'I-HISTORY': '现病史',
  22. 'B-DIAGNOSIS': '诊断',
  23. 'I-DIAGNOSIS': '诊断',
  24. 'B-MEDICATION': '用药',
  25. 'I-MEDICATION': '用药',
  26. 'O': '其他'
  27. }
  28. def load_model():
  29. """
  30. 加载医疗NER模型
  31. 这里使用示例,实际需要替换为真实的医疗模型
  32. """
  33. global tokenizer, model
  34. print('正在加载医疗NER模型...')
  35. # 方式1: 使用HuggingFace上的中文医疗模型
  36. # model_name = "HuatGPT/HuatGPT-medical-ner"
  37. # 方式2: 使用本地训练的模型
  38. model_path = "./models/medical-ner"
  39. try:
  40. tokenizer = AutoTokenizer.from_pretrained(model_path)
  41. model = AutoModelForTokenClassification.from_pretrained(model_path)
  42. model.to(device)
  43. model.eval()
  44. print(f'模型加载成功,使用设备: {device}')
  45. except Exception as e:
  46. print(f'模型加载失败: {e}')
  47. print('使用规则匹配作为降级方案...')
  48. def extract_entities_with_model(text):
  49. """
  50. 使用NER模型提取医疗实体
  51. """
  52. if model is None or tokenizer is None:
  53. raise Exception("模型未加载")
  54. # 分词
  55. inputs = tokenizer(
  56. text,
  57. return_tensors="pt",
  58. truncation=True,
  59. max_length=512,
  60. padding=True
  61. )
  62. inputs = {k: v.to(device) for k, v in inputs.items()}
  63. # 预测
  64. with torch.no_grad():
  65. outputs = model(**inputs)
  66. predictions = torch.argmax(outputs.logits, dim=2)
  67. # 解析实体
  68. tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
  69. labels = [LABEL_MAP.get(f"B-{p.item()}" if p.item() % 2 == 1 else f"I-{p.item()}", '其他')
  70. for p in predictions[0]]
  71. # 组合实体
  72. entities = {}
  73. current_entity = None
  74. current_text = ''
  75. for token, label in zip(tokens, labels):
  76. if label == '其他':
  77. if current_entity:
  78. entities.setdefault(current_entity, []).append(current_text.strip())
  79. current_entity = None
  80. current_text = ''
  81. else:
  82. if label != current_entity:
  83. if current_entity:
  84. entities.setdefault(current_entity, []).append(current_text.strip())
  85. current_entity = label
  86. current_text = token.replace('##', '')
  87. else:
  88. current_text += token.replace('##', '')
  89. # 最后一个实体
  90. if current_entity:
  91. entities.setdefault(current_entity, []).append(current_text.strip())
  92. return entities
  93. def extract_entities_with_rules(text):
  94. """
  95. 基于规则的医疗实体提取(降级方案)
  96. 适用于没有NER模型的情况
  97. """
  98. entities = {
  99. '主诉': '',
  100. '现病史': '',
  101. '诊断': '',
  102. '用药': ''
  103. }
  104. # 简单规则: 通过关键词分割
  105. patterns = {
  106. '主诉': r'(?:主诉|患者因)[,,。]*(.*?)(?:现病史|病史|诊断|$)',
  107. '现病史': r'(?:现病史|病史)[,,。]*(.*?)(?:诊断|用药|体检|$)',
  108. '诊断': r'(?:诊断|诊断为)[,,。]*(.*?)(?:用药|治疗|建议|$)',
  109. '用药': r'(?:用药|服用)[,,。]*(.*?)(?:建议|注意事项|$)'
  110. }
  111. for field, pattern in patterns.items():
  112. match = re.search(pattern, text, re.IGNORECASE)
  113. if match:
  114. entities[field] = match.group(1).strip()
  115. return entities
  116. @app.route('/extract', methods=['POST'])
  117. def extract_entities():
  118. """
  119. 接口: 提取医疗实体
  120. 请求示例:
  121. {
  122. "text": "患者主诉头痛三天,现病史显示有发热症状,诊断为上呼吸道感染,用药为阿司匹林"
  123. }
  124. 响应示例:
  125. {
  126. "主诉": "头痛三天",
  127. "现病史": "有发热症状",
  128. "诊断": "上呼吸道感染",
  129. "用药": "阿司匹林"
  130. }
  131. """
  132. try:
  133. data = request.json
  134. text = data.get('text', '')
  135. if not text:
  136. return jsonify({'error': '文本不能为空'}), 400
  137. # 优先使用模型,失败则使用规则
  138. try:
  139. entities = extract_entities_with_model(text)
  140. except Exception as e:
  141. print(f'模型推理失败,使用规则匹配: {e}')
  142. entities = extract_entities_with_rules(text)
  143. # 转换为前端期望的格式
  144. result = {}
  145. for field, values in entities.items():
  146. if isinstance(values, list):
  147. result[field] = ' '.join(values) if values else ''
  148. else:
  149. result[field] = values
  150. print(f'提取结果: {result}')
  151. return jsonify(result)
  152. except Exception as e:
  153. print(f'处理错误: {e}')
  154. return jsonify({'error': str(e)}), 500
  155. @app.route('/health', methods=['GET'])
  156. def health_check():
  157. """健康检查接口"""
  158. return jsonify({
  159. 'status': 'ok',
  160. 'model_loaded': model is not None,
  161. 'device': str(device)
  162. })
  163. if __name__ == '__main__':
  164. # 加载模型
  165. load_model()
  166. # 启动服务
  167. print('医疗NLP服务启动中...')
  168. print('监听地址: http://0.0.0.0:5001')
  169. app.run(host='0.0.0.0', port=5001, debug=False)