dataset_utils.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch
  4. from llama_recipes.data.concatenator import ConcatDataset
  5. from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
  6. from llama_recipes.utils.config_utils import get_dataloader_kwargs
  7. from llama_recipes.datasets.translation_dataset import get_preprocessed_bitext
  8. from llama_recipes.datasets.monolingual_dataset import get_preprocessed_monolingual_data
  9. from llama_recipes.datasets.preference_dataset import get_preprocessed_preference_data
  10. from llama_recipes.datasets.preference_z_dataset import get_preprocessed_preference_z_data
  11. from llama_recipes.datasets.alma_r_preference_dataset import get_preprocessed_alma_r_preference_data
  12. from llama_recipes.datasets.alma_r_list_padding_dataset import get_preprocessed_alma_r_list_padding_data
  13. from llama_recipes.datasets.calibration_dataset import get_calibration_data
  14. from llama_recipes.datasets.enote_dataset import get_preprocessed_enote
  15. def get_enote_dataset(
  16. tokenizer, dataset_name, mode: str = "train", split: str = "train", rule_names: list = ("ryjlzs0001", "ryjlzs0002")
  17. ) -> torch.utils.data.Dataset:
  18. return get_preprocessed_enote(
  19. tokenizer,
  20. dataset_name,
  21. mode,
  22. split,
  23. rule_names,
  24. )
  25. def get_translation_dataset(
  26. tokenizer, dataset_name, mode: str = "train", split: str = "train", lang_pairs: list = ("en-de", "en-zh", "en-ar")
  27. ) -> torch.utils.data.Dataset:
  28. return get_preprocessed_bitext(
  29. tokenizer,
  30. dataset_name,
  31. mode,
  32. split,
  33. lang_pairs
  34. )
  35. def get_prefernce_dataset(
  36. tokenizer, dataset_name, subset_name: str = "wmt-da-17-22.csv", split: str = "train",
  37. lang_pairs: list = ("en-de", "en-zh"), mode: str = None,
  38. filter: str = None, listwise: bool = False, batch_size: int = 15) -> torch.utils.data.Dataset:
  39. if dataset_name == "da_dataset":
  40. return get_preprocessed_preference_data(
  41. tokenizer,
  42. dataset_name,
  43. subset_name,
  44. split,
  45. lang_pairs,
  46. filter,
  47. )
  48. elif dataset_name == "haoranxu/ALMA-R-Preference" and not listwise:
  49. # ALMA-R-Preference only has train set
  50. return get_preprocessed_alma_r_preference_data(
  51. tokenizer,
  52. dataset_name,
  53. "train",
  54. lang_pairs,
  55. mode,
  56. filter=filter
  57. )
  58. elif dataset_name == "haoranxu/ALMA-R-Preference" and listwise:
  59. return get_preprocessed_alma_r_list_padding_data(
  60. tokenizer,
  61. dataset_name,
  62. "train",
  63. lang_pairs,
  64. batch_size
  65. )
  66. elif dataset_name == "flores-gpt" and listwise:
  67. return get_calibration_data(
  68. tokenizer,
  69. dataset_name,
  70. subset_name,
  71. lang_pairs,
  72. batch_size
  73. )
  74. elif dataset_name == "calibration" and listwise:
  75. return get_calibration_data(
  76. tokenizer,
  77. dataset_name,
  78. subset_name,
  79. split,
  80. lang_pairs,
  81. batch_size
  82. )
  83. else:
  84. return None
  85. def get_prefernce_z_dataset(
  86. tokenizer, dataset_name, mode: str = "train", split: str = "train", lang_pairs: list = ("en-de", "en-zh")
  87. ) -> torch.utils.data.Dataset:
  88. return get_preprocessed_preference_z_data(
  89. tokenizer,
  90. dataset_name,
  91. mode,
  92. split,
  93. lang_pairs
  94. )
  95. def get_monolingual_dataset(
  96. tokenizer, dataset_config, mode: str = "train", split: str = "train", lang_pairs: list = ("en-de", "en-zh", "en-ar")
  97. ) -> torch.utils.data.Dataset:
  98. return get_preprocessed_monolingual_data(
  99. tokenizer,
  100. dataset_config,
  101. mode,
  102. split,
  103. lang_pairs
  104. )
  105. def get_preprocessed_dataset(
  106. tokenizer, dataset_config, split: str = "train", lang_pairs: list = ("en-de", "en-zh", "en-ar")
  107. ) -> torch.utils.data.Dataset:
  108. if not dataset_config.dataset in DATASET_PREPROC:
  109. raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")
  110. def get_split():
  111. return (
  112. dataset_config.train_split
  113. if split == "train"
  114. else dataset_config.test_split
  115. )
  116. return DATASET_PREPROC[dataset_config.dataset](
  117. dataset_config,
  118. tokenizer,
  119. get_split(),
  120. lang_pairs
  121. )
  122. def get_custom_data_collator(
  123. dataset_processer, dataset_config
  124. ) -> torch.utils.data.Dataset:
  125. if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
  126. return None
  127. return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
  128. dataset_processer,
  129. dataset_config
  130. )
  131. def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
  132. dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
  133. dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
  134. if split == "train" and train_config.batching_strategy == "packing":
  135. dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
  136. # Create data loader
  137. dataloader = torch.utils.data.DataLoader(
  138. dataset,
  139. num_workers=train_config.num_workers_dataloader,
  140. pin_memory=True,
  141. **dl_kwargs,
  142. )
  143. return dataloader