| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- # For dataset details visit: https://huggingface.co/datasets/samsum
- import copy
- import os
- import datasets
- import random
- from datasets import load_dataset, concatenate_datasets
- from llama_recipes.configs import (
- fsdp_config as FSDP_CONFIG,
- quantization_config as QUANTIZATION_CONFIG,
- train_config as TRAIN_CONFIG,
- )
- from transformers import (
- AutoConfig,
- AutoProcessor,
- AutoTokenizer,
- )
- from unittest.mock import patch
- from pathlib import Path
- import re
- import torch
- random.seed(42)
- def get_rule(rule_file):
- with open(rule_file, "r", encoding="utf-8") as f:
- content = f.read()
- match = re.search(r"规则描述:(.*)", content)
- if match:
- rule_description = match.group(1)
- else:
- return None
- return rule_description.strip()
- @patch('builtins.input', return_value="N")
- def load_bitext(dataset_name, split, rule_names, _):
- assert split in ["train", "valid", "test"], f"Unknown split: {split}"
- current_file_path = os.path.abspath(__file__)
- current_dir = os.path.dirname(current_file_path)
- dir_name = os.path.join(current_dir, "..", "customer_data", dataset_name, split)
- output_dataset = []
- for rule_name in rule_names:
- rule_file = os.path.join(dir_name, rule_name, "rule.txt")
- rule_content = get_rule(rule_file)
- if rule_content:
- rule_wrong_content = "病例校验结果:{}".format(rule_content)
- rule_right_content = "病例校验结果:病例校验通过"
- enote_wrong_dir = os.path.join(dir_name, rule_name, "wrong")
- enote_wrong_files = [f.name for f in Path(enote_wrong_dir).iterdir() if f.is_file()]
- # get wrong file
- for enote_wrong_file in enote_wrong_files:
- with open(os.path.join(enote_wrong_dir, enote_wrong_file)) as fin:
- enote_wrong_content = fin.read().replace('\u2002', '') # Replace all occurrences of \u2002 (EN SPACE)
- row = {"rule": rule_name, "rule_content": rule_wrong_content, "enote_content": enote_wrong_content}
- output_dataset.append(row)
- # get right file
- enote_right_dir = os.path.join(dir_name, rule_name, "right")
- enote_right_files = [f.name for f in Path(enote_right_dir).iterdir() if f.is_file()]
- for enote_right_file in enote_right_files:
- if random.random() < 0.1 and split == "train":
- with open(os.path.join(enote_right_dir, enote_right_file)) as fin:
- enote_right_content = fin.read().replace('\u2002', '') # Replace all occurrences of \u2002 (EN SPACE)
- row = {"rule": rule_name, "rule_content": rule_right_content, "enote_content": enote_right_content}
- output_dataset.append(row)
- dataset = datasets.Dataset.from_list(output_dataset)
- return dataset
- def get_preprocessed_enote(tokenizer, dataset_name, mode, split, rule_names):
- dataset = load_bitext(dataset_name, split, rule_names)
- prompt = (
- f"检查如下病例,输出可能存在的问题。若无问题,回复\"病例校验通过\":\n{{enote_content}}\n"
- )
- def apply_prompt_template(sample):
- return {
- "prompt": prompt.format(enote_content=sample["enote_content"]),
- "summary": sample["rule_content"],
- }
- dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
- def tokenize_add_label(sample):
- prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
- summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False)
- sample = {
- "input_ids": prompt + summary,
- "attention_mask": [1] * (len(prompt) + len(summary)),
- "labels": [-100] * len(prompt) + summary,
- }
- return sample
- def tokenize_prompt(sample):
- prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
- sample = {
- "input_ids": prompt,
- "attention_mask": [1] * len(prompt),
- }
- return sample
- if mode == "infer":
- dataset = dataset.map(tokenize_prompt, remove_columns=list(dataset.features))
- elif mode == "eval":
- dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
- else: # Train
- dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
- dataset = dataset.shuffle(seed=42)
- return dataset
- # tokenizer = AutoTokenizer.from_pretrained('haoranxu/ALMA-7B-Pretrain')
- # dataset = get_preprocessed_enote(tokenizer, 'enote_dataset', 'train', 'train', ["ryjlzs0001"])
- #
- # print(dataset)
- # print(dataset[:10])
- # train_dataloader = torch.utils.data.DataLoader(
- # dataset,
- # batch_size=1,
- # num_workers=0,
- # pin_memory=True,
- # # collate_fn=lambda x: {
- # # "input_ids": torch.stack([torch.tensor(item["input_ids"]) for item in x], dim=0),
- # # "attention_mask": torch.stack([torch.tensor(item["attention_mask"]) for item in x], dim=0),
- # # "labels": torch.stack([torch.tensor(item["labels"]) for item in x], dim=0),
- # # },
- # )
|