test_cuda_backward.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import argparse
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. import pytest
  6. import json
  7. import random
  8. import time
  9. import copy
  10. from torch import nn
  11. from modelingpreln import BertEncoder as BertEncoderPreln
  12. from modeling import BertEncoder as BertEncoderPostln
  13. from modeling import BertConfig, BertLayerNorm
  14. from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  15. import sys
  16. def check_equal(first, second, atol=1e-2, verbose=False):
  17. diction_x = {}
  18. diction_y = {}
  19. for i, (x, y) in enumerate(zip(first, second)):
  20. print(x[1], y[1])
  21. for i, (x, y) in enumerate(zip(first, second)):
  22. k = 0
  23. while (diction_x.get((k, x[1])) is not None):
  24. k = k + 1
  25. diction_x[k, x[1]] = x[0]
  26. k = 0
  27. while (diction_y.get((k, y[1])) is not None):
  28. k = k + 1
  29. diction_y[k, y[1]] = y[0]
  30. if verbose:
  31. print()
  32. for i, (x, y) in enumerate(zip(diction_x, diction_y)):
  33. print(x, y)
  34. for i, (x, y) in enumerate(zip(diction_x, diction_y)):
  35. if (x[0] == 1): continue
  36. print("checking ", x[1], ":")
  37. y = diction_y[x[0], x[1]]
  38. x = diction_x[x[0], x[1]]
  39. x = x.cpu().detach().numpy()
  40. y = y.cpu().detach().numpy()
  41. print(x)
  42. print(y)
  43. avgx = np.sum(abs(x), dtype=float)
  44. countx = x.shape[0]
  45. for i in range(len(x.shape) - 1):
  46. countx *= x.shape[i + 1]
  47. avgx = np.sum(avgx)
  48. tollerance = 1
  49. if avgx != float('inf') and avgx != -float('inf'):
  50. avgx = avgx / countx
  51. tollerance = avgx * atol
  52. print("tollerance is ", tollerance)
  53. if verbose:
  54. print("x = {}".format(x.flatten()))
  55. print("y = {}".format(y.flatten()))
  56. print('-' * 80)
  57. np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=tollerance)
  58. def zero_grad(variables):
  59. for variable in variables:
  60. variable.grad.zero_()
  61. device = torch.device("cuda")
  62. kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
  63. kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}
  64. class DSEncoder(nn.Module):
  65. def __init__(self, config, weights, biases):
  66. super(DSEncoder, self).__init__()
  67. self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  68. self.layer = nn.ModuleList([
  69. copy.deepcopy(DeepSpeedTransformerLayer(i,
  70. config,
  71. weights,
  72. biases))
  73. for i in range(config.num_hidden_layers)
  74. ])
  75. self.grads = []
  76. self.pre_or_post = config.pre_layer_norm
  77. def forward(self,
  78. hidden_states,
  79. attention_mask,
  80. output_all_encoded_layers=True,
  81. checkpoint_activations=False):
  82. all_encoder_layers = []
  83. def custom(start, end):
  84. def custom_forward(*inputs):
  85. layers = self.layer[start:end]
  86. x_ = inputs[0]
  87. for layer in layers:
  88. x_ = layer(x_, inputs[1])
  89. return x_
  90. return custom_forward
  91. if checkpoint_activations:
  92. l = 0
  93. num_layers = len(self.layer)
  94. chunk_length = math.ceil(math.sqrt(num_layers))
  95. while l < num_layers:
  96. hidden_states = checkpoint.checkpoint(custom(l,
  97. l + chunk_length),
  98. hidden_states,
  99. attention_mask * 1)
  100. l += chunk_length
  101. # decoder layers
  102. else:
  103. for i, layer_module in enumerate(self.layer):
  104. hidden_states = layer_module(hidden_states, attention_mask, self.grads)
  105. hidden_states.register_hook(
  106. lambda x,
  107. self=self: self.grads.append([x,
  108. "hidden_state"]))
  109. if output_all_encoded_layers:
  110. all_encoder_layers.append(hidden_states)
  111. if not output_all_encoded_layers or checkpoint_activations:
  112. if (self.pre_or_post):
  113. hidden_states = self.FinalLayerNorm(hidden_states)
  114. all_encoder_layers.append(hidden_states)
  115. return all_encoder_layers
  116. def get_grads(self):
  117. return self.grads
  118. def create_models(ds_config):
  119. bert_config = BertConfig(vocab_size_or_config_json_file=119547,
  120. hidden_size=ds_config.hidden_size,
  121. num_hidden_layers=ds_config.num_hidden_layers,
  122. num_attention_heads=ds_config.heads,
  123. intermediate_size=4 * ds_config.hidden_size,
  124. hidden_act="gelu",
  125. hidden_dropout_prob=ds_config.hidden_dropout_ratio,
  126. attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
  127. max_position_embeddings=ds_config.max_seq_length,
  128. type_vocab_size=2,
  129. initializer_range=ds_config.initializer_range)
  130. weights = []
  131. biases = []
  132. for i in range(4):
  133. weights.append(
  134. nn.Parameter(torch.Tensor(ds_config.hidden_size,
  135. ds_config.hidden_size)))
  136. weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)
  137. weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  138. weights[4].data.fill_(1.0)
  139. weights.append(
  140. nn.Parameter(torch.Tensor(4 * ds_config.hidden_size,
  141. ds_config.hidden_size)))
  142. weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
  143. weights.append(
  144. nn.Parameter(torch.Tensor(ds_config.hidden_size,
  145. 4 * ds_config.hidden_size)))
  146. weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
  147. weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  148. weights[7].data.fill_(1.0)
  149. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  150. biases[0].data.zero_()
  151. for i in range(4):
  152. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  153. biases[i + 1].data.zero_()
  154. biases.append(nn.Parameter(torch.Tensor(4 * ds_config.hidden_size)))
  155. biases[5].data.zero_()
  156. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  157. biases[6].data.zero_()
  158. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  159. biases[7].data.zero_()
  160. if (ds_config.pre_layer_norm):
  161. bert_encoder = BertEncoderPreln(bert_config, weights, biases)
  162. else:
  163. bert_encoder = BertEncoderPostln(bert_config, weights, biases)
  164. ds_encoder = DSEncoder(ds_config, weights, biases)
  165. if ds_config.fp16:
  166. bert_encoder.half()
  167. ds_encoder.half()
  168. bert_encoder.cuda()
  169. ds_encoder.cuda()
  170. return bert_encoder, ds_encoder
  171. def set_seed(seed):
  172. random.seed(seed)
  173. np.random.seed(seed)
  174. torch.manual_seed(seed)
  175. def run_backward(ds_config, atol=1e-2, verbose=False):
  176. set_seed(123)
  177. bert_encoder, ds_encoder = create_models(ds_config)
  178. # prepare test data
  179. kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
  180. hidden_states = torch.randn(ds_config.batch_size,
  181. ds_config.max_seq_length,
  182. ds_config.hidden_size,
  183. **kwargs)
  184. input_mask = torch.randn(ds_config.batch_size,
  185. 1,
  186. 1,
  187. ds_config.max_seq_length,
  188. **kwargs)
  189. Y = torch.randn(ds_config.batch_size,
  190. ds_config.max_seq_length,
  191. ds_config.hidden_size,
  192. **kwargs)
  193. # run baseline
  194. base_results = bert_encoder(hidden_states,
  195. input_mask,
  196. output_all_encoded_layers=False,
  197. checkpoint_activations=False)
  198. loss = (Y - base_results[0]).pow(2).sum()
  199. loss.backward()
  200. base_grads = bert_encoder.get_grads()
  201. # run ds
  202. ds_results = ds_encoder(hidden_states,
  203. input_mask,
  204. output_all_encoded_layers=False,
  205. checkpoint_activations=False)
  206. loss = (Y - ds_results[0]).pow(2).sum()
  207. loss.backward()
  208. ds_grads = ds_encoder.get_grads()
  209. # check grads
  210. check_equal(base_grads, ds_grads, atol=atol, verbose=verbose)
  211. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
  212. [
  213. (3,1024,128,16,24,True,False, 0.05),
  214. (3,1024,128,16,24,True,True, 0.05),
  215. (3,1024,128,16,24,False,False, 0.1),
  216. (3,1024,128,16,24,False,True, 0.2),
  217. ]) # yapf: disable
  218. def test_backward(batch_size,
  219. hidden_size,
  220. seq_len,
  221. heads,
  222. num_layers,
  223. is_preln,
  224. use_fp16,
  225. atol):
  226. # Only run fp16 test cases on devices with 7+ capability.
  227. major, _ = torch.cuda.get_device_capability()
  228. if major < 7 and (use_fp16 is True or is_preln is False):
  229. return
  230. ds_config = DeepSpeedTransformerConfig()
  231. ds_config.layer_id = None
  232. ds_config.batch_size = batch_size
  233. ds_config.hidden_size = hidden_size
  234. ds_config.max_seq_length = seq_len
  235. ds_config.heads = heads
  236. ds_config.attn_dropout_ratio = 0.0
  237. ds_config.hidden_dropout_ratio = 0.0
  238. ds_config.num_hidden_layers = num_layers
  239. ds_config.pre_layer_norm = is_preln
  240. ds_config.initializer_range = 0.02
  241. ds_config.fp16 = use_fp16
  242. run_backward(ds_config, atol=atol)
  243. #@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol',
  244. # [
  245. # (3,1024,128,16,24,True,False, 0.07),
  246. # (3,1024,128,16,24,True,True, 0.05),
  247. # (3,1024,128,16,24,False,False, 0.1),
  248. # (3,1024,128,16,24,False,True, 0.2),
  249. # ]) # yapf: disable
  250. #def test_backward_stochastic(batch_size,
  251. # hidden_size,
  252. # seq_len,
  253. # heads,
  254. # num_layers,
  255. # is_preln,
  256. # use_fp16,
  257. # atol):
  258. # # Only run fp16 test cases on devices with 7+ capability.
  259. # major, _ = torch.cuda.get_device_capability()
  260. # if major < 7 and (use_fp16 is True or is_preln is False):
  261. # return
  262. #
  263. # ds_config = DeepSpeedTransformerConfig()
  264. # ds_config.layer_id = None
  265. # ds_config.batch_size = batch_size
  266. # ds_config.hidden_size = hidden_size
  267. # ds_config.max_seq_length = seq_len
  268. # ds_config.heads = heads
  269. # ds_config.attn_dropout_ratio = 0.0
  270. # ds_config.hidden_dropout_ratio = 0.0
  271. # ds_config.num_hidden_layers = num_layers
  272. # ds_config.pre_layer_norm = is_preln
  273. # ds_config.initializer_range = 0.02
  274. # ds_config.fp16 = use_fp16
  275. # ds_config.stochastic_mode = True
  276. #
  277. # run_backward(ds_config, atol=atol)