| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import torch
- from llama_recipes.data.concatenator import ConcatDataset
- from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
- from llama_recipes.utils.config_utils import get_dataloader_kwargs
- from llama_recipes.datasets.translation_dataset import get_preprocessed_bitext
- from llama_recipes.datasets.monolingual_dataset import get_preprocessed_monolingual_data
- from llama_recipes.datasets.preference_dataset import get_preprocessed_preference_data
- from llama_recipes.datasets.preference_z_dataset import get_preprocessed_preference_z_data
- from llama_recipes.datasets.alma_r_preference_dataset import get_preprocessed_alma_r_preference_data
- from llama_recipes.datasets.alma_r_list_padding_dataset import get_preprocessed_alma_r_list_padding_data
- from llama_recipes.datasets.calibration_dataset import get_calibration_data
- from llama_recipes.datasets.enote_dataset import get_preprocessed_enote
- def get_enote_dataset(
- tokenizer, dataset_name, mode: str = "train", split: str = "train", rule_names: list = ("ryjlzs0001", "ryjlzs0002")
- ) -> torch.utils.data.Dataset:
- return get_preprocessed_enote(
- tokenizer,
- dataset_name,
- mode,
- split,
- rule_names,
- )
- def get_translation_dataset(
- tokenizer, dataset_name, mode: str = "train", split: str = "train", lang_pairs: list = ("en-de", "en-zh", "en-ar")
- ) -> torch.utils.data.Dataset:
- return get_preprocessed_bitext(
- tokenizer,
- dataset_name,
- mode,
- split,
- lang_pairs
- )
- def get_prefernce_dataset(
- tokenizer, dataset_name, subset_name: str = "wmt-da-17-22.csv", split: str = "train",
- lang_pairs: list = ("en-de", "en-zh"), mode: str = None,
- filter: str = None, listwise: bool = False, batch_size: int = 15) -> torch.utils.data.Dataset:
- if dataset_name == "da_dataset":
- return get_preprocessed_preference_data(
- tokenizer,
- dataset_name,
- subset_name,
- split,
- lang_pairs,
- filter,
- )
- elif dataset_name == "haoranxu/ALMA-R-Preference" and not listwise:
- # ALMA-R-Preference only has train set
- return get_preprocessed_alma_r_preference_data(
- tokenizer,
- dataset_name,
- "train",
- lang_pairs,
- mode,
- filter=filter
- )
- elif dataset_name == "haoranxu/ALMA-R-Preference" and listwise:
- return get_preprocessed_alma_r_list_padding_data(
- tokenizer,
- dataset_name,
- "train",
- lang_pairs,
- batch_size
- )
- elif dataset_name == "flores-gpt" and listwise:
- return get_calibration_data(
- tokenizer,
- dataset_name,
- subset_name,
- lang_pairs,
- batch_size
- )
- elif dataset_name == "calibration" and listwise:
- return get_calibration_data(
- tokenizer,
- dataset_name,
- subset_name,
- split,
- lang_pairs,
- batch_size
- )
- else:
- return None
- def get_prefernce_z_dataset(
- tokenizer, dataset_name, mode: str = "train", split: str = "train", lang_pairs: list = ("en-de", "en-zh")
- ) -> torch.utils.data.Dataset:
- return get_preprocessed_preference_z_data(
- tokenizer,
- dataset_name,
- mode,
- split,
- lang_pairs
- )
- def get_monolingual_dataset(
- tokenizer, dataset_config, mode: str = "train", split: str = "train", lang_pairs: list = ("en-de", "en-zh", "en-ar")
- ) -> torch.utils.data.Dataset:
- return get_preprocessed_monolingual_data(
- tokenizer,
- dataset_config,
- mode,
- split,
- lang_pairs
- )
- def get_preprocessed_dataset(
- tokenizer, dataset_config, split: str = "train", lang_pairs: list = ("en-de", "en-zh", "en-ar")
- ) -> torch.utils.data.Dataset:
- if not dataset_config.dataset in DATASET_PREPROC:
- raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
- def get_split():
- return (
- dataset_config.train_split
- if split == "train"
- else dataset_config.test_split
- )
- return DATASET_PREPROC[dataset_config.dataset](
- dataset_config,
- tokenizer,
- get_split(),
- lang_pairs
- )
- def get_custom_data_collator(
- dataset_processer, dataset_config
- ) -> torch.utils.data.Dataset:
- if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
- return None
- return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
- dataset_processer,
- dataset_config
- )
- def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
- dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
- dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
-
- if split == "train" and train_config.batching_strategy == "packing":
- dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
- # Create data loader
- dataloader = torch.utils.data.DataLoader(
- dataset,
- num_workers=train_config.num_workers_dataloader,
- pin_memory=True,
- **dl_kwargs,
- )
- return dataloader
-
|