test_cuda_forward.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import argparse
  2. import numpy as np
  3. import torch
  4. import torch.nn.functional as F
  5. import pytest
  6. import json
  7. import random
  8. import time
  9. import copy
  10. from torch import nn
  11. from modelingpreln import BertEncoder as BertEncoderPreln
  12. from modeling import BertEncoder as BertEncoderPostln
  13. from modeling import BertLayerNorm, BertConfig
  14. from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
  15. import deepspeed
  16. import sys
  17. if not deepspeed.ops.__installed_ops__['transformer']:
  18. pytest.skip("transformer kernels are not installed", allow_module_level=True)
  19. def check_equal(first, second, atol=1e-2, verbose=False):
  20. if verbose:
  21. print()
  22. for i, (x, y) in enumerate(zip(first, second)):
  23. x = x[0].cpu().detach().numpy()
  24. y = y[0].cpu().detach().numpy()
  25. if verbose:
  26. print("x = {}".format(x.flatten()))
  27. print("y = {}".format(y.flatten()))
  28. print('-' * 80)
  29. np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i), atol=atol)
  30. def zero_grad(variables):
  31. for variable in variables:
  32. variable.grad.zero_()
  33. device = torch.device("cuda")
  34. kwargs_fp32 = {'dtype': torch.float, 'device': device, 'requires_grad': True}
  35. kwargs_fp16 = {'dtype': torch.half, 'device': device, 'requires_grad': True}
  36. class DSEncoder(nn.Module):
  37. def __init__(self, config, weights, biases):
  38. super(DSEncoder, self).__init__()
  39. self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  40. self.layer = nn.ModuleList([
  41. copy.deepcopy(DeepSpeedTransformerLayer(i,
  42. config,
  43. weights,
  44. biases))
  45. for i in range(config.num_hidden_layers)
  46. ])
  47. self.grads = []
  48. self.pre_or_post = config.pre_layer_norm
  49. def forward(self,
  50. hidden_states,
  51. attention_mask,
  52. output_all_encoded_layers=True,
  53. checkpoint_activations=False):
  54. all_encoder_layers = []
  55. def custom(start, end):
  56. def custom_forward(*inputs):
  57. layers = self.layer[start:end]
  58. x_ = inputs[0]
  59. for layer in layers:
  60. x_ = layer(x_, inputs[1])
  61. return x_
  62. return custom_forward
  63. if checkpoint_activations:
  64. l = 0
  65. num_layers = len(self.layer)
  66. chunk_length = math.ceil(math.sqrt(num_layers))
  67. while l < num_layers:
  68. hidden_states = checkpoint.checkpoint(custom(l,
  69. l + chunk_length),
  70. hidden_states,
  71. attention_mask * 1)
  72. l += chunk_length
  73. # decoder layers
  74. else:
  75. for i, layer_module in enumerate(self.layer):
  76. hidden_states = layer_module(hidden_states, attention_mask)
  77. hidden_states.register_hook(
  78. lambda x,
  79. i=i,
  80. self=self: self.grads.append([x,
  81. "hidden_state"]))
  82. if output_all_encoded_layers:
  83. all_encoder_layers.append(hidden_states)
  84. if not output_all_encoded_layers or checkpoint_activations:
  85. if (self.pre_or_post):
  86. hidden_states = self.FinalLayerNorm(hidden_states)
  87. all_encoder_layers.append(hidden_states)
  88. return all_encoder_layers
  89. def get_grads(self):
  90. return self.grads
  91. def create_models(ds_config):
  92. bert_config = BertConfig(vocab_size_or_config_json_file=119547,
  93. hidden_size=ds_config.hidden_size,
  94. num_hidden_layers=ds_config.num_hidden_layers,
  95. num_attention_heads=ds_config.heads,
  96. batch_size=ds_config.batch_size,
  97. intermediate_size=ds_config.intermediate_size,
  98. hidden_act="gelu",
  99. hidden_dropout_prob=ds_config.hidden_dropout_ratio,
  100. attention_probs_dropout_prob=ds_config.attn_dropout_ratio,
  101. max_position_embeddings=ds_config.max_seq_length,
  102. type_vocab_size=2,
  103. initializer_range=ds_config.initializer_range,
  104. fp16=ds_config.fp16)
  105. weights = []
  106. biases = []
  107. for i in range(4):
  108. weights.append(
  109. nn.Parameter(torch.Tensor(ds_config.hidden_size,
  110. ds_config.hidden_size)))
  111. weights[i].data.normal_(mean=0.0, std=ds_config.initializer_range)
  112. weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  113. weights[4].data.fill_(1.0)
  114. weights.append(
  115. nn.Parameter(torch.Tensor(ds_config.intermediate_size,
  116. ds_config.hidden_size)))
  117. weights[5].data.normal_(mean=0.0, std=ds_config.initializer_range)
  118. weights.append(
  119. nn.Parameter(torch.Tensor(ds_config.hidden_size,
  120. ds_config.intermediate_size)))
  121. weights[6].data.normal_(mean=0.0, std=ds_config.initializer_range)
  122. weights.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  123. weights[7].data.fill_(1.0)
  124. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  125. biases[0].data.zero_()
  126. for i in range(4):
  127. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  128. biases[i + 1].data.zero_()
  129. biases.append(nn.Parameter(torch.Tensor(ds_config.intermediate_size)))
  130. biases[5].data.zero_()
  131. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  132. biases[6].data.zero_()
  133. biases.append(nn.Parameter(torch.Tensor(ds_config.hidden_size)))
  134. biases[7].data.zero_()
  135. if (ds_config.pre_layer_norm):
  136. bert_encoder = BertEncoderPreln(bert_config, weights, biases)
  137. else:
  138. bert_encoder = BertEncoderPostln(bert_config, weights, biases)
  139. ds_encoder = DSEncoder(ds_config, weights, biases)
  140. if ds_config.fp16:
  141. bert_encoder.half()
  142. ds_encoder.half()
  143. bert_encoder.cuda()
  144. ds_encoder.cuda()
  145. return bert_encoder, ds_encoder
  146. def set_seed(seed):
  147. random.seed(seed)
  148. np.random.seed(seed)
  149. torch.manual_seed(seed)
  150. def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
  151. set_seed(123)
  152. bert_encoder, ds_encoder = create_models(ds_config)
  153. bsz = ds_config.batch_size if test_bsz is None else test_bsz
  154. # prepare test data
  155. kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32
  156. hidden_states = torch.randn(bsz,
  157. seq_len, #ds_config.max_seq_length,
  158. ds_config.hidden_size,
  159. **kwargs)
  160. input_mask = torch.randn(bsz, 1, 1,
  161. seq_len, #ds_config.max_seq_length,
  162. **kwargs)
  163. # run baseline
  164. base_results = bert_encoder(hidden_states,
  165. input_mask,
  166. output_all_encoded_layers=False,
  167. checkpoint_activations=False)
  168. # run ds
  169. ds_results = ds_encoder(hidden_states,
  170. input_mask,
  171. output_all_encoded_layers=False,
  172. checkpoint_activations=False)
  173. # check grads
  174. check_equal(base_results, ds_results, atol=atol, verbose=verbose)
  175. # FP16 test cases can only run on the devices support FP16.
  176. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
  177. [
  178. (64,1024,128,16,3,True,False),
  179. (64,1024,128,16,3,True,True),
  180. (8,1024,384,16,3,True,False),
  181. (8,1024,384,16,3,True,True),
  182. (8,1024,384,16,3,True,True),
  183. (8,1024,120,16,3,True,False),
  184. (8,1024,120,16,3,True,True),
  185. (8,1024,512,16,3,True,False),
  186. (8,1024,512,16,3,True,True),
  187. (64,1024,56,16,3,False,False),
  188. (64,1024,56,16,3,False,True),
  189. (64,1024,24,16,3,False,False),
  190. (64,1024,24,16,3,False,True),
  191. (8,1024,384,16,3,False,False),
  192. (8,1024,384,16,3,False,True),
  193. (8,1024,512,16,3,False,False),
  194. (8,1024,512,16,3,False,True),
  195. (8,1536,128,24,3,False,False),
  196. (8,1536,128,24,3,False,True),
  197. (8,2048,128,32,3,False,False),
  198. (8,2048,128,32,3,False,True),
  199. (8,2560,128,40,3,False,False),
  200. (8,2560,128,40,3,False,True),
  201. ]) # yapf: disable
  202. def test_forward(batch_size,
  203. hidden_size,
  204. seq_len,
  205. heads,
  206. num_layers,
  207. is_preln,
  208. use_fp16):
  209. # Only run fp16 test cases on devices with 7+ capability.
  210. major, _ = torch.cuda.get_device_capability()
  211. if major < 7 and use_fp16 is True:
  212. return
  213. ds_config = DeepSpeedTransformerConfig()
  214. ds_config.layer_id = None
  215. ds_config.batch_size = batch_size
  216. ds_config.hidden_size = hidden_size
  217. ds_config.max_seq_length = 128 #seq_len
  218. ds_config.intermediate_size = 4 * hidden_size
  219. ds_config.heads = heads
  220. ds_config.attn_dropout_ratio = 0.0
  221. ds_config.hidden_dropout_ratio = 0.0
  222. ds_config.num_hidden_layers = num_layers
  223. ds_config.pre_layer_norm = is_preln
  224. ds_config.initializer_range = 0.02
  225. ds_config.fp16 = use_fp16
  226. run_forward(ds_config, seq_len, atol=2e-2)
  227. @pytest.mark.parametrize('batch_size, small_bsz, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
  228. [
  229. (8,3,1024,512,16,3,True,False),
  230. (8,7,1024,512,16,3,True,True),
  231. (8,3,1024,512,16,3,False,False),
  232. (8,7,1024,512,16,3,False,True),
  233. ]) # yapf: disable
  234. def test_forward_with_small_bsz(batch_size,
  235. small_bsz,
  236. hidden_size,
  237. seq_len,
  238. heads,
  239. num_layers,
  240. is_preln,
  241. use_fp16):
  242. # Only run fp16 test cases on devices with 7+ capability.
  243. major, _ = torch.cuda.get_device_capability()
  244. if major < 7 and use_fp16 is True:
  245. return
  246. ds_config = DeepSpeedTransformerConfig()
  247. ds_config.layer_id = None
  248. ds_config.batch_size = batch_size
  249. ds_config.hidden_size = hidden_size
  250. ds_config.intermediate_size = 4 * hidden_size
  251. ds_config.max_seq_length = seq_len
  252. ds_config.heads = heads
  253. ds_config.attn_dropout_ratio = 0.0
  254. ds_config.hidden_dropout_ratio = 0.0
  255. ds_config.num_hidden_layers = num_layers
  256. ds_config.pre_layer_norm = is_preln
  257. ds_config.initializer_range = 0.02
  258. ds_config.fp16 = use_fp16
  259. run_forward(ds_config, seq_len, atol=2e-2, test_bsz=small_bsz)
  260. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
  261. [
  262. (64,1024,128,16,3,True,False),
  263. (64,1024,128,16,3,True,True),
  264. (64,1024,128,16,3,False,False),
  265. (64,1024,128,16,3,False,True),
  266. ]) # yapf: disable
  267. def test_forward_stochastic(batch_size,
  268. hidden_size,
  269. seq_len,
  270. heads,
  271. num_layers,
  272. is_preln,
  273. use_fp16):
  274. # Only run fp16 test cases on devices with 7+ capability.
  275. major, _ = torch.cuda.get_device_capability()
  276. if major < 7 and use_fp16 is True:
  277. return
  278. ds_config = DeepSpeedTransformerConfig()
  279. ds_config.layer_id = None
  280. ds_config.batch_size = batch_size
  281. ds_config.hidden_size = hidden_size
  282. ds_config.intermediate_size = 4 * hidden_size
  283. ds_config.max_seq_length = seq_len
  284. ds_config.heads = heads
  285. ds_config.attn_dropout_ratio = 0.0
  286. ds_config.hidden_dropout_ratio = 0.0
  287. ds_config.num_hidden_layers = num_layers
  288. ds_config.pre_layer_norm = is_preln
  289. ds_config.initializer_range = 0.02
  290. ds_config.fp16 = use_fp16
  291. ds_config.stochastic_mode = True
  292. run_forward(ds_config, seq_len, atol=7e-2)