123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518 |
- #!/usr/bin/env python
- # coding=utf-8
- # Copyright 2021 The HuggingFace Team All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Pre-training/Fine-tuning the library models for causal language modeling
- (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
- Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
- https://huggingface.co/models?filter=text-generation
- """
- from dataclasses import dataclass, field
- import functools
- from itertools import chain
- import json
- import logging
- import os
- from statistics import mean
- import time
- from typing import Optional
- import datasets
- from datasets import Dataset, load_dataset
- import numpy as np
- from tqdm import tqdm
- import alpa
- from alpa.global_env import global_config
- from alpa.model.model_util import DynamicScale, TrainState
- import jax
- import jax.numpy as jnp
- import optax
- import transformers
- import tensorflow as tf
- from transformers import (
- FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
- AutoConfig,
- AutoTokenizer,
- FlaxAutoModelForCausalLM,
- HfArgumentParser,
- )
- logger = logging.getLogger(__name__)
- MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
- MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
- def setup_logging():
- # Make one log on every process with the configuration for debugging.
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=logging.INFO,
- )
- # Setup logging, we only want one process per machine to log things on the screen.
- logger.setLevel(logging.INFO)
- datasets.utils.logging.set_verbosity_warning()
- # Set the verbosity to info of the Transformers logger (on main process only):
- transformers.utils.logging.set_verbosity_info()
- @dataclass
- class TrainingArguments:
- output_dir: str = field(
- metadata={
- "help": "The output directory where the model and checkpoints are saved."
- },
- )
- per_device_train_batch_size: int = field(
- default=1, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
- )
- num_micro_batches: int = field(
- default=1,
- metadata={"help": "The number of micro batches for gradient accumulation."},
- )
- operator_parallel: int = field(
- default=1, metadata={"help": "The degree of operator model parallelism."}
- )
- pipeline_parallel: int = field(
- default=1, metadata={"help": "The degree of pipeline model parallelism."}
- )
- learning_rate: float = field(
- default=5e-5, metadata={"help": "The initial learning rate for AdamW."}
- )
- num_train_epochs: int = field(
- default=1, metadata={"help": "Total number of training epochs to perform."}
- )
- logging_steps: int = field(
- default=10, metadata={"help": "Log every X updates steps."}
- )
- save_steps: int = field(
- default=100, metadata={"help": "Save checkpoint every X updates steps."}
- )
- @dataclass
- class ModelArguments:
- """
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune,
- or train from scratch.
- """
- model_name_or_path: Optional[str] = field(
- metadata={"help": "The model checkpoint for weights initialization."},
- )
- model_type: Optional[str] = field(
- default=None,
- metadata={
- "help": "If training from scratch, pass a model type from the list: "
- + ", ".join(MODEL_TYPES)
- },
- )
- tokenizer_name: Optional[str] = field(
- default=None,
- metadata={
- "help": "Pretrained tokenizer name or path if not the same as model_name"
- },
- )
- @dataclass
- class DataTrainingArguments:
- """Arguments pertaining to what data we are going to use for training and eval."""
- train_file: Optional[str] = field(
- metadata={"help": "The input training data file (a text file)."}
- )
- max_train_samples: Optional[int] = field(
- default=None,
- metadata={
- "help": (
- "For debugging purposes or quicker training, truncate the number of "
- "training examples to this value if set."
- )
- },
- )
- block_size: Optional[int] = field(
- default=1024,
- metadata={
- "help": (
- "Optional input sequence length after tokenization. "
- "The training dataset will be truncated in block of this size. "
- "Default to the model max input length for single sentence inputs "
- "(take into account special tokens)."
- )
- },
- )
- preprocessing_num_workers: Optional[int] = field(
- default=1,
- metadata={"help": "The number of processes to use for the preprocessing."},
- )
- def data_loader(dataset: Dataset, batch_size: int, shuffle: bool = False):
- """Returns batches of size `batch_size` from truncated `dataset`,
- sharded over all local devices. Shuffle batches if `shuffle` is `True`.
- """
- data_collator = transformers.DefaultDataCollator("np")
- tf_dataset = dataset.to_tf_dataset(
- batch_size=batch_size,
- columns=dataset.column_names,
- collate_fn=data_collator,
- shuffle=shuffle,
- drop_remainder=True,
- )
- for batch in tf_dataset:
- batch = {k: v._numpy() for k, v in batch.items()}
- yield batch
- # Main data processing function that will concatenate all texts from
- # our dataset and generate chunks of block_size.
- def group_texts(block_size, examples):
- # Concatenate all texts.
- concatenated_examples = {k: list(chain(*v)) for k, v in examples.items()}
- # Length of first concatenated example.
- total_length = len(next(iter(concatenated_examples.values())))
- # We drop the small remainder, we could add padding if the model supported
- # it instead of this drop, you can customize this part to your needs.
- if total_length >= block_size:
- total_length = (total_length // block_size) * block_size
- # Split by chunks of max_len.
- result = {
- k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
- for k, t in concatenated_examples.items()
- }
- result["labels"] = result["input_ids"].copy()
- return result
- def preprocess(tokenizer, dataset, data_args):
- """Tokenize a single dataset."""
- text_column_name = (
- "text" if "text" in dataset.column_names else dataset.column_names[0]
- )
- print("Tokenize dataset ...")
- tokenized_dataset = dataset.map(
- lambda row: tokenizer(row[text_column_name]),
- # Note that with `batched=True`, this map processes BLOCK_SIZE
- # of texts together, and throws away a remainder for each of
- # those grouped texts.
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- remove_columns=dataset.column_names,
- load_from_cache_file=False,
- )
- print("Build dataset ...")
- block_size = min(data_args.block_size, tokenizer.model_max_length)
- lm_dataset = tokenized_dataset.map(
- functools.partial(group_texts, block_size),
- batched=True,
- num_proc=data_args.preprocessing_num_workers,
- load_from_cache_file=False,
- )
- if data_args.max_train_samples > 0:
- max_samples = min(len(dataset), data_args.max_train_samples)
- lm_dataset = lm_dataset.select(range(max_samples))
- return lm_dataset
- def build_datasets(tokenizer, data_args):
- # TODO(jungong) : replace huggingface dataset with Ray dataset.
- # Manually create train split.
- dataset = load_dataset(
- "text",
- data_files={
- "train": data_args.train_file,
- },
- keep_linebreaks=False,
- )
- train_dataset = preprocess(tokenizer, dataset["train"], data_args)
- return train_dataset
- # Define gradient update step fn
- def train_step(state, batch):
- """Main training step function."""
- def loss_fn(logits, labels):
- shift_logits = logits[..., :-1, :]
- shift_labels = labels[..., 1:]
- loss = optax.softmax_cross_entropy(
- shift_logits, jax.nn.one_hot(shift_labels, logits.shape[-1])
- )
- return loss.mean()
- def compute_loss(params):
- labels = batch.pop("labels")
- logits = state.apply_fn(**batch, params=params, deterministic=True)[0]
- loss = loss_fn(logits, labels)
- return loss
- dynamic_scale = state.dynamic_scale
- grad_fn = dynamic_scale.value_and_grad(compute_loss)
- dynamic_scale, is_fin, loss, grads = grad_fn(state.params)
- new_state = state.apply_gradients(grads=grads)
- new_state = new_state.replace(
- opt_state=jax.tree_map(
- functools.partial(jnp.where, is_fin),
- new_state.opt_state,
- state.opt_state,
- ),
- params=jax.tree_map(
- functools.partial(jnp.where, is_fin), new_state.params, state.params
- ),
- master_copy=jax.tree_map(
- functools.partial(jnp.where, is_fin),
- new_state.master_copy,
- state.master_copy,
- ),
- dynamic_scale=dynamic_scale,
- )
- metrics = {"loss": loss}
- return new_state, metrics
- def save_checkpoint(state, model, tokenizer, training_args):
- """Util to checkpoint model in output_dir."""
- alpa.prefetch(state.params)
- params = alpa.util.map_to_nparray(state.params)
- model.save_pretrained(training_args.output_dir, params=params)
- tokenizer.save_pretrained(training_args.output_dir)
- def log_metrics(
- config, epochs, metrics_to_report, batch, latency, epoch, step, train_metric
- ):
- """Log metrics to stdout."""
- throughput_tokens = np.prod(batch["input_ids"].shape) / latency
- throughput_tflops = alpa.util.compute_gpt_tflops(
- batch_size=batch["input_ids"].shape[0],
- seq_len=batch["input_ids"].shape[1],
- num_layers=config.num_hidden_layers,
- hidden_size=config.hidden_size,
- vocab_size=config.vocab_size,
- num_gpus=alpa.get_global_num_devices(),
- latency=latency,
- )
- train_metric = jax.tree_map(np.mean, train_metric)
- # Metrics we report from the release test.
- metrics_to_report["tokens"].append(throughput_tokens)
- metrics_to_report["tflops"].append(throughput_tflops)
- epochs.write(
- f"Epoch: {epoch} | "
- f"Step: {step} | "
- f"Loss: {train_metric['loss'].mean():.4f}, "
- f"Throughput: {throughput_tokens:.2f} token/s, "
- f"{throughput_tflops:.2f} TFLOP/s"
- )
- def save_json_metrics(metrics):
- # Skip the first couple of data points for a more accurate throughput.
- to_report = {
- "throughput_tokens": (
- mean(metrics["tokens"][2:]) if len(metrics["tokens"]) > 2 else 0.0
- ),
- "throughput_tflops": (
- mean(metrics["tflops"][2:]) if len(metrics["tflops"]) > 2 else 0.0
- ),
- }
- test_output_json = os.environ.get(
- "TEST_OUTPUT_JSON", "/tmp/alpa_opt_2_7b_sanity_check.json"
- )
- print("Writing metrics: ", to_report, f" to {test_output_json}")
- with open(test_output_json, "wt") as f:
- json.dump(to_report, f)
- def main():
- # Global initialization.
- alpa.init(cluster="ray")
- tf.config.experimental.set_visible_devices([], "GPU")
- # "cupy" doesn't really work, use "xla_extension" instead.
- global_config.nccl_mode = "xla_extension"
- # See all possible arguments in src/transformers/training_args.py
- # or by passing the --help flag to this script.
- # We now keep distinct sets of args, for a cleaner separation of concerns.
- parser = HfArgumentParser(
- (ModelArguments, DataTrainingArguments, TrainingArguments)
- )
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
- if os.path.exists(training_args.output_dir) and os.listdir(
- training_args.output_dir
- ):
- raise ValueError(
- f"Directory ({training_args.output_dir}) already exists and is not empty."
- )
- logger.info(f"Training/evaluation parameters {training_args}")
- setup_logging()
- # Load pretrained model and tokenizer
- # Distributed training:
- config = AutoConfig.from_pretrained(model_args.model_name_or_path)
- tokenizer = AutoTokenizer.from_pretrained(
- model_args.model_name_or_path,
- use_fast=True,
- )
- assert model_args.model_name_or_path, "model_name_or_path is required"
- model = FlaxAutoModelForCausalLM.from_pretrained(
- model_args.model_name_or_path,
- config=config,
- dtype=getattr(jnp, "float16"),
- use_auth_token=None,
- )
- # Training dataset.
- train_dataset = build_datasets(tokenizer, data_args)
- # Adjust batch size and num_micro_batches for small datasets
- num_devices = alpa.get_global_num_devices()
- # Store some constant
- num_epochs = training_args.num_train_epochs
- data_parallel = num_devices // (
- training_args.operator_parallel * training_args.pipeline_parallel
- )
- train_batch_size = training_args.per_device_train_batch_size * data_parallel
- steps_per_epoch = len(train_dataset) // train_batch_size
- total_train_steps = steps_per_epoch * num_epochs
- # create adam optimizer
- optimizer = optax.chain(
- optax.clip_by_global_norm(1.0),
- optax.adamw(learning_rate=training_args.learning_rate),
- )
- # Setup train state
- state = TrainState.create(
- apply_fn=model.__call__,
- params=model.params,
- tx=optimizer,
- dynamic_scale=DynamicScale(),
- use_master_copy=True,
- )
- # Create parallel version of the train and eval step
- method = alpa.get_3d_parallel_method(
- num_micro_batches=training_args.num_micro_batches,
- data_parallel=-1,
- operator_parallel=training_args.operator_parallel,
- pipeline_parallel=training_args.pipeline_parallel,
- )
- p_train_step = alpa.parallelize(train_step, method=method, donate_argnums=(0,))
- logger.info("***** Training *****")
- logger.info(f" Num examples = {len(train_dataset)}")
- logger.info(f" Num Epochs = {num_epochs}")
- logger.info(
- " Batch size per device (w. accumulation) = "
- f"{training_args.per_device_train_batch_size}"
- )
- logger.info(
- f" Global train batch size (w. parallel & distributed) = {train_batch_size}"
- )
- logger.info(f" Total optimization steps = {total_train_steps}")
- logger.info(f" NCCL mode = {global_config.nccl_mode}")
- step_ct = 0
- last_time = 0
- epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
- epochs.write("Initial compilation. This might take some minutes...")
- # Track and report throughput per iteration. These are the metrics we
- # care about over time.
- metrics_to_report = {
- "tokens": [],
- "tflops": [],
- }
- for epoch in epochs:
- # Generate an epoch by shuffling sampling indices from the train dataset
- train_loader = data_loader(train_dataset, train_batch_size, shuffle=True)
- last_time = time.time()
- for step in tqdm(
- range(steps_per_epoch), desc="Training...", position=1, leave=False
- ):
- batch = next(train_loader)
- batch["position_ids"] = (
- batch["attention_mask"].cumsum(axis=1) * batch["attention_mask"]
- ) - 1
- state, train_metric = p_train_step(state, batch)
- cur_step = epoch * steps_per_epoch + step
- step_ct += 1
- if cur_step % training_args.logging_steps == 0 and cur_step > 0:
- latency = (time.time() - last_time) / step_ct
- log_metrics(
- config,
- epochs,
- metrics_to_report,
- batch,
- latency,
- epoch,
- cur_step,
- train_metric,
- )
- step_ct = 0
- last_time = time.time()
- if cur_step % training_args.save_steps == 0 and cur_step > 0:
- # save checkpoint after each epoch
- epochs.write("\nSave checkpoint...")
- save_checkpoint(state, model, tokenizer, training_args)
- # Save the final model
- epochs.write("\nSave the final model...")
- save_checkpoint(state, model, tokenizer, training_args)
- # Save JSON metrics
- save_json_metrics(metrics_to_report)
- if __name__ == "__main__":
- main()
|