alexnet_model.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import os
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import deepspeed
  10. import deepspeed.comm as dist
  11. import deepspeed.runtime.utils as ds_utils
  12. from deepspeed.accelerator import get_accelerator
  13. from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
  14. class AlexNet(nn.Module):
  15. def __init__(self, num_classes=10):
  16. super(AlexNet, self).__init__()
  17. self.features = nn.Sequential(
  18. nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
  19. nn.ReLU(inplace=True),
  20. nn.MaxPool2d(kernel_size=2, stride=2),
  21. nn.Conv2d(64, 192, kernel_size=5, padding=2),
  22. nn.ReLU(inplace=True),
  23. nn.MaxPool2d(kernel_size=2, stride=2),
  24. nn.Conv2d(192, 384, kernel_size=3, padding=1),
  25. nn.ReLU(inplace=True),
  26. nn.Conv2d(384, 256, kernel_size=3, padding=1),
  27. nn.ReLU(inplace=True),
  28. nn.Conv2d(256, 256, kernel_size=3, padding=1),
  29. nn.ReLU(inplace=True),
  30. nn.MaxPool2d(kernel_size=2, stride=2),
  31. )
  32. self.classifier = nn.Linear(256, num_classes)
  33. self.loss_fn = nn.CrossEntropyLoss()
  34. def forward(self, x, y):
  35. x = self.features(x)
  36. x = x.view(x.size(0), -1)
  37. x = self.classifier(x)
  38. return self.loss_fn(x, y)
  39. class AlexNetPipe(AlexNet):
  40. def to_layers(self):
  41. layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
  42. return layers
  43. class AlexNetPipeSpec(PipelineModule):
  44. def __init__(self, num_classes=10, **kwargs):
  45. self.num_classes = num_classes
  46. specs = [
  47. LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=5),
  48. LayerSpec(nn.ReLU, inplace=True),
  49. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  50. LayerSpec(nn.Conv2d, 64, 192, kernel_size=5, padding=2),
  51. F.relu,
  52. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  53. LayerSpec(nn.Conv2d, 192, 384, kernel_size=3, padding=1),
  54. F.relu,
  55. LayerSpec(nn.Conv2d, 384, 256, kernel_size=3, padding=1),
  56. F.relu,
  57. LayerSpec(nn.Conv2d, 256, 256, kernel_size=3, padding=1),
  58. F.relu,
  59. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  60. lambda x: x.view(x.size(0), -1),
  61. LayerSpec(nn.Linear, 256, self.num_classes), # classifier
  62. ]
  63. super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
  64. # Define this here because we cannot pickle local lambda functions
  65. def cast_to_half(x):
  66. return x.half()
  67. def cifar_trainset(fp16=False):
  68. torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
  69. import torchvision.transforms as transforms
  70. transform_list = [
  71. transforms.ToTensor(),
  72. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  73. ]
  74. if fp16:
  75. transform_list.append(torchvision.transforms.Lambda(cast_to_half))
  76. transform = transforms.Compose(transform_list)
  77. local_rank = get_accelerator().current_device()
  78. # Only one rank per machine downloads.
  79. dist.barrier()
  80. if local_rank != 0:
  81. dist.barrier()
  82. data_root = os.getenv("TEST_DATA_DIR", "/tmp/")
  83. trainset = torchvision.datasets.CIFAR10(root=os.path.join(data_root, "cifar10-data"),
  84. train=True,
  85. download=True,
  86. transform=transform)
  87. if local_rank == 0:
  88. dist.barrier()
  89. return trainset
  90. def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
  91. with get_accelerator().random().fork_rng(devices=[get_accelerator().current_device_name()]):
  92. ds_utils.set_random_seed(seed)
  93. # disable dropout
  94. model.eval()
  95. trainset = cifar_trainset(fp16=fp16)
  96. config['local_rank'] = dist.get_rank()
  97. # deepspeed_io defaults to creating a dataloader that uses a
  98. # multiprocessing pool. Our tests use pools and we cannot nest pools in
  99. # python. Therefore we're injecting this kwarg to ensure that no pools
  100. # are used in the dataloader.
  101. old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io
  102. def new_method(*args, **kwargs):
  103. kwargs["num_local_io_workers"] = 0
  104. return old_method(*args, **kwargs)
  105. deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method
  106. engine, _, _, _ = deepspeed.initialize(config=config,
  107. model=model,
  108. model_parameters=[p for p in model.parameters()],
  109. training_data=trainset)
  110. losses = []
  111. for step in range(num_steps):
  112. loss = engine.train_batch()
  113. losses.append(loss.item())
  114. if step % 50 == 0 and dist.get_rank() == 0:
  115. print(f'STEP={step} LOSS={loss.item()}')
  116. if average_dp_losses:
  117. loss_tensor = torch.tensor(losses).to(get_accelerator().device_name())
  118. dist.all_reduce(loss_tensor)
  119. loss_tensor /= dist.get_world_size()
  120. losses = loss_tensor.tolist()
  121. return losses