simple_model.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import os
  2. import json
  3. import argparse
  4. import torch
  5. from deepspeed.pipe import PipelineModule, LayerSpec
  6. class SimpleModel(torch.nn.Module):
  7. def __init__(self, hidden_dim, empty_grad=False, rank=0):
  8. super(SimpleModel, self).__init__()
  9. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  10. if empty_grad:
  11. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  12. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  13. self.rank = rank
  14. self.empty_grad = empty_grad
  15. def forward(self, x, y):
  16. hidden_dim = x
  17. if self.rank == 0 and self.empty_grad:
  18. hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
  19. else:
  20. hidden_dim = self.linear(hidden_dim)
  21. return self.cross_entropy_loss(hidden_dim, y)
  22. class LinearStack(torch.nn.Module):
  23. def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
  24. super().__init__()
  25. self.input_dim = input_dim
  26. self.output_dim = output_dim
  27. self.hidden_dim = hidden_dim
  28. self.input_layer = VerboseLinear(in_features=self.input_dim,
  29. out_features=self.hidden_dim)
  30. self.layers = torch.nn.ModuleList([
  31. torch.nn.Linear(in_features=self.hidden_dim,
  32. out_features=self.hidden_dim,
  33. bias=False) for x in range(num_layers)
  34. ])
  35. self.output_layer = torch.nn.Linear(in_features=self.hidden_dim,
  36. out_features=self.output_dim)
  37. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  38. def forward(self, x, y):
  39. x = self.input_layer(x)
  40. for layer in self.layers:
  41. x = layer(x)
  42. x = self.output_layer(x)
  43. return x
  44. class LinearStackPipe(PipelineModule):
  45. def __init__(self,
  46. input_dim=128,
  47. hidden_dim=128,
  48. output_dim=128,
  49. num_layers=4,
  50. **kwargs):
  51. self.input_dim = input_dim
  52. self.output_dim = output_dim
  53. self.hidden_dim = hidden_dim
  54. self.num_layers = num_layers
  55. layers = []
  56. layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
  57. for x in range(self.num_layers):
  58. layers.append(
  59. LayerSpec(torch.nn.Linear,
  60. self.hidden_dim,
  61. self.hidden_dim,
  62. bias=False))
  63. layers.append(lambda x: x)
  64. layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))
  65. super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)
  66. class SimpleOptimizer(torch.optim.Optimizer):
  67. def __init__(self, params, lr=0.11072018):
  68. defaults = dict(lr=lr)
  69. super(SimpleOptimizer, self).__init__(params, defaults)
  70. def __setstate__(self, state):
  71. super(SimpleOptimizer, self).__setstate__(state)
  72. def step(self, closure=None):
  73. loss = None
  74. if closure is not None:
  75. loss = closure()
  76. for group in self.param_groups:
  77. for p in group['params']:
  78. if p.grad is None:
  79. continue
  80. d_p = p.grad.data
  81. p.data.add_(-group['lr'], d_p)
  82. return loss
  83. def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
  84. batch_size = model.train_micro_batch_size_per_gpu()
  85. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
  86. train_label = torch.empty(total_samples,
  87. dtype=torch.long,
  88. device=device).random_(hidden_dim)
  89. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  90. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  91. return train_loader
  92. def create_config_from_dict(tmpdir, config_dict):
  93. config_path = os.path.join(tmpdir, 'temp_config.json')
  94. with open(config_path, 'w') as fd:
  95. json.dump(config_dict, fd)
  96. return config_path
  97. def args_from_dict(tmpdir, config_dict):
  98. config_path = create_config_from_dict(tmpdir, config_dict)
  99. parser = argparse.ArgumentParser()
  100. args = parser.parse_args(args='')
  101. args.deepspeed = True
  102. args.deepspeed_config = config_path
  103. if torch.distributed.is_initialized():
  104. # We assume up to one full node executing unit tests
  105. assert torch.distributed.get_world_size() <= torch.cuda.device_count()
  106. args.local_rank = torch.distributed.get_rank()
  107. else:
  108. args.local_rank = 0
  109. return args