{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "(lightning_advanced_example)=\n", "\n", "# Finetune a BERT Text Classifier with LightningTrainer\n", "\n", ":::{note}\n", "\n", "This is an advanced example for {class}`LightningTrainer `, which demonstrates how to use LightningTrainer with {ref}`Dataset ` and {ref}`Batch Predictor `. \n", "\n", "If you just want to quickly convert your existing PyTorch Lightning scripts into Ray AIR, you can refer to this starter example:\n", "{ref}`Train a Pytorch Lightning Image Classifier `.\n", "\n", ":::\n", "\n", "In this demo, we will introduce how to finetune a text classifier on [CoLA(The Corpus of Linguistic Acceptability)](https://nyu-mll.github.io/CoLA/) datasets with pretrained BERT. \n", "In particular, we will:\n", "- Create Ray Data from the original CoLA dataset.\n", "- Define a preprocessor to tokenize the sentences.\n", "- Finetune a BERT model using LightningTrainer.\n", "- Construct a BatchPredictor with the checkpoint and preprocessor.\n", "- Do batch prediction on multiple GPUs, and evaluate the results." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "SMOKE_TEST = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Run the following line in order to install all the necessary dependencies:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "!pip install numpy datasets \"transformers>=4.19.1\" \"pytorch_lightning>=1.6.5\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by importing the needed libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ray\n", "import torch\n", "import pytorch_lightning as pl\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "from datasets import load_dataset, load_metric\n", "import numpy as np" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-process CoLA Dataset\n", "\n", "CoLA is a binary sentence classification task with 10.6K training examples. First, we download the dataset and metrics using the HuggingFace API, and create Ray Data for each split accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"glue\", \"cola\")\n", "metric = load_metric(\"glue\", \"cola\")\n", "\n", "ray_datasets = ray.data.from_huggingface(dataset)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Next, define a preprocessor that tokenizes the input sentences and pads the ID sequence to length 128 using the bert-base-uncased tokenizer. The preprocessor transforms all datasets that we provide to the LightningTrainer later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ray.data.preprocessors import BatchMapper\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", "\n", "\n", "def tokenize_sentence(batch):\n", " encoded_sent = tokenizer(\n", " batch[\"sentence\"].tolist(),\n", " max_length=128,\n", " truncation=True,\n", " padding=\"max_length\",\n", " return_tensors=\"pt\",\n", " )\n", " batch[\"input_ids\"] = encoded_sent[\"input_ids\"].numpy()\n", " batch[\"attention_mask\"] = encoded_sent[\"attention_mask\"].numpy()\n", " batch[\"label\"] = np.array(batch[\"label\"])\n", " batch.pop(\"sentence\")\n", " return batch\n", "\n", "\n", "preprocessor = BatchMapper(tokenize_sentence, batch_format=\"numpy\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Define a PyTorch Lightning Model\n", "\n", "You don't have to make any change of your `LightningModule` definition. Just copy and paste your code here:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class SentimentModel(pl.LightningModule):\n", " def __init__(self, lr=2e-5, eps=1e-8):\n", " super().__init__()\n", " self.lr = lr\n", " self.eps = eps\n", " self.num_classes = 2\n", " self.model = AutoModelForSequenceClassification.from_pretrained(\n", " \"bert-base-cased\", num_labels=self.num_classes\n", " )\n", " self.metric = load_metric(\"glue\", \"cola\")\n", " self.predictions = []\n", " self.references = []\n", "\n", " def forward(self, batch):\n", " input_ids, attention_mask = batch[\"input_ids\"], batch[\"attention_mask\"]\n", " outputs = self.model(input_ids, attention_mask=attention_mask)\n", " logits = outputs.logits\n", " return logits\n", "\n", " def training_step(self, batch, batch_idx):\n", " labels = batch[\"label\"]\n", " logits = self.forward(batch)\n", " loss = F.cross_entropy(logits.view(-1, self.num_classes), labels)\n", " self.log(\"train_loss\", loss)\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " labels = batch[\"label\"]\n", " logits = self.forward(batch)\n", " preds = torch.argmax(logits, dim=1)\n", " self.predictions.append(preds)\n", " self.references.append(labels)\n", "\n", " def on_validation_epoch_end(self):\n", " predictions = torch.concat(self.predictions).view(-1)\n", " references = torch.concat(self.references).view(-1)\n", " matthews_correlation = self.metric.compute(\n", " predictions=predictions, references=references\n", " )\n", "\n", " # self.metric.compute() returns a dictionary:\n", " # e.g. {\"matthews_correlation\": 0.53}\n", " self.log_dict(matthews_correlation, sync_dist=True)\n", " self.predictions.clear()\n", " self.references.clear()\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Configure your LightningTrainer\n", "\n", "Define a LightningTrainer with necessary configurations, including hyper-parameters, checkpointing and compute resources settings. \n", "\n", "You may find the API of {class}`LightningConfigBuilder ` and the discussion {ref}`here ` useful.\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from ray.train.lightning import LightningTrainer, LightningConfigBuilder\n", "from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig\n", "\n", "# Define the configs for LightningTrainer\n", "lightning_config = (\n", " LightningConfigBuilder()\n", " .module(cls=SentimentModel, lr=1e-5, eps=1e-8)\n", " .trainer(max_epochs=5, accelerator=\"gpu\")\n", " .checkpointing(save_on_train_epoch_end=False)\n", " .build()\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ ":::{note}\n", "Note that the `lightning_config` is created on the head node and will be passed to the worker nodes later. Be aware that the environment variables and hardware settings may differ between the head node and worker nodes.\n", ":::\n", "\n", ":::{note}\n", "{meth}`LightningConfigBuilder.checkpointing() ` creates a [ModelCheckpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint) callback. This callback defines the checkpoint frequency and saves checkpoint files in Lightning style. \n", "\n", "If you want to save AIR checkpoints for Batch Prediction, please also provide an AIR {class}`CheckpointConfig `.\n", ":::" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Save AIR checkpoints according to the performance on validation set\n", "run_config = RunConfig(\n", " name=\"ptl-sent-classification\",\n", " checkpoint_config=CheckpointConfig(\n", " num_to_keep=2,\n", " checkpoint_score_attribute=\"matthews_correlation\",\n", " checkpoint_score_order=\"max\",\n", " ),\n", ")\n", "\n", "# Scale the DDP training workload across 4 GPUs\n", "# You can change this config based on your compute resources.\n", "scaling_config = ScalingConfig(\n", " num_workers=4, use_gpu=True, resources_per_worker={\"CPU\": 1, \"GPU\": 1}\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "if SMOKE_TEST:\n", " lightning_config = (\n", " LightningConfigBuilder()\n", " .module(cls=SentimentModel, lr=1e-5, eps=1e-8)\n", " .trainer(max_epochs=2, accelerator=\"gpu\")\n", " .checkpointing(save_on_train_epoch_end=False)\n", " .build()\n", " )\n", "\n", " for split, ds in ray_datasets.items():\n", " ray_datasets[split] = ds.random_sample(0.1)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-tune the model with LightningTrainer\n", "\n", "Train the model with the configuration we specified above. \n", "\n", "To feed data into LightningTrainer, we need to configure the following arguments:\n", "\n", "- `datasets`: A dictionary of the input Ray datasets, with special keys \"train\" and \"val\".\n", "- `datasets_iter_config`: The argument list of {meth}`iter_torch_batches() `. It defines the way we iterate dataset shards for each worker.\n", "- `preprocessor`: The preprocessor that will be applied to the input dataset.\n", "\n", ":::{note}\n", "Note that we are using Dataset for data ingestion for faster preprocessing here, but you can also continue to use the native `PyTorch DataLoader` or `LightningDataModule`. See {ref}`this example `. \n", "\n", ":::\n", "\n", "\n", "Now, call `trainer.fit()` to initiate the training process." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2023-04-24 10:42:50
Running for: 00:06:26.94
Memory: 23.8/186.6 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 0/48 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:T4)\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc iter total time (s) train_loss matthews_correlation epoch
LightningTrainer_87ecf_00000TERMINATED10.0.60.127:67819 5 376.028 0.0119807 0.589931 4
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "(pid=67819) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "(pid=67819) from pandas import MultiIndex, Int64Index\n", "(LightningTrainer pid=67819) 2023-04-24 10:36:31,679\tINFO backend_executor.py:128 -- Starting distributed worker processes: ['68396 (10.0.60.127)', '68397 (10.0.60.127)', '68398 (10.0.60.127)', '68399 (10.0.60.127)']\n", "(RayTrainWorker pid=68396) 2023-04-24 10:36:32,731\tINFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9443dd2a6dc49029ef7fb4d7a596729", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=67819) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00 TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder]\n", "(LightningTrainer pid=67819) 2023-04-24 10:36:34,052\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", "(LightningTrainer pid=67819) 2023-04-24 10:36:34,053\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1.\n", "(RayTrainWorker pid=68396) /home/ray/anaconda3/lib/python3.9/site-packages/xgboost/compat.py:31: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", "(RayTrainWorker pid=68396) from pandas import MultiIndex, Int64Index\n", "Downloading: 0%| | 0.00/416M [00:00 TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:36:59,629\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", "(RayTrainWorker pid=68398) 2023-04-24 10:36:59,629\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "70151d1b6133418fb5bf5e39b0089dd6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=68398) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00 TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder] [repeated 3x across cluster]\n", "(RayTrainWorker pid=68399) 2023-04-24 10:36:59,628\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 3x across cluster]\n", "(RayTrainWorker pid=68399) 2023-04-24 10:36:59,629\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1. [repeated 3x across cluster]\n", "(RayTrainWorker pid=68398) [W reducer.cpp:1298] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n", "(RayTrainWorker pid=68396) 2023-04-24 10:37:27.091660: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA\n", "(RayTrainWorker pid=68396) To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "(RayTrainWorker pid=68399) [W reducer.cpp:1298] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator()) [repeated 3x across cluster]\n", "(RayTrainWorker pid=68396) 2023-04-24 10:37:27.373013: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", "(RayTrainWorker pid=68396) 2023-04-24 10:37:28.763569: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "(RayTrainWorker pid=68396) 2023-04-24 10:37:28.763761: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", "(RayTrainWorker pid=68396) 2023-04-24 10:37:28.763770: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n", "(RayTrainWorker pid=68398) 2023-04-24 10:38:01,220\tINFO streaming_executor.py:87 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:38:01,221\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)\n", "(RayTrainWorker pid=68398) 2023-04-24 10:38:01,221\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "50090e60317342e8a2fa5747b2dfc7dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=68398) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00\n", "

