gpt_model.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. # -*- coding: utf-8 -*-
  2. """
  3. @author:XuMing(xuming624@qq.com)
  4. @description: Multi round conversation SFT model
  5. """
  6. import math
  7. import os
  8. import random
  9. from threading import Thread
  10. from typing import List, Tuple, Optional, Union
  11. import numpy as np
  12. import torch
  13. from loguru import logger
  14. from tqdm import tqdm
  15. from transformers import (
  16. AutoConfig,
  17. BloomTokenizerFast,
  18. BloomForCausalLM,
  19. AutoModelForCausalLM,
  20. AutoTokenizer,
  21. AutoModel,
  22. Trainer,
  23. TrainingArguments,
  24. TextIteratorStreamer,
  25. DataCollatorForSeq2Seq,
  26. BitsAndBytesConfig,
  27. )
  28. try:
  29. from transformers.integrations import is_deepspeed_zero3_enabled
  30. except ImportError:
  31. from transformers.deepspeed import is_deepspeed_zero3_enabled
  32. from transformers.trainer import TRAINING_ARGS_NAME
  33. from pycorrector.gpt.gpt_utils import GptSupervisedDataset, IGNORE_INDEX, GptArgs, get_conv_template
  34. has_cuda = torch.cuda.is_available()
  35. os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
  36. os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
  37. MODEL_CLASSES = {
  38. "llama": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
  39. "chatglm": (AutoConfig, AutoModel, AutoTokenizer),
  40. "bloom": (AutoConfig, BloomForCausalLM, BloomTokenizerFast),
  41. "baichuan": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
  42. "auto": (AutoConfig, AutoModelForCausalLM, AutoTokenizer),
  43. }
  44. class GptModel:
  45. def __init__(
  46. self,
  47. model_type,
  48. model_name,
  49. peft_name: Optional[str] = None,
  50. args: Optional[dict] = None,
  51. use_cuda: Optional[bool] = has_cuda,
  52. cuda_device: Optional[int] = -1,
  53. **kwargs,
  54. ):
  55. """
  56. Initializes a GptModel model.
  57. Args:
  58. model_type: The type of model (llama, bloom, baichuan, auto)
  59. model_name: The exact architecture and trained weights to use. This may be a Hugging Face Transformers compatible pre-trained model, a community model, or the path to a directory containing model files.
  60. peft_name (optional): Peft model name
  61. args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
  62. use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
  63. cuda_device (int, optional): Specific GPU that should be used. Will use the first available GPU by default.
  64. **kwargs (optional): For providing proxies, force_download, resume_download, cache_dir and other options specific to the 'from_pretrained' implementation where this will be supplied.
  65. """ # noqa: ignore flake8"
  66. model_type = model_type.lower()
  67. self.args = GptArgs()
  68. if isinstance(args, dict):
  69. self.args.update_from_dict(args)
  70. elif isinstance(args, GptArgs):
  71. self.args = args
  72. if self.args.manual_seed:
  73. random.seed(self.args.manual_seed)
  74. np.random.seed(self.args.manual_seed)
  75. torch.manual_seed(self.args.manual_seed)
  76. if torch.cuda.is_available() > 0:
  77. torch.cuda.manual_seed_all(self.args.manual_seed)
  78. self.device_map = "auto"
  79. if use_cuda:
  80. if torch.cuda.is_available():
  81. if cuda_device == -1:
  82. self.device = torch.device("cuda")
  83. else:
  84. self.device = torch.device(f"cuda:{cuda_device}")
  85. self.device_map = {"": int(cuda_device)}
  86. else:
  87. raise ValueError(
  88. "'use_cuda' set to True when cuda is unavailable."
  89. "Make sure CUDA is available or set `use_cuda=False`."
  90. )
  91. else:
  92. if torch.backends.mps.is_available():
  93. self.device = torch.device("mps")
  94. self.device_map = {"": "mps"}
  95. else:
  96. self.device = "cpu"
  97. self.device_map = {"": "cpu"}
  98. logger.debug(f"Device: {self.device}")
  99. if not use_cuda:
  100. self.args.fp16 = False
  101. self.args.int8 = False
  102. self.world_size = int(os.environ.get("WORLD_SIZE", 1))
  103. self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
  104. self.ddp = self.world_size != 1
  105. if self.ddp:
  106. self.device_map = {"": self.local_rank}
  107. self.results = {}
  108. config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
  109. if model_name is None:
  110. model_name = self.args.model_name_or_path
  111. if self.args.bf16:
  112. self.args.fp16 = False
  113. if self.args.fp16:
  114. self.args.bf16 = False
  115. self.torch_dtype = torch.bfloat16 if self.args.bf16 else (torch.float16 if self.args.fp16 else torch.float32)
  116. self.config = config_class.from_pretrained(
  117. model_name,
  118. trust_remote_code=self.args.trust_remote_code,
  119. torch_dtype=self.torch_dtype,
  120. **kwargs
  121. )
  122. self.model = model_class.from_pretrained(
  123. model_name,
  124. config=self.config,
  125. load_in_8bit=self.args.int8,
  126. load_in_4bit=self.args.int4,
  127. torch_dtype=self.torch_dtype,
  128. low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
  129. device_map=self.device_map,
  130. trust_remote_code=self.args.trust_remote_code,
  131. quantization_config=BitsAndBytesConfig(
  132. load_in_4bit=self.args.int4,
  133. bnb_4bit_use_double_quant=True,
  134. bnb_4bit_quant_type="nf4",
  135. bnb_4bit_compute_dtype=self.torch_dtype,
  136. ) if self.args.qlora else None,
  137. )
  138. self.tokenizer_class = tokenizer_class
  139. if self.args.tokenizer_name:
  140. self.tokenizer = tokenizer_class.from_pretrained(
  141. self.args.tokenizer_name, trust_remote_code=self.args.trust_remote_code)
  142. else:
  143. self.tokenizer = tokenizer_class.from_pretrained(
  144. model_name, trust_remote_code=self.args.trust_remote_code)
  145. self.args.tokenizer_name = self.args.model_name
  146. if self.tokenizer.eos_token_id is None:
  147. self.tokenizer.eos_token = "</s>" # eos token is required for SFT
  148. logger.debug("Add eos token: {}".format(self.tokenizer.eos_token))
  149. if self.tokenizer.pad_token_id is None:
  150. if self.tokenizer.unk_token_id is not None:
  151. self.tokenizer.pad_token = self.tokenizer.unk_token
  152. else:
  153. self.tokenizer.pad_token = self.tokenizer.eos_token
  154. logger.debug("Add pad token: {}".format(self.tokenizer.pad_token))
  155. self.args.model_type = model_type
  156. if model_name is None:
  157. self.args.model_name = "Llama_from_scratch"
  158. else:
  159. self.args.model_name = model_name
  160. self.peft_name = peft_name
  161. if self.args.use_peft and self.peft_name:
  162. self.load_peft_model()
  163. def load_peft_model(self):
  164. """Load peft model"""
  165. from peft import PeftModel
  166. self.model = PeftModel.from_pretrained(
  167. self.model,
  168. self.peft_name,
  169. torch_dtype=self.torch_dtype,
  170. device_map=self.device_map,
  171. )
  172. self.model = self.model.merge_and_unload()
  173. logger.info(f"Loaded peft model from {self.peft_name}")
  174. def find_all_linear_names(self, int4=False, int8=False):
  175. cls = torch.nn.Linear
  176. if int4 or int8:
  177. import bitsandbytes as bnb
  178. if int4:
  179. cls = bnb.nn.Linear4bit
  180. elif int8:
  181. cls = bnb.nn.Linear8bitLt
  182. lora_module_names = set()
  183. for name, module in self.model.named_modules():
  184. if isinstance(module, cls):
  185. # last layer is not add to lora_module_names
  186. if 'lm_head' in name:
  187. continue
  188. if 'output_layer' in name:
  189. continue
  190. names = name.split('.')
  191. lora_module_names.add(names[0] if len(names) == 1 else names[-1])
  192. return sorted(lora_module_names)
  193. def train_model(
  194. self,
  195. train_data,
  196. output_dir=None,
  197. args=None,
  198. eval_data=None,
  199. verbose=True,
  200. **kwargs,
  201. ):
  202. """
  203. Trains the model using 'train_data'
  204. Args:
  205. train_data: json file path or Pandas DataFrame containing 1 columns - `conversations`.
  206. format: {"conversations":[{"from":"human","value":"Mike的妈妈有4个孩子; 其中3个是 Luis、Drake 和 Matilda。 第4个孩子叫什么?"},{"from":"gpt","value":"Mike。"}]}
  207. output_dir: The directory where model files will be saved. If not given, self.args.output_dir will be used.
  208. args (optional): Optional changes to the args dict of the model. Any changes made will persist for the model.
  209. eval_data (optional): A DataFrame against which evaluation will be performed. If it is not passed, evaluation will be skipped.
  210. verbose (optional): If True, all of the warnings related to data processing will be printed.
  211. **kwargs: Additional metrics that should be used. Pass in the metrics as keyword arguments (name of metric: function to use).
  212. A metric function should take in two parameters. The first parameter will be the true labels, and the second parameter will be the predictions. Both inputs
  213. will be lists of strings. Note that this will slow down training significantly as the predicted sequences need to be generated.
  214. Returns:
  215. global_step: Number of global steps trained
  216. training_details: Training progress scores
  217. """ # noqa: ignore flake8"
  218. from peft import (
  219. get_peft_model,
  220. LoraConfig,
  221. TaskType,
  222. PeftModel,
  223. prepare_model_for_kbit_training,
  224. set_peft_model_state_dict,
  225. )
  226. if args:
  227. self.args.update_from_dict(args)
  228. if eval_data is None:
  229. logger.debug("eval_data is not specified. Pass eval_data to model.train_model() if using evaluate.")
  230. if not output_dir:
  231. output_dir = self.args.output_dir
  232. if (
  233. os.path.exists(output_dir)
  234. and os.listdir(output_dir)
  235. and not self.args.overwrite_output_dir
  236. ):
  237. raise ValueError(
  238. "Output directory ({}) already exists and is not empty."
  239. " Set args.overwrite_output_dir = True to overcome.".format(output_dir)
  240. )
  241. # Setup train args
  242. training_args = TrainingArguments(
  243. output_dir=output_dir,
  244. learning_rate=self.args.learning_rate,
  245. num_train_epochs=self.args.num_train_epochs,
  246. logging_dir=f"{output_dir}/logs",
  247. logging_steps=self.args.logging_steps,
  248. max_steps=self.args.max_steps,
  249. per_device_train_batch_size=self.args.per_device_train_batch_size,
  250. per_device_eval_batch_size=self.args.per_device_train_batch_size,
  251. gradient_checkpointing=self.args.gradient_checkpointing,
  252. torch_compile=self.args.torch_compile,
  253. gradient_accumulation_steps=self.args.gradient_accumulation_steps,
  254. warmup_steps=self.args.warmup_steps,
  255. save_steps=self.args.save_steps,
  256. optim=self.args.optimizer,
  257. save_strategy=self.args.save_strategy,
  258. evaluation_strategy='steps' if eval_data is not None else 'no',
  259. eval_steps=self.args.eval_steps if eval_data is not None else None,
  260. load_best_model_at_end=True if eval_data is not None else False,
  261. ddp_find_unused_parameters=False if self.ddp else None,
  262. save_total_limit=self.args.save_total_limit,
  263. fp16=self.args.fp16,
  264. bf16=self.args.bf16,
  265. remove_unused_columns=self.args.remove_unused_columns,
  266. report_to=self.args.report_to,
  267. overwrite_output_dir=self.args.overwrite_output_dir,
  268. no_cuda=True if self.device == "cpu" else False,
  269. **kwargs
  270. )
  271. resume_from_checkpoint = self.args.resume_from_checkpoint
  272. if self.args.qlora and (len(training_args.fsdp) > 0 or is_deepspeed_zero3_enabled()):
  273. logger.warning("FSDP and ZeRO3 are both currently incompatible with QLoRA.")
  274. if 'all' in self.args.lora_target_modules:
  275. self.args.lora_target_modules = self.find_all_linear_names(self.args.int4, self.args.int8)
  276. # setup peft
  277. if self.args.use_peft:
  278. if self.args.int8 or self.args.int4:
  279. self.model = prepare_model_for_kbit_training(self.model, self.args.gradient_checkpointing)
  280. peft_type = self.args.peft_type.upper()
  281. logger.info(f"Using PEFT type: {peft_type}")
  282. # add peft config
  283. if peft_type == 'LORA':
  284. logger.debug(f"Using list modules for LoRA: {self.args.lora_target_modules}")
  285. peft_config = LoraConfig(
  286. task_type=TaskType.CAUSAL_LM,
  287. inference_mode=False,
  288. r=self.args.lora_r,
  289. lora_alpha=self.args.lora_alpha,
  290. lora_dropout=self.args.lora_dropout,
  291. target_modules=self.args.lora_target_modules,
  292. bias=self.args.lora_bias,
  293. )
  294. elif peft_type == 'ADALORA':
  295. from peft import AdaLoraConfig
  296. logger.debug(f"Using list modules for LoRA: {self.args.lora_target_modules}")
  297. peft_config = AdaLoraConfig(
  298. init_r=self.args.adalora_init_r,
  299. r=self.args.lora_r,
  300. beta1=self.args.lora_beta,
  301. beta2=self.args.lora_beta,
  302. tinit=self.args.adalora_tinit,
  303. tfinal=self.args.adalora_tfinal,
  304. deltaT=self.args.adalora_delta_t,
  305. lora_alpha=self.args.lora_alpha,
  306. lora_dropout=self.args.lora_dropout,
  307. target_modules=self.args.lora_target_modules,
  308. task_type=TaskType.CAUSAL_LM,
  309. inference_mode=False,
  310. )
  311. elif peft_type == 'PROMPT_TUNING':
  312. from peft import PromptTuningConfig
  313. peft_config = PromptTuningConfig(
  314. task_type=TaskType.CAUSAL_LM,
  315. num_virtual_tokens=self.args.num_virtual_tokens,
  316. )
  317. elif peft_type == 'P_TUNING':
  318. from peft import PromptEncoderConfig
  319. peft_config = PromptEncoderConfig(
  320. task_type=TaskType.CAUSAL_LM,
  321. num_virtual_tokens=self.args.num_virtual_tokens,
  322. encoder_hidden_size=self.args.prompt_encoder_hidden_size
  323. )
  324. elif peft_type == 'PREFIX_TUNING':
  325. from peft import PrefixTuningConfig
  326. peft_config = PrefixTuningConfig(
  327. task_type=TaskType.CAUSAL_LM,
  328. num_virtual_tokens=self.args.num_virtual_tokens,
  329. encoder_hidden_size=self.args.prompt_encoder_hidden_size,
  330. prefix_projection=True,
  331. )
  332. self.model.gradient_checkpointing_disable()
  333. else:
  334. logger.warning(f"Wrong type of peft. Set to default lora")
  335. logger.debug(f"Using list modules for LoRA: {self.args.lora_target_modules}")
  336. peft_config = LoraConfig(
  337. task_type=TaskType.CAUSAL_LM,
  338. inference_mode=False,
  339. r=self.args.lora_r,
  340. lora_alpha=self.args.lora_alpha,
  341. lora_dropout=self.args.lora_dropout,
  342. target_modules=self.args.lora_target_modules,
  343. bias=self.args.lora_bias,
  344. )
  345. if isinstance(self.model, PeftModel):
  346. logger.debug("Merge peft weights to base model")
  347. self.model = self.model.merge_and_unload()
  348. self.model = get_peft_model(self.model, peft_config)
  349. # Set data type to float32
  350. for param in filter(lambda p: p.requires_grad, self.model.parameters()):
  351. param.data = param.data.to(torch.float32)
  352. if resume_from_checkpoint:
  353. # Check the available weights and load them
  354. checkpoint_name = os.path.join(resume_from_checkpoint, "pytorch_model.bin") # Full checkpoint
  355. if not os.path.exists(checkpoint_name):
  356. checkpoint_name = os.path.join(
  357. resume_from_checkpoint, "adapter_model.bin") # only LoRA model - LoRA config above has to fit
  358. resume_from_checkpoint = (
  359. False # So the trainer won't try loading its state
  360. )
  361. # The two files above have a different name depending on how they were saved, but are actually the same.
  362. if os.path.exists(checkpoint_name):
  363. logger.info(f"Restarting from {checkpoint_name}")
  364. adapters_weights = torch.load(checkpoint_name, map_location='cpu')
  365. set_peft_model_state_dict(self.model, adapters_weights)
  366. else:
  367. logger.warning(f"Checkpoint {checkpoint_name} not found")
  368. resume_from_checkpoint = None
  369. self.model.print_trainable_parameters() # Be more transparent about the % of trainable params.
  370. else:
  371. logger.info("Fine-tuning method: Full parameters training")
  372. # self.model = self.model.float()
  373. os.makedirs(output_dir, exist_ok=True)
  374. logger.debug(f"Tokenizer: {self.tokenizer}")
  375. # load dataset
  376. train_dataset = self.load_and_cache_examples(train_data)
  377. if verbose:
  378. logger.debug(f"train_dataset len: {len(train_dataset)}, train_dataset[0]: {train_dataset[0]}")
  379. logger.debug("Tokenized training example:")
  380. logger.debug(f"Decode input_ids[0]: {self.tokenizer.decode(train_dataset[0]['input_ids'])}")
  381. replaced_labels = [label if label != IGNORE_INDEX else self.tokenizer.pad_token_id
  382. for label in list(train_dataset[0]['labels'])]
  383. logger.debug(f"Decode labels[0]: {self.tokenizer.decode(replaced_labels)}")
  384. eval_dataset = None
  385. if eval_data is not None:
  386. eval_dataset = self.load_and_cache_examples(eval_data, evaluate=True)
  387. if verbose:
  388. logger.debug(f"eval_dataset len: {len(eval_dataset)}, eval_dataset[0]: {eval_dataset[0]}")
  389. # Log on each process the small summary:
  390. logger.warning(
  391. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
  392. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  393. )
  394. if training_args.local_rank <= 0:
  395. logger.info(f"Training/evaluation parameters {training_args}")
  396. # Update model train config
  397. if self.args.gradient_checkpointing:
  398. self.model.gradient_checkpointing_enable()
  399. self.model.config.use_cache = False
  400. else:
  401. self.model.config.use_cache = True
  402. self.model.enable_input_require_grads()
  403. if not self.ddp and torch.cuda.device_count() > 1:
  404. # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
  405. self.model.is_parallelizable = True
  406. self.model.model_parallel = True
  407. # Initialize our Trainer
  408. data_collator = DataCollatorForSeq2Seq(self.tokenizer, label_pad_token_id=IGNORE_INDEX)
  409. trainer = SavePeftModelTrainer(
  410. model=self.model,
  411. train_dataset=train_dataset,
  412. eval_dataset=eval_dataset if eval_data is not None else None,
  413. args=training_args,
  414. tokenizer=self.tokenizer,
  415. data_collator=data_collator,
  416. )
  417. # Training
  418. logger.info("*** Train ***")
  419. sample = next(iter(trainer.get_train_dataloader()))
  420. logger.debug(f"Train dataloader example: {sample}")
  421. logger.debug(f"Detail input_ids: {sample['input_ids'][:3]}, \nlabels: {sample['labels'][:3]}")
  422. logger.debug(f"Decode input_ids[0]: {self.tokenizer.decode(sample['input_ids'][0])}")
  423. replaced_labels = [label if label != IGNORE_INDEX else
  424. self.tokenizer.pad_token_id for label in sample['labels'][0]]
  425. logger.debug(f"Decode labels[0]: {self.tokenizer.decode(replaced_labels)}")
  426. (global_step, training_loss, metrics) = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  427. self.results.update(metrics)
  428. trainer.log_metrics("train", metrics)
  429. trainer.save_metrics("train", metrics)
  430. self.model.config.use_cache = True # enable cache after training
  431. trainer.save_state()
  432. self.save_model(model=self.model)
  433. if eval_data is not None:
  434. logger.info("*** Evaluate ***")
  435. if self.args.fp16:
  436. self.model.half()
  437. metrics = trainer.evaluate(metric_key_prefix="eval")
  438. metrics['eval_samples'] = len(eval_dataset)
  439. try:
  440. perplexity = math.exp(metrics["eval_loss"])
  441. except OverflowError:
  442. perplexity = float("inf")
  443. metrics["perplexity"] = perplexity
  444. logger.debug(f"eval metrics: {metrics}")
  445. self.results.update(metrics)
  446. trainer.log_metrics("eval", metrics)
  447. trainer.save_metrics("eval", metrics)
  448. if verbose and training_args.local_rank <= 0:
  449. logger.debug(f"metrics: {self.results}")
  450. logger.info(
  451. " Training of {} model complete. Saved to {}.".format(
  452. self.args.model_name, output_dir
  453. )
  454. )
  455. return global_step, training_loss
  456. @torch.inference_mode()
  457. def predict(
  458. self,
  459. sentences: List[str],
  460. skip_prompt: bool = True,
  461. prompt_template_name: str = 'vicuna',
  462. max_length: int = None,
  463. do_sample: bool = None,
  464. temperature: float = None,
  465. repetition_penalty: float = None,
  466. eval_batch_size: int = None,
  467. **kwargs
  468. ) -> List[str]:
  469. """
  470. Performs predictions on a list of text.
  471. Args:
  472. sentences: A python list of text (str) to be sent to the model for prediction. Note that the prefix should be prepended to the text.
  473. skip_prompt: Whether to skip the prompt when generating text.
  474. prompt_template_name: The name of the prompt template to use.
  475. max_length: The maximum length of the generated text.
  476. do_sample: Whether or not to use sampling ; use greedy decoding otherwise.
  477. temperature: The value used to module the next token probabilities.
  478. repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty.
  479. eval_batch_size: Batch size to use for evaluation.
  480. **kwargs: Additional arguments for generating sequences.
  481. Returns:
  482. preds: A python list of the generated sequences.
  483. """ # noqa: ignore flake8"
  484. self.model.eval()
  485. if self.args.fp16:
  486. self.model.half()
  487. prompt_template = get_conv_template(prompt_template_name or self.args.prompt_template_name)
  488. if not eval_batch_size:
  489. eval_batch_size = self.args.eval_batch_size
  490. all_outputs = []
  491. # Batching
  492. for batch in tqdm(
  493. [
  494. sentences[i: i + eval_batch_size]
  495. for i in range(0, len(sentences), eval_batch_size)
  496. ],
  497. desc="Generating outputs",
  498. disable=self.args.silent,
  499. ):
  500. if prompt_template_name:
  501. batch = [prompt_template.get_prompt(messages=[[s, '']]) for s in batch]
  502. inputs = self.tokenizer(batch, padding=True, return_tensors='pt')
  503. input_ids = inputs['input_ids'].to(self.device)
  504. generation_kwargs = dict(
  505. max_new_tokens=max_length if max_length is not None else self.args.max_length,
  506. do_sample=do_sample if do_sample is not None else self.args.do_sample,
  507. temperature=temperature if temperature is not None else self.args.temperature,
  508. repetition_penalty=repetition_penalty if repetition_penalty is not None else self.args.repetition_penalty,
  509. )
  510. outputs = self.model.generate(
  511. input_ids=input_ids,
  512. **generation_kwargs,
  513. **kwargs,
  514. )
  515. for prompt, generated_sequence in zip(batch, outputs):
  516. # Decode text
  517. prompt_len = len(input_ids[0])
  518. generated_sequence = generated_sequence[prompt_len:]
  519. gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True)
  520. stop_str = self.tokenizer.eos_token or prompt_template.stop_str
  521. pos = gen_text.find(stop_str)
  522. if pos != -1:
  523. gen_text = gen_text[:pos]
  524. if not skip_prompt:
  525. gen_text = prompt + gen_text
  526. all_outputs.append(gen_text)
  527. return all_outputs
  528. @torch.inference_mode()
  529. def chat(
  530. self,
  531. query: str,
  532. history: Union[List, List[Tuple[str, str]]] = None,
  533. stream: bool = False,
  534. skip_prompt: bool = True,
  535. prompt_template_name: str = "vicuna",
  536. max_new_tokens: int = None,
  537. do_sample: bool = None,
  538. temperature: float = None,
  539. repetition_penalty: float = None,
  540. context_len: int = 2048,
  541. **kwargs
  542. ):
  543. """Chat model with multi turn conversation."""
  544. prompt_template = get_conv_template(prompt_template_name or self.args.prompt_template_name)
  545. if history is None:
  546. history = []
  547. history.append([query, ''])
  548. prompt = prompt_template.get_prompt(messages=history)
  549. input_ids = self.tokenizer(prompt).input_ids
  550. max_new_tokens = max_new_tokens if max_new_tokens is not None else self.args.max_length
  551. max_src_len = context_len - max_new_tokens - 8
  552. input_ids = input_ids[-max_src_len:]
  553. if stream:
  554. streamer = TextIteratorStreamer(
  555. self.tokenizer, timeout=60.0, skip_prompt=skip_prompt, skip_special_tokens=True
  556. )
  557. generation_kwargs = dict(
  558. input_ids=torch.as_tensor([input_ids]).to(self.device),
  559. max_new_tokens=max_new_tokens,
  560. do_sample=do_sample if do_sample is not None else self.args.do_sample,
  561. temperature=temperature if temperature is not None else self.args.temperature,
  562. repetition_penalty=repetition_penalty if repetition_penalty is not None else self.args.repetition_penalty,
  563. streamer=streamer,
  564. **kwargs,
  565. )
  566. thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
  567. thread.start()
  568. yield from streamer
  569. else:
  570. generation_kwargs = dict(
  571. max_new_tokens=max_new_tokens if max_new_tokens is not None else self.args.max_length,
  572. do_sample=do_sample if do_sample is not None else self.args.do_sample,
  573. temperature=temperature if temperature is not None else self.args.temperature,
  574. repetition_penalty=repetition_penalty if repetition_penalty is not None else self.args.repetition_penalty,
  575. )
  576. outputs = self.model.generate(
  577. input_ids=torch.as_tensor([input_ids]).to(self.device),
  578. **generation_kwargs,
  579. **kwargs,
  580. )
  581. output_tensor = outputs[0][len(input_ids[0]):] if skip_prompt else outputs[0]
  582. response = self.tokenizer.decode(output_tensor, skip_special_tokens=True)
  583. history[-1][1] = response
  584. return response, history
  585. def load_and_cache_examples(
  586. self, data, evaluate=False, no_cache=False, verbose=True, silent=False
  587. ):
  588. """
  589. Creates a LlamaDataset from data.
  590. Utility function for train() and eval() methods. Not intended to be used directly.
  591. """
  592. tokenizer = self.tokenizer
  593. args = self.args
  594. if not no_cache:
  595. no_cache = args.no_cache
  596. if not no_cache:
  597. os.makedirs(self.args.cache_dir, exist_ok=True)
  598. mode = "dev" if evaluate else "train"
  599. if args.dataset_class:
  600. CustomDataset = args.dataset_class
  601. return CustomDataset(tokenizer, args, data, mode)
  602. else:
  603. return GptSupervisedDataset(tokenizer, args, data, mode)
  604. def save_model(
  605. self, output_dir=None, optimizer=None, scheduler=None, model=None, results=None
  606. ):
  607. """Save the model and the tokenizer."""
  608. if not output_dir:
  609. output_dir = self.args.output_dir
  610. os.makedirs(output_dir, exist_ok=True)
  611. if model and not self.args.no_save:
  612. torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
  613. # Take care of distributed/parallel training
  614. model_to_save = model.module if hasattr(model, "module") else model
  615. model_to_save.save_pretrained(output_dir)
  616. self.tokenizer.save_pretrained(output_dir)
  617. class SavePeftModelTrainer(Trainer):
  618. """
  619. Trainer for lora models
  620. """
  621. def save_model(self, output_dir=None, _internal_call=False):
  622. """Save the LoRA model."""
  623. os.makedirs(output_dir, exist_ok=True)
  624. torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
  625. self.model.save_pretrained(output_dir)