# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. # For dataset details visit: https://huggingface.co/datasets/samsum import copy import os import datasets import random from datasets import load_dataset, concatenate_datasets from llama_recipes.configs import ( fsdp_config as FSDP_CONFIG, quantization_config as QUANTIZATION_CONFIG, train_config as TRAIN_CONFIG, ) from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, ) from unittest.mock import patch from pathlib import Path import re import torch random.seed(42) def get_rule(rule_file): with open(rule_file, "r", encoding="utf-8") as f: content = f.read() match = re.search(r"规则描述:(.*)", content) if match: rule_description = match.group(1) else: return None return rule_description.strip() @patch('builtins.input', return_value="N") def load_bitext(dataset_name, split, rule_names, _): assert split in ["train", "valid", "test"], f"Unknown split: {split}" current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) dir_name = os.path.join(current_dir, "..", "customer_data", dataset_name, split) output_dataset = [] for rule_name in rule_names: rule_file = os.path.join(dir_name, rule_name, "rule.txt") rule_content = get_rule(rule_file) if rule_content: rule_wrong_content = "病例校验结果:{}".format(rule_content) rule_right_content = "病例校验结果:病例校验通过" enote_wrong_dir = os.path.join(dir_name, rule_name, "wrong") enote_wrong_files = [f.name for f in Path(enote_wrong_dir).iterdir() if f.is_file()] # get wrong file for enote_wrong_file in enote_wrong_files: with open(os.path.join(enote_wrong_dir, enote_wrong_file)) as fin: enote_wrong_content = fin.read().replace('\u2002', '') # Replace all occurrences of \u2002 (EN SPACE) row = {"rule": rule_name, "rule_content": rule_wrong_content, "enote_content": enote_wrong_content} output_dataset.append(row) # get right file enote_right_dir = os.path.join(dir_name, rule_name, "right") enote_right_files = [f.name for f in Path(enote_right_dir).iterdir() if f.is_file()] for enote_right_file in enote_right_files: if random.random() < 0.1 and split == "train": with open(os.path.join(enote_right_dir, enote_right_file)) as fin: enote_right_content = fin.read().replace('\u2002', '') # Replace all occurrences of \u2002 (EN SPACE) row = {"rule": rule_name, "rule_content": rule_right_content, "enote_content": enote_right_content} output_dataset.append(row) dataset = datasets.Dataset.from_list(output_dataset) return dataset def get_preprocessed_enote(tokenizer, dataset_name, mode, split, rule_names): dataset = load_bitext(dataset_name, split, rule_names) prompt = ( f"检查如下病例,输出可能存在的问题。若无问题,回复\"病例校验通过\":\n{{enote_content}}\n" ) def apply_prompt_template(sample): return { "prompt": prompt.format(enote_content=sample["enote_content"]), "summary": sample["rule_content"], } dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) def tokenize_add_label(sample): prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False) summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False) sample = { "input_ids": prompt + summary, "attention_mask": [1] * (len(prompt) + len(summary)), "labels": [-100] * len(prompt) + summary, } return sample def tokenize_prompt(sample): prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False) sample = { "input_ids": prompt, "attention_mask": [1] * len(prompt), } return sample if mode == "infer": dataset = dataset.map(tokenize_prompt, remove_columns=list(dataset.features)) elif mode == "eval": dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) else: # Train dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) dataset = dataset.shuffle(seed=42) return dataset # tokenizer = AutoTokenizer.from_pretrained('haoranxu/ALMA-7B-Pretrain') # dataset = get_preprocessed_enote(tokenizer, 'enote_dataset', 'train', 'train', ["ryjlzs0001"]) # # print(dataset) # print(dataset[:10]) # train_dataloader = torch.utils.data.DataLoader( # dataset, # batch_size=1, # num_workers=0, # pin_memory=True, # # collate_fn=lambda x: { # # "input_ids": torch.stack([torch.tensor(item["input_ids"]) for item in x], dim=0), # # "attention_mask": torch.stack([torch.tensor(item["attention_mask"]) for item in x], dim=0), # # "labels": torch.stack([torch.tensor(item["labels"]) for item in x], dim=0), # # }, # )