enote_dataset.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # For dataset details visit: https://huggingface.co/datasets/samsum
  4. import copy
  5. import os
  6. import datasets
  7. import random
  8. from datasets import load_dataset, concatenate_datasets
  9. from llama_recipes.configs import (
  10. fsdp_config as FSDP_CONFIG,
  11. quantization_config as QUANTIZATION_CONFIG,
  12. train_config as TRAIN_CONFIG,
  13. )
  14. from transformers import (
  15. AutoConfig,
  16. AutoProcessor,
  17. AutoTokenizer,
  18. )
  19. from unittest.mock import patch
  20. from pathlib import Path
  21. import re
  22. import torch
  23. random.seed(42)
  24. def get_rule(rule_file):
  25. with open(rule_file, "r", encoding="utf-8") as f:
  26. content = f.read()
  27. match = re.search(r"规则描述:(.*)", content)
  28. if match:
  29. rule_description = match.group(1)
  30. else:
  31. return None
  32. return rule_description.strip()
  33. @patch('builtins.input', return_value="N")
  34. def load_bitext(dataset_name, split, rule_names, _):
  35. assert split in ["train", "valid", "test"], f"Unknown split: {split}"
  36. current_file_path = os.path.abspath(__file__)
  37. current_dir = os.path.dirname(current_file_path)
  38. dir_name = os.path.join(current_dir, "..", "customer_data", dataset_name, split)
  39. output_dataset = []
  40. for rule_name in rule_names:
  41. rule_file = os.path.join(dir_name, rule_name, "rule.txt")
  42. rule_content = get_rule(rule_file)
  43. if rule_content:
  44. rule_wrong_content = "病例校验结果:{}".format(rule_content)
  45. rule_right_content = "病例校验结果:病例校验通过"
  46. enote_wrong_dir = os.path.join(dir_name, rule_name, "wrong")
  47. enote_wrong_files = [f.name for f in Path(enote_wrong_dir).iterdir() if f.is_file()]
  48. # get wrong file
  49. for enote_wrong_file in enote_wrong_files:
  50. with open(os.path.join(enote_wrong_dir, enote_wrong_file)) as fin:
  51. enote_wrong_content = fin.read().replace('\u2002', '') # Replace all occurrences of \u2002 (EN SPACE)
  52. row = {"rule": rule_name, "rule_content": rule_wrong_content, "enote_content": enote_wrong_content}
  53. output_dataset.append(row)
  54. # get right file
  55. enote_right_dir = os.path.join(dir_name, rule_name, "right")
  56. enote_right_files = [f.name for f in Path(enote_right_dir).iterdir() if f.is_file()]
  57. for enote_right_file in enote_right_files:
  58. if random.random() < 0.1 and split == "train":
  59. with open(os.path.join(enote_right_dir, enote_right_file)) as fin:
  60. enote_right_content = fin.read().replace('\u2002', '') # Replace all occurrences of \u2002 (EN SPACE)
  61. row = {"rule": rule_name, "rule_content": rule_right_content, "enote_content": enote_right_content}
  62. output_dataset.append(row)
  63. dataset = datasets.Dataset.from_list(output_dataset)
  64. return dataset
  65. def get_preprocessed_enote(tokenizer, dataset_name, mode, split, rule_names):
  66. dataset = load_bitext(dataset_name, split, rule_names)
  67. prompt = (
  68. f"检查如下病例,输出可能存在的问题。若无问题,回复\"病例校验通过\":\n{{enote_content}}\n"
  69. )
  70. def apply_prompt_template(sample):
  71. return {
  72. "prompt": prompt.format(enote_content=sample["enote_content"]),
  73. "summary": sample["rule_content"],
  74. }
  75. dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
  76. def tokenize_add_label(sample):
  77. prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
  78. summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
  79. sample = {
  80. "input_ids": prompt + summary,
  81. "attention_mask": [1] * (len(prompt) + len(summary)),
  82. "labels": [-100] * len(prompt) + summary,
  83. }
  84. return sample
  85. def tokenize_prompt(sample):
  86. prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
  87. sample = {
  88. "input_ids": prompt,
  89. "attention_mask": [1] * len(prompt),
  90. }
  91. return sample
  92. if mode == "infer":
  93. dataset = dataset.map(tokenize_prompt, remove_columns=list(dataset.features))
  94. elif mode == "eval":
  95. dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
  96. else: # Train
  97. dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
  98. dataset = dataset.shuffle(seed=42)
  99. return dataset
  100. # tokenizer = AutoTokenizer.from_pretrained('haoranxu/ALMA-7B-Pretrain')
  101. # dataset = get_preprocessed_enote(tokenizer, 'enote_dataset', 'train', 'train', ["ryjlzs0001"])
  102. #
  103. # print(dataset)
  104. # print(dataset[:10])
  105. # train_dataloader = torch.utils.data.DataLoader(
  106. # dataset,
  107. # batch_size=1,
  108. # num_workers=0,
  109. # pin_memory=True,
  110. # # collate_fn=lambda x: {
  111. # # "input_ids": torch.stack([torch.tensor(item["input_ids"]) for item in x], dim=0),
  112. # # "attention_mask": torch.stack([torch.tensor(item["attention_mask"]) for item in x], dim=0),
  113. # # "labels": torch.stack([torch.tensor(item["labels"]) for item in x], dim=0),
  114. # # },
  115. # )