alexnet_model.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import os
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import deepspeed
  10. import deepspeed.comm as dist
  11. import deepspeed.runtime.utils as ds_utils
  12. from deepspeed.utils.torch import required_torch_version
  13. from deepspeed.accelerator import get_accelerator
  14. from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
  15. from .util import no_child_process_in_deepspeed_io
  16. class AlexNet(nn.Module):
  17. def __init__(self, num_classes=10):
  18. super(AlexNet, self).__init__()
  19. self.features = nn.Sequential(
  20. nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
  21. nn.ReLU(inplace=True),
  22. nn.MaxPool2d(kernel_size=2, stride=2),
  23. nn.Conv2d(64, 192, kernel_size=5, padding=2),
  24. nn.ReLU(inplace=True),
  25. nn.MaxPool2d(kernel_size=2, stride=2),
  26. nn.Conv2d(192, 384, kernel_size=3, padding=1),
  27. nn.ReLU(inplace=True),
  28. nn.Conv2d(384, 256, kernel_size=3, padding=1),
  29. nn.ReLU(inplace=True),
  30. nn.Conv2d(256, 256, kernel_size=3, padding=1),
  31. nn.ReLU(inplace=True),
  32. nn.MaxPool2d(kernel_size=2, stride=2),
  33. )
  34. self.classifier = nn.Linear(256, num_classes)
  35. self.loss_fn = nn.CrossEntropyLoss()
  36. def forward(self, x, y):
  37. x = self.features(x)
  38. x = x.view(x.size(0), -1)
  39. x = self.classifier(x)
  40. return self.loss_fn(x, y)
  41. class AlexNetPipe(AlexNet):
  42. def to_layers(self):
  43. layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
  44. return layers
  45. class AlexNetPipeSpec(PipelineModule):
  46. def __init__(self, num_classes=10, **kwargs):
  47. self.num_classes = num_classes
  48. specs = [
  49. LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=5),
  50. LayerSpec(nn.ReLU, inplace=True),
  51. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  52. LayerSpec(nn.Conv2d, 64, 192, kernel_size=5, padding=2),
  53. F.relu,
  54. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  55. LayerSpec(nn.Conv2d, 192, 384, kernel_size=3, padding=1),
  56. F.relu,
  57. LayerSpec(nn.Conv2d, 384, 256, kernel_size=3, padding=1),
  58. F.relu,
  59. LayerSpec(nn.Conv2d, 256, 256, kernel_size=3, padding=1),
  60. F.relu,
  61. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  62. lambda x: x.view(x.size(0), -1),
  63. LayerSpec(nn.Linear, 256, self.num_classes), # classifier
  64. ]
  65. super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
  66. # Define this here because we cannot pickle local lambda functions
  67. def cast_to_half(x):
  68. return x.half()
  69. def cifar_trainset(fp16=False):
  70. torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
  71. import torchvision.transforms as transforms
  72. transform_list = [
  73. transforms.ToTensor(),
  74. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  75. ]
  76. if fp16:
  77. transform_list.append(torchvision.transforms.Lambda(cast_to_half))
  78. transform = transforms.Compose(transform_list)
  79. local_rank = get_accelerator().current_device()
  80. # Only one rank per machine downloads.
  81. dist.barrier()
  82. if local_rank != 0:
  83. dist.barrier()
  84. data_root = os.getenv("TEST_DATA_DIR", "/tmp/")
  85. if os.getenv("CIFAR10_DATASET_PATH"):
  86. data_root = os.getenv("CIFAR10_DATASET_PATH")
  87. download = False
  88. else:
  89. data_root = os.path.join(os.getenv("TEST_DATA_DIR", "/tmp"), "cifar10-data")
  90. download = True
  91. trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=download, transform=transform)
  92. if local_rank == 0:
  93. dist.barrier()
  94. return trainset
  95. def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
  96. if required_torch_version(min_version=2.1):
  97. fork_kwargs = {"device_type": get_accelerator().device_name()}
  98. else:
  99. fork_kwargs = {}
  100. with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()], **fork_kwargs):
  101. ds_utils.set_random_seed(seed)
  102. # disable dropout
  103. model.eval()
  104. trainset = cifar_trainset(fp16=fp16)
  105. config['local_rank'] = dist.get_rank()
  106. with no_child_process_in_deepspeed_io():
  107. engine, _, _, _ = deepspeed.initialize(config=config,
  108. model=model,
  109. model_parameters=[p for p in model.parameters()],
  110. training_data=trainset)
  111. losses = []
  112. for step in range(num_steps):
  113. loss = engine.train_batch()
  114. losses.append(loss.item())
  115. if step % 50 == 0 and dist.get_rank() == 0:
  116. print(f'STEP={step} LOSS={loss.item()}')
  117. if average_dp_losses:
  118. loss_tensor = torch.tensor(losses).to(get_accelerator().device_name())
  119. dist.all_reduce(loss_tensor)
  120. loss_tensor /= dist.get_world_size()
  121. losses = loss_tensor.tolist()
  122. return losses