test_pipe.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import deepspeed.comm as dist
  6. import pytest
  7. import deepspeed
  8. import deepspeed.runtime.utils as ds_utils
  9. from deepspeed.runtime.pipe.topology import PipeDataParallelTopology
  10. PipeTopo = PipeDataParallelTopology
  11. from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec
  12. from .common import distributed_test
  13. def rel_diff(A, B):
  14. return abs(A - B) / abs(A)
  15. # All models
  16. from .simple_model import args_from_dict
  17. class AlexNet(nn.Module):
  18. def __init__(self, num_classes=10):
  19. super(AlexNet, self).__init__()
  20. self.features = nn.Sequential(
  21. nn.Conv2d(3,
  22. 64,
  23. kernel_size=11,
  24. stride=4,
  25. padding=5),
  26. nn.ReLU(inplace=True),
  27. nn.MaxPool2d(kernel_size=2,
  28. stride=2),
  29. nn.Conv2d(64,
  30. 192,
  31. kernel_size=5,
  32. padding=2),
  33. nn.ReLU(inplace=True),
  34. nn.MaxPool2d(kernel_size=2,
  35. stride=2),
  36. nn.Conv2d(192,
  37. 384,
  38. kernel_size=3,
  39. padding=1),
  40. nn.ReLU(inplace=True),
  41. nn.Conv2d(384,
  42. 256,
  43. kernel_size=3,
  44. padding=1),
  45. nn.ReLU(inplace=True),
  46. nn.Conv2d(256,
  47. 256,
  48. kernel_size=3,
  49. padding=1),
  50. nn.ReLU(inplace=True),
  51. nn.MaxPool2d(kernel_size=2,
  52. stride=2),
  53. )
  54. self.classifier = nn.Linear(256, num_classes)
  55. self.loss_fn = nn.CrossEntropyLoss()
  56. def forward(self, x, y):
  57. x = self.features(x)
  58. x = x.view(x.size(0), -1)
  59. x = self.classifier(x)
  60. return self.loss_fn(x, y)
  61. class AlexNetPipe(AlexNet):
  62. def to_layers(self):
  63. layers = [*self.features, lambda x: x.view(x.size(0), -1), self.classifier]
  64. return layers
  65. class AlexNetPipeSpec(PipelineModule):
  66. def __init__(self, num_classes=10, **kwargs):
  67. self.num_classes = num_classes
  68. specs = [
  69. LayerSpec(nn.Conv2d, 3, 64, kernel_size=11, stride=4, padding=5),
  70. LayerSpec(nn.ReLU, inplace=True),
  71. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  72. LayerSpec(nn.Conv2d, 64, 192, kernel_size=5, padding=2),
  73. F.relu,
  74. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  75. LayerSpec(nn.Conv2d, 192, 384, kernel_size=3, padding=1),
  76. F.relu,
  77. LayerSpec(nn.Conv2d, 384, 256, kernel_size=3, padding=1),
  78. F.relu,
  79. LayerSpec(nn.Conv2d, 256, 256, kernel_size=3, padding=1),
  80. F.relu,
  81. LayerSpec(nn.MaxPool2d, kernel_size=2, stride=2),
  82. lambda x: x.view(x.size(0), -1),
  83. LayerSpec(nn.Linear, 256, self.num_classes), # classifier
  84. ]
  85. super().__init__(layers=specs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
  86. def cifar_trainset(fp16=False):
  87. import torchvision
  88. import torchvision.transforms as transforms
  89. transform_list = [
  90. transforms.ToTensor(),
  91. transforms.Normalize((0.5,
  92. 0.5,
  93. 0.5),
  94. (0.5,
  95. 0.5,
  96. 0.5)),
  97. ]
  98. if fp16:
  99. transform_list.append(torchvision.transforms.Lambda(lambda x: x.half()))
  100. transform = transforms.Compose(transform_list)
  101. local_rank = torch.cuda.current_device()
  102. # Only one rank per machine downloads.
  103. dist.barrier()
  104. if local_rank != 0:
  105. dist.barrier()
  106. trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10-data',
  107. train=True,
  108. download=True,
  109. transform=transform)
  110. if local_rank == 0:
  111. dist.barrier()
  112. return trainset
  113. def train_cifar(model, args, num_steps=400, average_dp_losses=True, fp16=True, seed=123):
  114. with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
  115. ds_utils.set_random_seed(seed)
  116. # disable dropout
  117. model.eval()
  118. trainset = cifar_trainset(fp16=fp16)
  119. args.local_rank = dist.get_rank()
  120. engine, _, _, _ = deepspeed.initialize(
  121. args=args,
  122. model=model,
  123. model_parameters=[p for p in model.parameters()],
  124. training_data=trainset)
  125. losses = []
  126. for step in range(num_steps):
  127. loss = engine.train_batch()
  128. losses.append(loss.item())
  129. if step % 50 == 0 and dist.get_rank() == 0:
  130. print(f'STEP={step} LOSS={loss.item()}')
  131. if average_dp_losses:
  132. loss_tensor = torch.tensor(losses).cuda()
  133. dist.all_reduce(loss_tensor)
  134. loss_tensor /= dist.get_world_size()
  135. losses = loss_tensor.tolist()
  136. return losses
  137. @pytest.mark.skip(reason="been seeing nondeterministic failures, skipping for now")
  138. @pytest.mark.parametrize('topo',
  139. [
  140. PipeTopo(num_pp=1,
  141. num_dp=4),
  142. PipeTopo(num_pp=2,
  143. num_dp=2),
  144. PipeTopo(num_pp=4,
  145. num_dp=1),
  146. ])
  147. def test_pipe_cifar10(topo, tmpdir):
  148. config_dict = {
  149. "train_batch_size": 16,
  150. "train_micro_batch_size_per_gpu": 4,
  151. "steps_per_print": 20,
  152. "optimizer": {
  153. "type": "Adam",
  154. "params": {
  155. "lr": 0.001,
  156. "betas": [0.9,
  157. 0.999],
  158. "eps": 1e-8,
  159. "weight_decay": 3e-7
  160. }
  161. },
  162. "zero_optimization": {
  163. "stage": 0
  164. },
  165. "fp16": {
  166. "enabled": False
  167. },
  168. "pipeline": {
  169. "seed_layers": True,
  170. "activation_checkpoint_interval": 1
  171. }
  172. }
  173. args = args_from_dict(tmpdir, config_dict)
  174. # Allocate model for consistent initial weights.
  175. init_net = AlexNetPipe()
  176. @distributed_test(world_size=4)
  177. def _helper(topo, tmpdir, steps=500):
  178. assert steps >= 100
  179. base_net = copy.deepcopy(init_net)
  180. base_model = PipelineModule(layers=base_net.to_layers(),
  181. num_stages=1,
  182. loss_fn=nn.CrossEntropyLoss())
  183. # Train with just data parallelism
  184. base_losses = train_cifar(base_model,
  185. args,
  186. num_steps=steps,
  187. fp16=config_dict['fp16']['enabled'])
  188. test_net = copy.deepcopy(init_net)
  189. test_model = PipelineModule(layers=test_net.to_layers(),
  190. topology=topo,
  191. loss_fn=nn.CrossEntropyLoss())
  192. #test_model = AlexNetPipe(num_classes=10,
  193. # topology=test_topo,
  194. # seed_layers=config_dict['pipeline']['seed_layers'])
  195. test_losses = train_cifar(test_model,
  196. args,
  197. num_steps=steps,
  198. fp16=config_dict['fp16']['enabled'])
  199. abs_diffs = [l0 - l1 for l0, l1 in zip(base_losses, test_losses)]
  200. rel_diffs = [rel_diff(l0, l1) for l0, l1 in zip(base_losses, test_losses)]
  201. if dist.get_rank() == 0:
  202. print(
  203. f'abs min={min(abs_diffs)} max={max(abs_diffs)} avg={sum(abs_diffs)/len(abs_diffs)}'
  204. )
  205. print(
  206. f'rel min={min(rel_diffs)} max={max(rel_diffs)} avg={sum(rel_diffs)/len(rel_diffs)}'
  207. )
  208. print(
  209. f'first: base={base_losses[0]} test={test_losses[0]} abs={abs_diffs[0]} rel={rel_diffs[0]}'
  210. )
  211. for lastX in [1, 10, 100]:
  212. base_avg = sum(base_losses[-lastX:]) / lastX
  213. test_avg = sum(test_losses[-lastX:]) / lastX
  214. print(
  215. f'last-{lastX}: base={base_avg} test={test_avg} abs={base_avg - test_avg} rel={rel_diff(base_avg, test_avg)}'
  216. )
  217. lastX = 100
  218. base = base_losses[-lastX:]
  219. base_avg = sum(base) / len(base)
  220. test = test_losses[-lastX:]
  221. test_avg = sum(test) / len(test)
  222. assert rel_diff(base_avg, test_avg) < 0.03
  223. _helper(topo, tmpdir)