123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import pytest
- import os
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import deepspeed
- import deepspeed.comm as dist
- import deepspeed.runtime.utils as ds_utils
- from deepspeed.utils.torch import required_torch_version
- from deepspeed.accelerator import get_accelerator
- from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
- from .util import no_child_process_in_deepspeed_io
- class AlexNet(nn.Module):
- def __init__(self, num_classes=10):
- super(AlexNet, self).__init__()
- self.features = nn.Sequential(
- nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.Conv2d(64, 192, kernel_size=5, padding=2),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- nn.Conv2d(192, 384, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(384, 256, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
- nn.ReLU(inplace=True),
- nn.MaxPool2d(kernel_size=2, stride=2),
- )
- self.classifier = nn.Linear(256, num_classes)
- self.loss_fn = nn.CrossEntropyLoss()
- def forward(self, x, y):
- x = self.features(x)
- x = x.view(x.size(0), -1)
- x = self.classifier(x)
- return self.loss_fn(x, y)
- class AlexNetPipe(AlexNet):
- def to_layers(self):
- layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
- return layers
- class AlexNetPipeSpec(PipelineModule):
- def __init__(self, num_classes=10, **kwargs):
- self.num_classes = num_classes
- specs = [
- LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=5),
- LayerSpec(nn.ReLU, inplace=True),
- LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
- LayerSpec(nn.Conv2d, 64, 192, kernel_size=5, padding=2),
- F.relu,
- LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
- LayerSpec(nn.Conv2d, 192, 384, kernel_size=3, padding=1),
- F.relu,
- LayerSpec(nn.Conv2d, 384, 256, kernel_size=3, padding=1),
- F.relu,
- LayerSpec(nn.Conv2d, 256, 256, kernel_size=3, padding=1),
- F.relu,
- LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
- lambda x: x.view(x.size(0), -1),
- LayerSpec(nn.Linear, 256, self.num_classes), # classifier
- ]
- super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
- # Define this here because we cannot pickle local lambda functions
- def cast_to_half(x):
- return x.half()
- def cifar_trainset(fp16=False):
- torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
- import torchvision.transforms as transforms
- transform_list = [
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]
- if fp16:
- transform_list.append(torchvision.transforms.Lambda(cast_to_half))
- transform = transforms.Compose(transform_list)
- local_rank = get_accelerator().current_device()
- # Only one rank per machine downloads.
- dist.barrier()
- if local_rank != 0:
- dist.barrier()
- data_root = os.getenv("TEST_DATA_DIR", "/tmp/")
- if os.getenv("CIFAR10_DATASET_PATH"):
- data_root = os.getenv("CIFAR10_DATASET_PATH")
- download = False
- else:
- data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data")
- download = True
- trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform)
- if local_rank == 0:
- dist.barrier()
- return trainset
- def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
- if required_torch_version(min_version=2.1):
- fork_kwargs = {"device_type": get_accelerator().device_name()}
- else:
- fork_kwargs = {}
- with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs):
- ds_utils.set_random_seed(seed)
- # disable dropout
- model.eval()
- trainset = cifar_trainset(fp16=fp16)
- config['local_rank'] = dist.get_rank()
- with no_child_process_in_deepspeed_io():
- engine, _, _, _ = deepspeed.initialize(config=config,
- model=model,
- model_parameters=[p for p in model.parameters()],
- training_data=trainset)
- losses = []
- for step in range(num_steps):
- loss = engine.train_batch()
- losses.append(loss.item())
- if step % 50 == 0 and dist.get_rank() == 0:
- print(f'STEP={step} LOSS={loss.item()}')
- if average_dp_losses:
- loss_tensor = torch.tensor(losses).to(get_accelerator().device_name())
- dist.all_reduce(loss_tensor)
- loss_tensor /= dist.get_world_size()
- losses = loss_tensor.tolist()
- return losses
|