simple_model.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import os
  2. import json
  3. import argparse
  4. import torch
  5. from deepspeed.pipe import PipelineModule, LayerSpec
  6. from deepspeed.moe.layer import MoE
  7. import deepspeed.comm as dist
  8. class SimpleModel(torch.nn.Module):
  9. def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
  10. super(SimpleModel, self).__init__()
  11. self.linears = torch.nn.ModuleList(
  12. [torch.nn.Linear(hidden_dim,
  13. hidden_dim) for i in range(nlayers)])
  14. if empty_grad:
  15. self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  16. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  17. self.empty_grad = empty_grad
  18. def forward(self, x, y):
  19. if len(self.linears) == 1:
  20. x = self.linears[0](x)
  21. else:
  22. for i, l in enumerate(self.linears):
  23. x = self.linears[i // 2](x) + l(x)
  24. return self.cross_entropy_loss(x, y)
  25. class Curriculum_SimpleModel(SimpleModel):
  26. def __init__(self, hidden_dim, empty_grad=False):
  27. super(Curriculum_SimpleModel, self).__init__(hidden_dim, empty_grad)
  28. def forward(self, x, y, **kwargs):
  29. seqlen = kwargs.get('curriculum_seqlen', None)
  30. loss = super(Curriculum_SimpleModel, self).forward(x, y)
  31. return loss, seqlen
  32. class SimpleMoEModel(torch.nn.Module):
  33. def __init__(self, hidden_dim, num_experts=4, ep_size=1, use_residual=False):
  34. super(SimpleMoEModel, self).__init__()
  35. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  36. linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  37. self.linear2 = MoE(hidden_size=hidden_dim,
  38. expert=linear2,
  39. ep_size=ep_size,
  40. use_residual=use_residual,
  41. num_experts=num_experts,
  42. k=1)
  43. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  44. def forward(self, x, y):
  45. hidden_dim = x
  46. hidden_dim = self.linear(hidden_dim)
  47. output, _, _ = self.linear2(hidden_dim)
  48. hidden_dim = hidden_dim + output
  49. sentence_embed = hidden_dim.mean(1)
  50. return self.cross_entropy_loss(sentence_embed, y)
  51. class SimplePRMoEModel(torch.nn.Module):
  52. def __init__(self, hidden_dim, num_experts=2, ep_size=1, use_residual=False):
  53. super(SimplePRMoEModel, self).__init__()
  54. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  55. linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
  56. self.linear2 = MoE(hidden_size=hidden_dim,
  57. expert=linear2,
  58. ep_size=ep_size,
  59. use_residual=use_residual,
  60. num_experts=num_experts,
  61. k=1)
  62. linear3 = torch.nn.Linear(hidden_dim, hidden_dim)
  63. self.linear3 = MoE(hidden_size=hidden_dim,
  64. expert=linear3,
  65. ep_size=ep_size,
  66. use_residual=use_residual,
  67. num_experts=int(2 * num_experts),
  68. k=1)
  69. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  70. def forward(self, x, y):
  71. hidden_dim = x
  72. hidden_dim = self.linear(hidden_dim)
  73. output, _, _ = self.linear2(hidden_dim)
  74. output, _, _ = self.linear3(output)
  75. hidden_dim = hidden_dim + output
  76. sentence_embed = hidden_dim.mean(1)
  77. return self.cross_entropy_loss(sentence_embed, y)
  78. class UnusedParametersModel(SimpleModel):
  79. def __init__(self, hidden_dim, empty_grad=False):
  80. super().__init__(hidden_dim, empty_grad)
  81. self.unused_linear = torch.nn.Linear(hidden_dim, hidden_dim)
  82. class LinearStack(torch.nn.Module):
  83. def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
  84. super().__init__()
  85. self.input_dim = input_dim
  86. self.output_dim = output_dim
  87. self.hidden_dim = hidden_dim
  88. self.input_layer = torch.nn.Linear(in_features=self.input_dim,
  89. out_features=self.hidden_dim)
  90. self.layers = torch.nn.ModuleList([
  91. torch.nn.Linear(in_features=self.hidden_dim,
  92. out_features=self.hidden_dim,
  93. bias=False) for x in range(num_layers)
  94. ])
  95. self.output_layer = torch.nn.Linear(in_features=self.hidden_dim,
  96. out_features=self.output_dim)
  97. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  98. def forward(self, x, y):
  99. x = self.input_layer(x)
  100. for layer in self.layers:
  101. x = layer(x)
  102. x = self.output_layer(x)
  103. return x
  104. class LinearStackPipe(PipelineModule):
  105. def __init__(self,
  106. input_dim=128,
  107. hidden_dim=128,
  108. output_dim=128,
  109. num_layers=4,
  110. **kwargs):
  111. self.input_dim = input_dim
  112. self.output_dim = output_dim
  113. self.hidden_dim = hidden_dim
  114. self.num_layers = num_layers
  115. layers = []
  116. layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
  117. for x in range(self.num_layers):
  118. layers.append(
  119. LayerSpec(torch.nn.Linear,
  120. self.hidden_dim,
  121. self.hidden_dim,
  122. bias=False))
  123. layers.append(lambda x: x)
  124. layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))
  125. super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)
  126. class SimpleOptimizer(torch.optim.Optimizer):
  127. def __init__(self, params, lr=0.11072018):
  128. defaults = dict(lr=lr)
  129. super(SimpleOptimizer, self).__init__(params, defaults)
  130. def __setstate__(self, state):
  131. super(SimpleOptimizer, self).__setstate__(state)
  132. def step(self, closure=None):
  133. loss = None
  134. if closure is not None:
  135. loss = closure()
  136. for group in self.param_groups:
  137. for p in group['params']:
  138. if p.grad is None:
  139. continue
  140. d_p = p.grad.data
  141. p.data.add_(-group['lr'], d_p)
  142. return loss
  143. class HybridStateOptimizer(torch.optim.Optimizer):
  144. def __init__(self, params, lr=0.11072018):
  145. defaults = dict(lr=lr)
  146. super(HybridStateOptimizer, self).__init__(params, defaults)
  147. def __setstate__(self, state):
  148. super(HybridStateOptimizer, self).__setstate__(state)
  149. def step(self, closure=None):
  150. loss = None
  151. if closure is not None:
  152. loss = closure()
  153. for group in self.param_groups:
  154. for p in group['params']:
  155. if p.grad is None:
  156. continue
  157. state = self.state[p]
  158. if len(state) == 0:
  159. state['integer_step'] = 0
  160. state['tensor_step'] = torch.zeros(1, device=p.device)
  161. d_p = p.grad.data
  162. p.data.add_(-group['lr'], d_p)
  163. state['integer_step'] += 1
  164. state['tensor_step'] += 1
  165. return loss
  166. class PLD_SimpleModel(SimpleModel):
  167. def __init__(self, hidden_dim, empty_grad=False):
  168. super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad)
  169. def forward(self, x, y, **kwargs):
  170. pld = kwargs.get('progressive_layer_drop', False)
  171. theta = kwargs.get('pld_theta', 1.0)
  172. hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
  173. return hidden_dim
  174. def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
  175. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
  176. train_label = torch.empty(total_samples,
  177. dtype=torch.long,
  178. device=device).random_(hidden_dim)
  179. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  180. return train_dataset
  181. def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
  182. batch_size = model.train_micro_batch_size_per_gpu()
  183. train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
  184. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  185. return train_loader
  186. def sequence_dataloader(model,
  187. total_samples,
  188. hidden_dim,
  189. device,
  190. seq_len: int = 32,
  191. dtype=torch.half):
  192. batch_size = model.train_micro_batch_size_per_gpu()
  193. train_data = torch.randn(total_samples,
  194. seq_len,
  195. hidden_dim,
  196. device=device,
  197. dtype=dtype)
  198. train_label = torch.empty(total_samples,
  199. dtype=torch.long,
  200. device=device).random_(hidden_dim)
  201. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  202. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
  203. return train_loader
  204. def create_config_from_dict(tmpdir, config_dict):
  205. config_path = os.path.join(tmpdir, 'temp_config.json')
  206. with open(config_path, 'w') as fd:
  207. json.dump(config_dict, fd)
  208. return config_path
  209. def create_deepspeed_args():
  210. parser = argparse.ArgumentParser()
  211. args = parser.parse_args(args='')
  212. args.deepspeed = True
  213. if dist.is_initialized():
  214. # We assume up to one full node executing unit tests
  215. assert dist.get_world_size() <= torch.cuda.device_count()
  216. args.local_rank = dist.get_rank()
  217. return args
  218. def args_from_dict(tmpdir, config_dict):
  219. args = create_deepspeed_args()
  220. config_path = create_config_from_dict(tmpdir, config_dict)
  221. args.deepspeed_config = config_path
  222. return args