Trial Progress

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name _report_on date done epoch experiment_taghostname iterations_since_restore matthews_correlationnode_ip pidshould_checkpoint step time_since_restore time_this_iter_s time_total_s timestamp train_loss training_iterationtrial_id
LightningTrainer_87ecf_00000validation_end2023-04-24_10-42-46True 4 0ip-10-0-60-127 5 0.58993110.0.60.12767819True 670 376.028 70.6609 376.028 1682358165 0.0119807 587ecf_00000
\n", "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "(RayTrainWorker pid=68398) 2023-04-24 10:39:03,705\tINFO streaming_executor.py:87 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder] [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:39:03,706\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:39:03,706\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1. [repeated 4x across cluster]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "125ccea4d26e48c0bf4e45610f9ae64a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=68398) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00 TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder] [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:40:09,873\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:40:09,873\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1. [repeated 4x across cluster]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "db4c22b67b844a6d8ff3e1882540bce4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=68398) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00 TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder] [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:41:18,552\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:41:18,552\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1. [repeated 4x across cluster]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ccc3d13c44b344e8891a81794fd17ffe", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=68398) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00 TaskPoolMapOperator[BatchMapper] -> AllToAllOperator[RandomizeBlockOrder] [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:42:29,325\tINFO streaming_executor.py:88 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False) [repeated 4x across cluster]\n", "(RayTrainWorker pid=68398) 2023-04-24 10:42:29,325\tINFO streaming_executor.py:90 -- Tip: To enable per-operator progress reporting, set RAY_DATA_VERBOSE_PROGRESS=1. [repeated 4x across cluster]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "55f6f7e8333341d1b57a890809bc90ad", "version_major": 2, "version_minor": 0 }, "text/plain": [ "(pid=68398) - RandomizeBlockOrder 1: 0%| | 0/1 [00:00`. \n", "\n", ":::" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Result(\n", " metrics={'_report_on': 'validation_end', 'train_loss': 0.011980690062046051, 'matthews_correlation': 0.5899314497879129, 'epoch': 4, 'step': 670, 'should_checkpoint': True, 'done': True, 'trial_id': '87ecf_00000', 'experiment_tag': '0'},\n", " path='/home/ray/ray_results/ptl-sent-classification/LightningTrainer_87ecf_00000_0_2023-04-24_10-36-23',\n", " checkpoint=LightningCheckpoint(local_path=/home/ray/ray_results/ptl-sent-classification/LightningTrainer_87ecf_00000_0_2023-04-24_10-36-23/checkpoint_000004)\n", ")" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Do Batch Inference with a Saved Checkpoint" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have fine-tuned the module, we can load the checkpoint into a BatchPredictor and perform fast inference with multiple GPUs. It will distribute the inference workload across multiple workers when calling `predict()` and run prediction on multiple shards of data in parallel. \n", "\n", "You can find more details in [Using Predictors for Inference](air-predictors)." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [] }, "outputs": [], "source": [ "from ray.train.batch_predictor import BatchPredictor\n", "from ray.train.lightning import LightningCheckpoint, LightningPredictor\n", "\n", "# Use in-memory checkpoint object\n", "checkpoint = result.checkpoint\n", "\n", "# You can also load a checkpoint from disk:\n", "# YOUR_CHECKPOINT_DIR = result.checkpoint.path\n", "# checkpoint = LightningCheckpoint.from_directory(YOUR_CHECKPOINT_DIR)\n", "\n", "batch_predictor = BatchPredictor(\n", " checkpoint=checkpoint,\n", " predictor_cls=LightningPredictor,\n", " use_gpu=True,\n", " model_class=SentimentModel,\n", " preprocessor=preprocessor,\n", ")\n", "\n", "# Use 2 GPUs for batch inference\n", "predictions = batch_predictor.predict(\n", " ray_datasets[\"validation\"],\n", " feature_columns=[\"input_ids\", \"attention_mask\", \"label\"],\n", " keep_columns=[\"label\"],\n", " batch_size=16,\n", " min_scoring_workers=2,\n", " max_scoring_workers=2,\n", " num_gpus_per_worker=1,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We obtained a Ray dataset containing predictions from `batch_predictor.predict()`. Now we can easily evaluate the results with just a few lines of code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Internally, BatchPredictor calls forward() method of the LightningModule.\n", "# Convert the logits tensor into labels with argmax.\n", "def argmax(batch):\n", " batch[\"predictions\"] = batch[\"predictions\"].apply(lambda x: np.argmax(x))\n", " return batch\n", "\n", "\n", "results = predictions.map_batches(argmax, batch_format=\"pandas\").to_pandas()" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " predictions label\n", "0 1 1\n", "1 1 1\n", "2 0 1\n", "3 1 1\n", "4 0 0\n", "5 1 0\n", "6 1 0\n", "7 1 1\n", "8 1 1\n", "9 1 1\n", "\n", "{'matthews_correlation': 0.5899314497879129}\n" ] } ], "source": [ "matthews_corr = metric.compute(\n", " predictions=results[\"predictions\"], references=results[\"label\"]\n", ")\n", "print(results.head(10))\n", "print(matthews_corr)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## What's next?\n", "\n", "- {ref}`Fine-tune a Large Language Model with LightningTrainer and FSDP `\n", "- {ref}`Hyperparameter searching with LightningTrainer + Ray Tune. `\n", "- {ref}`Experiment Tracking with Wandb, CometML, MLFlow, and Tensorboard in LightningTrainer `" ] } ], "metadata": { "kernelspec": { "display_name": "build", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "vscode": { "interpreter": { "hash": "178108d354ddc93ba36c4b7bfc5283800982aac0e7ca92cc0cf312ad1b8f8b20" } } }, "nbformat": 4, "nbformat_minor": 4 }