train_opt_2_7b_minimum.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2021 The HuggingFace Team All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """
  17. Pre-training/Fine-tuning the library models for causal language modeling
  18. (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
  19. Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
  20. https://huggingface.co/models?filter=text-generation
  21. """
  22. from dataclasses import dataclass, field
  23. import functools
  24. from itertools import chain
  25. import json
  26. import logging
  27. import os
  28. from statistics import mean
  29. import time
  30. from typing import Optional
  31. import datasets
  32. from datasets import Dataset, load_dataset
  33. import numpy as np
  34. from tqdm import tqdm
  35. import alpa
  36. from alpa.global_env import global_config
  37. from alpa.model.model_util import DynamicScale, TrainState
  38. import jax
  39. import jax.numpy as jnp
  40. import optax
  41. import transformers
  42. import tensorflow as tf
  43. from transformers import (
  44. FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
  45. AutoConfig,
  46. AutoTokenizer,
  47. FlaxAutoModelForCausalLM,
  48. HfArgumentParser,
  49. )
  50. logger = logging.getLogger(__name__)
  51. MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
  52. MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
  53. def setup_logging():
  54. # Make one log on every process with the configuration for debugging.
  55. logging.basicConfig(
  56. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  57. datefmt="%m/%d/%Y %H:%M:%S",
  58. level=logging.INFO,
  59. )
  60. # Setup logging, we only want one process per machine to log things on the screen.
  61. logger.setLevel(logging.INFO)
  62. datasets.utils.logging.set_verbosity_warning()
  63. # Set the verbosity to info of the Transformers logger (on main process only):
  64. transformers.utils.logging.set_verbosity_info()
  65. @dataclass
  66. class TrainingArguments:
  67. output_dir: str = field(
  68. metadata={
  69. "help": "The output directory where the model and checkpoints are saved."
  70. },
  71. )
  72. per_device_train_batch_size: int = field(
  73. default=1, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
  74. )
  75. num_micro_batches: int = field(
  76. default=1,
  77. metadata={"help": "The number of micro batches for gradient accumulation."},
  78. )
  79. operator_parallel: int = field(
  80. default=1, metadata={"help": "The degree of operator model parallelism."}
  81. )
  82. pipeline_parallel: int = field(
  83. default=1, metadata={"help": "The degree of pipeline model parallelism."}
  84. )
  85. learning_rate: float = field(
  86. default=5e-5, metadata={"help": "The initial learning rate for AdamW."}
  87. )
  88. num_train_epochs: int = field(
  89. default=1, metadata={"help": "Total number of training epochs to perform."}
  90. )
  91. logging_steps: int = field(
  92. default=10, metadata={"help": "Log every X updates steps."}
  93. )
  94. save_steps: int = field(
  95. default=100, metadata={"help": "Save checkpoint every X updates steps."}
  96. )
  97. @dataclass
  98. class ModelArguments:
  99. """
  100. Arguments pertaining to which model/config/tokenizer we are going to fine-tune,
  101. or train from scratch.
  102. """
  103. model_name_or_path: Optional[str] = field(
  104. metadata={"help": "The model checkpoint for weights initialization."},
  105. )
  106. model_type: Optional[str] = field(
  107. default=None,
  108. metadata={
  109. "help": "If training from scratch, pass a model type from the list: "
  110. + ", ".join(MODEL_TYPES)
  111. },
  112. )
  113. tokenizer_name: Optional[str] = field(
  114. default=None,
  115. metadata={
  116. "help": "Pretrained tokenizer name or path if not the same as model_name"
  117. },
  118. )
  119. @dataclass
  120. class DataTrainingArguments:
  121. """Arguments pertaining to what data we are going to use for training and eval."""
  122. train_file: Optional[str] = field(
  123. metadata={"help": "The input training data file (a text file)."}
  124. )
  125. max_train_samples: Optional[int] = field(
  126. default=None,
  127. metadata={
  128. "help": (
  129. "For debugging purposes or quicker training, truncate the number of "
  130. "training examples to this value if set."
  131. )
  132. },
  133. )
  134. block_size: Optional[int] = field(
  135. default=1024,
  136. metadata={
  137. "help": (
  138. "Optional input sequence length after tokenization. "
  139. "The training dataset will be truncated in block of this size. "
  140. "Default to the model max input length for single sentence inputs "
  141. "(take into account special tokens)."
  142. )
  143. },
  144. )
  145. preprocessing_num_workers: Optional[int] = field(
  146. default=1,
  147. metadata={"help": "The number of processes to use for the preprocessing."},
  148. )
  149. def data_loader(dataset: Dataset, batch_size: int, shuffle: bool = False):
  150. """Returns batches of size `batch_size` from truncated `dataset`,
  151. sharded over all local devices. Shuffle batches if `shuffle` is `True`.
  152. """
  153. data_collator = transformers.DefaultDataCollator("np")
  154. tf_dataset = dataset.to_tf_dataset(
  155. batch_size=batch_size,
  156. columns=dataset.column_names,
  157. collate_fn=data_collator,
  158. shuffle=shuffle,
  159. drop_remainder=True,
  160. )
  161. for batch in tf_dataset:
  162. batch = {k: v._numpy() for k, v in batch.items()}
  163. yield batch
  164. # Main data processing function that will concatenate all texts from
  165. # our dataset and generate chunks of block_size.
  166. def group_texts(block_size, examples):
  167. # Concatenate all texts.
  168. concatenated_examples = {k: list(chain(*v)) for k, v in examples.items()}
  169. # Length of first concatenated example.
  170. total_length = len(next(iter(concatenated_examples.values())))
  171. # We drop the small remainder, we could add padding if the model supported
  172. # it instead of this drop, you can customize this part to your needs.
  173. if total_length >= block_size:
  174. total_length = (total_length // block_size) * block_size
  175. # Split by chunks of max_len.
  176. result = {
  177. k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
  178. for k, t in concatenated_examples.items()
  179. }
  180. result["labels"] = result["input_ids"].copy()
  181. return result
  182. def preprocess(tokenizer, dataset, data_args):
  183. """Tokenize a single dataset."""
  184. text_column_name = (
  185. "text" if "text" in dataset.column_names else dataset.column_names[0]
  186. )
  187. print("Tokenize dataset ...")
  188. tokenized_dataset = dataset.map(
  189. lambda row: tokenizer(row[text_column_name]),
  190. # Note that with `batched=True`, this map processes BLOCK_SIZE
  191. # of texts together, and throws away a remainder for each of
  192. # those grouped texts.
  193. batched=True,
  194. num_proc=data_args.preprocessing_num_workers,
  195. remove_columns=dataset.column_names,
  196. load_from_cache_file=False,
  197. )
  198. print("Build dataset ...")
  199. block_size = min(data_args.block_size, tokenizer.model_max_length)
  200. lm_dataset = tokenized_dataset.map(
  201. functools.partial(group_texts, block_size),
  202. batched=True,
  203. num_proc=data_args.preprocessing_num_workers,
  204. load_from_cache_file=False,
  205. )
  206. if data_args.max_train_samples > 0:
  207. max_samples = min(len(dataset), data_args.max_train_samples)
  208. lm_dataset = lm_dataset.select(range(max_samples))
  209. return lm_dataset
  210. def build_datasets(tokenizer, data_args):
  211. # TODO(jungong) : replace huggingface dataset with Ray dataset.
  212. # Manually create train split.
  213. dataset = load_dataset(
  214. "text",
  215. data_files={
  216. "train": data_args.train_file,
  217. },
  218. keep_linebreaks=False,
  219. )
  220. train_dataset = preprocess(tokenizer, dataset["train"], data_args)
  221. return train_dataset
  222. # Define gradient update step fn
  223. def train_step(state, batch):
  224. """Main training step function."""
  225. def loss_fn(logits, labels):
  226. shift_logits = logits[..., :-1, :]
  227. shift_labels = labels[..., 1:]
  228. loss = optax.softmax_cross_entropy(
  229. shift_logits, jax.nn.one_hot(shift_labels, logits.shape[-1])
  230. )
  231. return loss.mean()
  232. def compute_loss(params):
  233. labels = batch.pop("labels")
  234. logits = state.apply_fn(**batch, params=params, deterministic=True)[0]
  235. loss = loss_fn(logits, labels)
  236. return loss
  237. dynamic_scale = state.dynamic_scale
  238. grad_fn = dynamic_scale.value_and_grad(compute_loss)
  239. dynamic_scale, is_fin, loss, grads = grad_fn(state.params)
  240. new_state = state.apply_gradients(grads=grads)
  241. new_state = new_state.replace(
  242. opt_state=jax.tree_map(
  243. functools.partial(jnp.where, is_fin),
  244. new_state.opt_state,
  245. state.opt_state,
  246. ),
  247. params=jax.tree_map(
  248. functools.partial(jnp.where, is_fin), new_state.params, state.params
  249. ),
  250. master_copy=jax.tree_map(
  251. functools.partial(jnp.where, is_fin),
  252. new_state.master_copy,
  253. state.master_copy,
  254. ),
  255. dynamic_scale=dynamic_scale,
  256. )
  257. metrics = {"loss": loss}
  258. return new_state, metrics
  259. def save_checkpoint(state, model, tokenizer, training_args):
  260. """Util to checkpoint model in output_dir."""
  261. alpa.prefetch(state.params)
  262. params = alpa.util.map_to_nparray(state.params)
  263. model.save_pretrained(training_args.output_dir, params=params)
  264. tokenizer.save_pretrained(training_args.output_dir)
  265. def log_metrics(
  266. config, epochs, metrics_to_report, batch, latency, epoch, step, train_metric
  267. ):
  268. """Log metrics to stdout."""
  269. throughput_tokens = np.prod(batch["input_ids"].shape) / latency
  270. throughput_tflops = alpa.util.compute_gpt_tflops(
  271. batch_size=batch["input_ids"].shape[0],
  272. seq_len=batch["input_ids"].shape[1],
  273. num_layers=config.num_hidden_layers,
  274. hidden_size=config.hidden_size,
  275. vocab_size=config.vocab_size,
  276. num_gpus=alpa.get_global_num_devices(),
  277. latency=latency,
  278. )
  279. train_metric = jax.tree_map(np.mean, train_metric)
  280. # Metrics we report from the release test.
  281. metrics_to_report["tokens"].append(throughput_tokens)
  282. metrics_to_report["tflops"].append(throughput_tflops)
  283. epochs.write(
  284. f"Epoch: {epoch} | "
  285. f"Step: {step} | "
  286. f"Loss: {train_metric['loss'].mean():.4f}, "
  287. f"Throughput: {throughput_tokens:.2f} token/s, "
  288. f"{throughput_tflops:.2f} TFLOP/s"
  289. )
  290. def save_json_metrics(metrics):
  291. # Skip the first couple of data points for a more accurate throughput.
  292. to_report = {
  293. "throughput_tokens": (
  294. mean(metrics["tokens"][2:]) if len(metrics["tokens"]) > 2 else 0.0
  295. ),
  296. "throughput_tflops": (
  297. mean(metrics["tflops"][2:]) if len(metrics["tflops"]) > 2 else 0.0
  298. ),
  299. }
  300. test_output_json = os.environ.get(
  301. "TEST_OUTPUT_JSON", "/tmp/alpa_opt_2_7b_sanity_check.json"
  302. )
  303. print("Writing metrics: ", to_report, f" to {test_output_json}")
  304. with open(test_output_json, "wt") as f:
  305. json.dump(to_report, f)
  306. def main():
  307. # Global initialization.
  308. alpa.init(cluster="ray")
  309. tf.config.experimental.set_visible_devices([], "GPU")
  310. # "cupy" doesn't really work, use "xla_extension" instead.
  311. global_config.nccl_mode = "xla_extension"
  312. # See all possible arguments in src/transformers/training_args.py
  313. # or by passing the --help flag to this script.
  314. # We now keep distinct sets of args, for a cleaner separation of concerns.
  315. parser = HfArgumentParser(
  316. (ModelArguments, DataTrainingArguments, TrainingArguments)
  317. )
  318. model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  319. if os.path.exists(training_args.output_dir) and os.listdir(
  320. training_args.output_dir
  321. ):
  322. raise ValueError(
  323. f"Directory ({training_args.output_dir}) already exists and is not empty."
  324. )
  325. logger.info(f"Training/evaluation parameters {training_args}")
  326. setup_logging()
  327. # Load pretrained model and tokenizer
  328. # Distributed training:
  329. config = AutoConfig.from_pretrained(model_args.model_name_or_path)
  330. tokenizer = AutoTokenizer.from_pretrained(
  331. model_args.model_name_or_path,
  332. use_fast=True,
  333. )
  334. assert model_args.model_name_or_path, "model_name_or_path is required"
  335. model = FlaxAutoModelForCausalLM.from_pretrained(
  336. model_args.model_name_or_path,
  337. config=config,
  338. dtype=getattr(jnp, "float16"),
  339. use_auth_token=None,
  340. )
  341. # Training dataset.
  342. train_dataset = build_datasets(tokenizer, data_args)
  343. # Adjust batch size and num_micro_batches for small datasets
  344. num_devices = alpa.get_global_num_devices()
  345. # Store some constant
  346. num_epochs = training_args.num_train_epochs
  347. data_parallel = num_devices // (
  348. training_args.operator_parallel * training_args.pipeline_parallel
  349. )
  350. train_batch_size = training_args.per_device_train_batch_size * data_parallel
  351. steps_per_epoch = len(train_dataset) // train_batch_size
  352. total_train_steps = steps_per_epoch * num_epochs
  353. # create adam optimizer
  354. optimizer = optax.chain(
  355. optax.clip_by_global_norm(1.0),
  356. optax.adamw(learning_rate=training_args.learning_rate),
  357. )
  358. # Setup train state
  359. state = TrainState.create(
  360. apply_fn=model.__call__,
  361. params=model.params,
  362. tx=optimizer,
  363. dynamic_scale=DynamicScale(),
  364. use_master_copy=True,
  365. )
  366. # Create parallel version of the train and eval step
  367. method = alpa.get_3d_parallel_method(
  368. num_micro_batches=training_args.num_micro_batches,
  369. data_parallel=-1,
  370. operator_parallel=training_args.operator_parallel,
  371. pipeline_parallel=training_args.pipeline_parallel,
  372. )
  373. p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,))
  374. logger.info("***** Training *****")
  375. logger.info(f" Num examples = {len(train_dataset)}")
  376. logger.info(f" Num Epochs = {num_epochs}")
  377. logger.info(
  378. " Batch size per device (w. accumulation) = "
  379. f"{training_args.per_device_train_batch_size}"
  380. )
  381. logger.info(
  382. f" Global train batch size (w. parallel & distributed) = {train_batch_size}"
  383. )
  384. logger.info(f" Total optimization steps = {total_train_steps}")
  385. logger.info(f" NCCL mode = {global_config.nccl_mode}")
  386. step_ct = 0
  387. last_time = 0
  388. epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
  389. epochs.write("Initial compilation. This might take some minutes...")
  390. # Track and report throughput per iteration. These are the metrics we
  391. # care about over time.
  392. metrics_to_report = {
  393. "tokens": [],
  394. "tflops": [],
  395. }
  396. for epoch in epochs:
  397. # Generate an epoch by shuffling sampling indices from the train dataset
  398. train_loader = data_loader(train_dataset, train_batch_size, shuffle=True)
  399. last_time = time.time()
  400. for step in tqdm(
  401. range(steps_per_epoch), desc="Training...", position=1, leave=False
  402. ):
  403. batch = next(train_loader)
  404. batch["position_ids"] = (
  405. batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]
  406. ) - 1
  407. state, train_metric = p_train_step(state, batch)
  408. cur_step = epoch * steps_per_epoch + step
  409. step_ct += 1
  410. if cur_step % training_args.logging_steps == 0 and cur_step > 0:
  411. latency = (time.time() - last_time) / step_ct
  412. log_metrics(
  413. config,
  414. epochs,
  415. metrics_to_report,
  416. batch,
  417. latency,
  418. epoch,
  419. cur_step,
  420. train_metric,
  421. )
  422. step_ct = 0
  423. last_time = time.time()
  424. if cur_step % training_args.save_steps == 0 and cur_step > 0:
  425. # save checkpoint after each epoch
  426. epochs.write("\nSave checkpoint...")
  427. save_checkpoint(state, model, tokenizer, training_args)
  428. # Save the final model
  429. epochs.write("\nSave the final model...")
  430. save_checkpoint(state, model, tokenizer, training_args)
  431. # Save JSON metrics
  432. save_json_metrics(metrics_to_report)
  433. if __name__ == "__main__":
  434. main()