simple_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import json
  6. import argparse
  7. import torch
  8. from collections import OrderedDict
  9. from deepspeed.pipe import PipelineModule, LayerSpec
  10. from deepspeed.moe.layer import MoE
  11. from deepspeed.accelerator import get_accelerator
  12. import deepspeed.comm as dist
  13. from .common import preferred_dtype
  14. class SimpleModel(torch.nn.Module):
  15. def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
  16. super(SimpleModel, self).__init__()
  17. self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)])
  18. if empty_grad:
  19. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  20. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  21. self.empty_grad = empty_grad
  22. def forward(self, x, y):
  23. if len(self.linears) == 1:
  24. x = self.linears[0](x)
  25. else:
  26. for i, l in enumerate(self.linears):
  27. x = self.linears[i // 2](x) + l(x)
  28. return self.cross_entropy_loss(x, y)
  29. class SimpleFrozenModel(torch.nn.Module):
  30. def __init__(self, hidden_dim, empty_grad=False):
  31. super(SimpleFrozenModel, self).__init__()
  32. self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(2)])
  33. if empty_grad:
  34. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  35. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  36. self.empty_grad = empty_grad
  37. # Freeze first layer
  38. self.linears[0].weight.requires_grad = False
  39. self.linears[0].bias.requires_grad = False
  40. def custom_state_dict(self, *args, **kwargs):
  41. state_dict = super(SimpleFrozenModel, self).state_dict(*args, **kwargs)
  42. custom = OrderedDict()
  43. for k, v in state_dict.items():
  44. if 'linears.0.weight' not in k:
  45. custom[k] = v
  46. return custom
  47. def forward(self, x, y):
  48. if len(self.linears) == 1:
  49. x = self.linears[0](x)
  50. else:
  51. for i, l in enumerate(self.linears):
  52. x = self.linears[i // 2](x) + l(x)
  53. return self.cross_entropy_loss(x, y)
  54. class Curriculum_SimpleModel(SimpleModel):
  55. def __init__(self, hidden_dim, empty_grad=False):
  56. super(Curriculum_SimpleModel, self).__init__(hidden_dim, empty_grad)
  57. def forward(self, x, y, **kwargs):
  58. seqlen = kwargs.get('curriculum_seqlen', None)
  59. loss = super(Curriculum_SimpleModel, self).forward(x, y)
  60. return loss, seqlen
  61. class SimpleMoEModel(torch.nn.Module):
  62. def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
  63. super(SimpleMoEModel, self).__init__()
  64. self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim)
  65. expert = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.Linear(hidden_dim, hidden_dim))
  66. # using two MoE layers to check implications of sharing a single storage
  67. self.moe_1 = MoE(hidden_size=hidden_dim,
  68. expert=expert,
  69. ep_size=ep_size,
  70. use_residual=use_residual,
  71. num_experts=num_experts,
  72. k=1)
  73. # interleaving MoE modules with dense to create an opportunity
  74. # for gradients to be merged in ZeRO stage 2 average_tensor reduce bucket
  75. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  76. self.moe_2 = MoE(hidden_size=hidden_dim,
  77. expert=expert,
  78. ep_size=ep_size,
  79. use_residual=use_residual,
  80. num_experts=num_experts,
  81. k=1)
  82. self.linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
  83. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  84. def forward(self, x, y):
  85. hidden_dim = self.linear1(x)
  86. output, _, _ = self.moe_1(hidden_dim)
  87. output = self.linear2(output)
  88. output, _, _ = self.moe_2(output)
  89. output = self.linear3(output)
  90. hidden_dim = hidden_dim + output
  91. sentence_embed = hidden_dim.mean(1)
  92. return self.cross_entropy_loss(sentence_embed, y)
  93. class SimplePRMoEModel(torch.nn.Module):
  94. def __init__(self, hidden_dim, num_experts=2, ep_size=1, use_residual=False):
  95. super(SimplePRMoEModel, self).__init__()
  96. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  97. linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  98. self.linear2 = MoE(hidden_size=hidden_dim,
  99. expert=linear2,
  100. ep_size=ep_size,
  101. use_residual=use_residual,
  102. num_experts=num_experts,
  103. k=1)
  104. linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
  105. self.linear3 = MoE(hidden_size=hidden_dim,
  106. expert=linear3,
  107. ep_size=ep_size,
  108. use_residual=use_residual,
  109. num_experts=int(2 * num_experts),
  110. k=1)
  111. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  112. def forward(self, x, y):
  113. hidden_dim = x
  114. hidden_dim = self.linear(hidden_dim)
  115. output, _, _ = self.linear2(hidden_dim)
  116. output, _, _ = self.linear3(output)
  117. hidden_dim = hidden_dim + output
  118. sentence_embed = hidden_dim.mean(1)
  119. return self.cross_entropy_loss(sentence_embed, y)
  120. class UnusedParametersModel(SimpleModel):
  121. def __init__(self, hidden_dim, empty_grad=False):
  122. super().__init__(hidden_dim, empty_grad)
  123. self.unused_linear = torch.nn.Linear(hidden_dim, hidden_dim)
  124. class LinearStack(torch.nn.Module):
  125. def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
  126. super().__init__()
  127. self.input_dim = input_dim
  128. self.output_dim = output_dim
  129. self.hidden_dim = hidden_dim
  130. self.input_layer = torch.nn.Linear(in_features=self.input_dim, out_features=self.hidden_dim)
  131. self.layers = torch.nn.ModuleList([
  132. torch.nn.Linear(in_features=self.hidden_dim, out_features=self.hidden_dim, bias=False)
  133. for x in range(num_layers)
  134. ])
  135. self.output_layer = torch.nn.Linear(in_features=self.hidden_dim, out_features=self.output_dim)
  136. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  137. def forward(self, x, y):
  138. x = self.input_layer(x)
  139. for layer in self.layers:
  140. x = layer(x)
  141. x = self.output_layer(x)
  142. return x
  143. class LinearStackPipe(PipelineModule):
  144. def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4, **kwargs):
  145. self.input_dim = input_dim
  146. self.output_dim = output_dim
  147. self.hidden_dim = hidden_dim
  148. self.num_layers = num_layers
  149. layers = []
  150. layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
  151. for x in range(self.num_layers):
  152. layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.hidden_dim, bias=False))
  153. layers.append(lambda x: x)
  154. layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))
  155. super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)
  156. class SimpleOptimizer(torch.optim.Optimizer):
  157. def __init__(self, params, lr=0.11072018):
  158. defaults = dict(lr=lr)
  159. super(SimpleOptimizer, self).__init__(params, defaults)
  160. def __setstate__(self, state):
  161. super(SimpleOptimizer, self).__setstate__(state)
  162. def step(self, closure=None):
  163. loss = None
  164. if closure is not None:
  165. loss = closure()
  166. for group in self.param_groups:
  167. for p in group['params']:
  168. if p.grad is None:
  169. continue
  170. d_p = p.grad.data
  171. p.data.add_(-group['lr'], d_p)
  172. return loss
  173. class HybridStateOptimizer(torch.optim.Optimizer):
  174. def __init__(self, params, lr=0.11072018):
  175. defaults = dict(lr=lr)
  176. super(HybridStateOptimizer, self).__init__(params, defaults)
  177. def __setstate__(self, state):
  178. super(HybridStateOptimizer, self).__setstate__(state)
  179. def step(self, closure=None):
  180. loss = None
  181. if closure is not None:
  182. loss = closure()
  183. for group in self.param_groups:
  184. for p in group['params']:
  185. if p.grad is None:
  186. continue
  187. state = self.state[p]
  188. if len(state) == 0:
  189. state['integer_step'] = 0
  190. state['tensor_step'] = torch.zeros(1, device=p.device)
  191. d_p = p.grad.data
  192. p.data.add_(-group['lr'], d_p)
  193. state['integer_step'] += 1
  194. state['tensor_step'] += 1
  195. return loss
  196. class PLD_SimpleModel(SimpleModel):
  197. def __init__(self, hidden_dim, empty_grad=False):
  198. super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad)
  199. def forward(self, x, y, **kwargs):
  200. pld = kwargs.get('progressive_layer_drop', False)
  201. theta = kwargs.get('pld_theta', 1.0)
  202. hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
  203. return hidden_dim
  204. def random_dataset(total_samples, hidden_dim, device, dtype=preferred_dtype()):
  205. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
  206. train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
  207. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  208. return train_dataset
  209. def random_dataloader(model, total_samples, hidden_dim, device, dtype=preferred_dtype()):
  210. batch_size = model.train_micro_batch_size_per_gpu()
  211. train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
  212. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  213. return train_loader
  214. def sequence_dataloader(model, total_samples, hidden_dim, device, seq_len: int = 32, dtype=preferred_dtype()):
  215. batch_size = model.train_micro_batch_size_per_gpu()
  216. train_data = torch.randn(total_samples, seq_len, hidden_dim, device=device, dtype=dtype)
  217. train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
  218. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  219. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  220. return train_loader
  221. def create_config_from_dict(tmpdir, config_dict):
  222. config_path = os.path.join(tmpdir, 'temp_config.json')
  223. with open(config_path, 'w') as fd:
  224. json.dump(config_dict, fd)
  225. return config_path
  226. def create_deepspeed_args():
  227. parser = argparse.ArgumentParser()
  228. args = parser.parse_args(args='')
  229. args.deepspeed = True
  230. if dist.is_initialized():
  231. # We assume up to one full node executing unit tests
  232. assert dist.get_world_size() <= get_accelerator().device_count()
  233. args.local_rank = dist.get_rank()
  234. return args
  235. def args_from_dict(tmpdir, config_dict):
  236. args = create_deepspeed_args()
  237. config_path = create_config_from_dict(tmpdir, config_dict)
  238. args.deepspeed_config = config_path
  239. return args