inference_formal.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import shutil
  5. import sys
  6. import time
  7. import fire
  8. import torch
  9. from tqdm import tqdm
  10. from accelerate.utils import is_xpu_available
  11. from llama_recipes.inference.model_utils import load_model, load_peft_model
  12. from llama_recipes.inference.safety_utils import AgentType, get_safety_checker
  13. from transformers import AutoTokenizer
  14. from llama_recipes.data.concatenator import ConcatDataset
  15. from llama_recipes.utils.dataset_utils import get_enote_dataset
  16. from llama_recipes.configs import (
  17. fsdp_config as FSDP_CONFIG,
  18. quantization_config as QUANTIZATION_CONFIG,
  19. train_config as TRAIN_CONFIG,
  20. )
  21. from llama_recipes.utils.config_utils import (
  22. check_fsdp_config,
  23. generate_dataset_config,
  24. generate_peft_config,
  25. get_dataloader_kwargs,
  26. update_config,
  27. )
  28. def create_clean_dir(path):
  29. """
  30. Create a clean directory. If the directory exists, remove it first.
  31. :param path: Path of the directory to create.
  32. """
  33. # Remove the directory if it exists
  34. if os.path.exists(path):
  35. shutil.rmtree(path)
  36. # Create the directory
  37. os.makedirs(path)
  38. def main(
  39. model_name,
  40. peft_model: str = None,
  41. quantization: str = None, # Options: 4bit, 8bit
  42. max_new_tokens=1000, # The maximum numbers of tokens to generate
  43. prompt_file: str = None,
  44. seed: int = 42, # seed value for reproducibility
  45. do_sample: bool = True, # Whether or not to use sampling ; use greedy decoding otherwise.
  46. min_length: int = None, # The minimum length of the sequence to be generated, input prompt + min_new_tokens
  47. use_cache: bool = True,
  48. # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
  49. top_p: float = 1.0,
  50. # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
  51. temperature: float = 1.0, # [optional] The value used to modulate the next token probabilities.
  52. top_k: int = 50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
  53. repetition_penalty: float = 1.0, # The parameter for repetition penalty. 1.0 means no penalty.
  54. length_penalty: int = 1,
  55. # [optional] Exponential penalty to the length that is used with beam-based generation.
  56. enable_azure_content_safety: bool = False, # Enable safety check with Azure content safety api
  57. enable_sensitive_topics: bool = False, # Enable check for sensitive topics using AuditNLG APIs
  58. enable_salesforce_content_safety: bool = True, # Enable safety check with Salesforce safety flan t5
  59. enable_llamaguard_content_safety: bool = False,
  60. max_padding_length: int = None, # the max padding length to be used with tokenizer padding the prompts.
  61. use_fast_kernels: bool = False,
  62. # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
  63. share_gradio: bool = False, # Enable endpoint creation for gradio.live
  64. lang_pairs: str = None,
  65. output_dir: str = None,
  66. **kwargs,
  67. ):
  68. # Set the seeds for reproducibility
  69. if is_xpu_available():
  70. torch.xpu.manual_seed(seed)
  71. else:
  72. torch.cuda.manual_seed(seed)
  73. torch.manual_seed(seed)
  74. # Update the configuration for the training and sharding process
  75. test_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  76. update_config((test_config, fsdp_config), **kwargs)
  77. # dataset_config = generate_dataset_config(test_config, kwargs)
  78. model = load_model(model_name, quantization, use_fast_kernels, **kwargs)
  79. if test_config.preload_peft_dir is not None:
  80. # merge peft into backbone, may not 100% aligned
  81. print("Load and merge peft...")
  82. model = load_peft_model(model, test_config.preload_peft_dir)
  83. model = model.merge_and_unload()
  84. if peft_model:
  85. model = load_peft_model(model, peft_model)
  86. model.eval()
  87. tokenizer = AutoTokenizer.from_pretrained(model_name)
  88. tokenizer.pad_token = tokenizer.eos_token
  89. tokenizer.padding_side = 'left'
  90. model.generation_config.pad_token_id = tokenizer.pad_token_id
  91. # TODO, batch inference
  92. def inference_new(
  93. dataloader,
  94. temperature,
  95. top_p,
  96. top_k,
  97. max_new_tokens,
  98. config,
  99. pbar,
  100. **kwargs,
  101. ):
  102. output = []
  103. for step, batch in enumerate(dataloader):
  104. # TODO, dirty
  105. batch.pop('labels')
  106. if is_xpu_available():
  107. batch = {k: v.to("xpu") for k, v in batch.items()}
  108. else:
  109. batch = {k: v.to("cuda") for k, v in batch.items()}
  110. with torch.no_grad():
  111. batch_output = model.generate(
  112. **batch,
  113. max_new_tokens=max_new_tokens,
  114. do_sample=do_sample,
  115. top_p=top_p,
  116. temperature=temperature,
  117. min_length=min_length,
  118. use_cache=use_cache,
  119. top_k=top_k,
  120. repetition_penalty=repetition_penalty,
  121. length_penalty=length_penalty,
  122. num_beams=config.beam_size,
  123. **kwargs,
  124. )
  125. # prompt_len = batch['attention_mask'].sum(-1)
  126. batch_len = batch['input_ids'].shape[-1]
  127. batch_output = batch_output[:, batch_len:]
  128. batch_output = [tokenizer.decode(output, skip_special_tokens=True) for output in batch_output]
  129. # replace \n with \t to read when hallucinating
  130. batch_output = [sent.replace("\n", "\t").strip() for sent in batch_output]
  131. output += batch_output
  132. pbar.update(1)
  133. return output
  134. # TODO, inference for each dataset
  135. output = {}
  136. rule_names = test_config.rule_names
  137. for rule_name in rule_names:
  138. # Get test data
  139. print("Processing {} ...".format(rule_name), flush=True)
  140. dataset_test = get_enote_dataset(
  141. tokenizer,
  142. test_config.dataset,
  143. mode="infer",
  144. split="valid",
  145. rule_names=rule_names
  146. )
  147. print(f"--> Test Set Length = {len(dataset_test)}", flush=True)
  148. test_dl_kwargs = get_dataloader_kwargs(
  149. test_config, dataset_test, tokenizer, "infer"
  150. )
  151. # Create DataLoaders for inference
  152. test_dataloader = torch.utils.data.DataLoader(
  153. dataset_test,
  154. num_workers=0,
  155. pin_memory=True,
  156. shuffle=False,
  157. **test_dl_kwargs,
  158. )
  159. print(f"--> Num of Testing Set Batches loaded = {len(test_dataloader)}", flush=True)
  160. start = time.perf_counter()
  161. total_length = len(test_dataloader)
  162. pbar = tqdm(colour="blue", desc=f"Inference", total=total_length, dynamic_ncols=True)
  163. results = inference_new(test_dataloader, temperature, top_p, top_k, max_new_tokens, test_config, pbar=pbar)
  164. pbar.close()
  165. e2e_inference_time = (time.perf_counter() - start) * 1000
  166. print(f"the inference time is {e2e_inference_time} ms", flush=True)
  167. output[rule_name] = results
  168. # dump results
  169. create_clean_dir(os.path.join(output_dir, rule_name))
  170. output_file = os.path.join(output_dir, rule_name, "hyp.{}".format(rule_name))
  171. with open(output_file, 'w') as fout:
  172. for line in results:
  173. fout.write(line.strip() + "\n")
  174. if __name__ == "__main__":
  175. fire.Fire(main)