simple_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import json
  6. import argparse
  7. import torch
  8. from deepspeed.pipe import PipelineModule, LayerSpec
  9. from deepspeed.moe.layer import MoE
  10. from deepspeed.accelerator import get_accelerator
  11. import deepspeed.comm as dist
  12. class SimpleModel(torch.nn.Module):
  13. def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
  14. super(SimpleModel, self).__init__()
  15. self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)])
  16. if empty_grad:
  17. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  18. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  19. self.empty_grad = empty_grad
  20. def forward(self, x, y):
  21. if len(self.linears) == 1:
  22. x = self.linears[0](x)
  23. else:
  24. for i, l in enumerate(self.linears):
  25. x = self.linears[i // 2](x) + l(x)
  26. return self.cross_entropy_loss(x, y)
  27. class SimpleFrozenModel(torch.nn.Module):
  28. def __init__(self, hidden_dim, empty_grad=False):
  29. super(SimpleFrozenModel, self).__init__()
  30. self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(2)])
  31. if empty_grad:
  32. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  33. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  34. self.empty_grad = empty_grad
  35. # Freeze first layer
  36. self.linears[0].weight.requires_grad = False
  37. self.linears[0].bias.requires_grad = False
  38. def forward(self, x, y):
  39. if len(self.linears) == 1:
  40. x = self.linears[0](x)
  41. else:
  42. for i, l in enumerate(self.linears):
  43. x = self.linears[i // 2](x) + l(x)
  44. return self.cross_entropy_loss(x, y)
  45. class Curriculum_SimpleModel(SimpleModel):
  46. def __init__(self, hidden_dim, empty_grad=False):
  47. super(Curriculum_SimpleModel, self).__init__(hidden_dim, empty_grad)
  48. def forward(self, x, y, **kwargs):
  49. seqlen = kwargs.get('curriculum_seqlen', None)
  50. loss = super(Curriculum_SimpleModel, self).forward(x, y)
  51. return loss, seqlen
  52. class SimpleMoEModel(torch.nn.Module):
  53. def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
  54. super(SimpleMoEModel, self).__init__()
  55. self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim)
  56. expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim))
  57. # using two MoE layers to check implications of sharing a single storage
  58. self.moe_1 = MoE(hidden_size=hidden_dim,
  59. expert=expert,
  60. ep_size=ep_size,
  61. use_residual=use_residual,
  62. num_experts=num_experts,
  63. k=1)
  64. # interleaving MoE modules with dense to create an opportunity
  65. # for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket
  66. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  67. self.moe_2 = MoE(hidden_size=hidden_dim,
  68. expert=expert,
  69. ep_size=ep_size,
  70. use_residual=use_residual,
  71. num_experts=num_experts,
  72. k=1)
  73. self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
  74. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  75. def forward(self, x, y):
  76. hidden_dim = self.linear1(x)
  77. output, _, _ = self.moe_1(hidden_dim)
  78. output = self.linear2(output)
  79. output, _, _ = self.moe_2(output)
  80. output = self.linear3(output)
  81. hidden_dim = hidden_dim + output
  82. sentence_embed = hidden_dim.mean(1)
  83. return self.cross_entropy_loss(sentence_embed, y)
  84. class SimplePRMoEModel(torch.nn.Module):
  85. def __init__(self, hidden_dim, num_experts=2, ep_size=1, use_residual=False):
  86. super(SimplePRMoEModel, self).__init__()
  87. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  88. linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  89. self.linear2 = MoE(hidden_size=hidden_dim,
  90. expert=linear2,
  91. ep_size=ep_size,
  92. use_residual=use_residual,
  93. num_experts=num_experts,
  94. k=1)
  95. linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
  96. self.linear3 = MoE(hidden_size=hidden_dim,
  97. expert=linear3,
  98. ep_size=ep_size,
  99. use_residual=use_residual,
  100. num_experts=int(2 * num_experts),
  101. k=1)
  102. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  103. def forward(self, x, y):
  104. hidden_dim = x
  105. hidden_dim = self.linear(hidden_dim)
  106. output, _, _ = self.linear2(hidden_dim)
  107. output, _, _ = self.linear3(output)
  108. hidden_dim = hidden_dim + output
  109. sentence_embed = hidden_dim.mean(1)
  110. return self.cross_entropy_loss(sentence_embed, y)
  111. class UnusedParametersModel(SimpleModel):
  112. def __init__(self, hidden_dim, empty_grad=False):
  113. super().__init__(hidden_dim, empty_grad)
  114. self.unused_linear = torch.nn.Linear(hidden_dim, hidden_dim)
  115. class LinearStack(torch.nn.Module):
  116. def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
  117. super().__init__()
  118. self.input_dim = input_dim
  119. self.output_dim = output_dim
  120. self.hidden_dim = hidden_dim
  121. self.input_layer = torch.nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim)
  122. self.layers = torch.nn.ModuleList([
  123. torch.nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim, bias=False)
  124. for x in range(num_layers)
  125. ])
  126. self.output_layer = torch.nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim)
  127. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  128. def forward(self, x, y):
  129. x = self.input_layer(x)
  130. for layer in self.layers:
  131. x = layer(x)
  132. x = self.output_layer(x)
  133. return x
  134. class LinearStackPipe(PipelineModule):
  135. def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4, **kwargs):
  136. self.input_dim = input_dim
  137. self.output_dim = output_dim
  138. self.hidden_dim = hidden_dim
  139. self.num_layers = num_layers
  140. layers = []
  141. layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
  142. for x in range(self.num_layers):
  143. layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.hidden_dim, bias=False))
  144. layers.append(lambda x: x)
  145. layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))
  146. super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)
  147. class SimpleOptimizer(torch.optim.Optimizer):
  148. def __init__(self, params, lr=0.11072018):
  149. defaults = dict(lr=lr)
  150. super(SimpleOptimizer, self).__init__(params, defaults)
  151. def __setstate__(self, state):
  152. super(SimpleOptimizer, self).__setstate__(state)
  153. def step(self, closure=None):
  154. loss = None
  155. if closure is not None:
  156. loss = closure()
  157. for group in self.param_groups:
  158. for p in group['params']:
  159. if p.grad is None:
  160. continue
  161. d_p = p.grad.data
  162. p.data.add_(-group['lr'], d_p)
  163. return loss
  164. class HybridStateOptimizer(torch.optim.Optimizer):
  165. def __init__(self, params, lr=0.11072018):
  166. defaults = dict(lr=lr)
  167. super(HybridStateOptimizer, self).__init__(params, defaults)
  168. def __setstate__(self, state):
  169. super(HybridStateOptimizer, self).__setstate__(state)
  170. def step(self, closure=None):
  171. loss = None
  172. if closure is not None:
  173. loss = closure()
  174. for group in self.param_groups:
  175. for p in group['params']:
  176. if p.grad is None:
  177. continue
  178. state = self.state[p]
  179. if len(state) == 0:
  180. state['integer_step'] = 0
  181. state['tensor_step'] = torch.zeros(1, device=p.device)
  182. d_p = p.grad.data
  183. p.data.add_(-group['lr'], d_p)
  184. state['integer_step'] += 1
  185. state['tensor_step'] += 1
  186. return loss
  187. class PLD_SimpleModel(SimpleModel):
  188. def __init__(self, hidden_dim, empty_grad=False):
  189. super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad)
  190. def forward(self, x, y, **kwargs):
  191. pld = kwargs.get('progressive_layer_drop', False)
  192. theta = kwargs.get('pld_theta', 1.0)
  193. hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
  194. return hidden_dim
  195. def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
  196. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
  197. train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
  198. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  199. return train_dataset
  200. def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
  201. batch_size = model.train_micro_batch_size_per_gpu()
  202. train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
  203. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  204. return train_loader
  205. def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=torch.half):
  206. batch_size = model.train_micro_batch_size_per_gpu()
  207. train_data = torch.randn(total_samples, seq_len, hidden_dim, device=device, dtype=dtype)
  208. train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
  209. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  210. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  211. return train_loader
  212. def create_config_from_dict(tmpdir, config_dict):
  213. config_path = os.path.join(tmpdir, 'temp_config.json')
  214. with open(config_path, 'w') as fd:
  215. json.dump(config_dict, fd)
  216. return config_path
  217. def create_deepspeed_args():
  218. parser = argparse.ArgumentParser()
  219. args = parser.parse_args(args='')
  220. args.deepspeed = True
  221. if dist.is_initialized():
  222. # We assume up to one full node executing unit tests
  223. assert dist.get_world_size() <= get_accelerator().device_count()
  224. args.local_rank = dist.get_rank()
  225. return args
  226. def args_from_dict(tmpdir, config_dict):
  227. args = create_deepspeed_args()
  228. config_path = create_config_from_dict(tmpdir, config_dict)
  229. args.deepspeed_config = config_path
  230. return args