transformer.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import json
  5. import math
  6. import importlib
  7. import torch
  8. from torch import nn
  9. from torch.autograd import Function
  10. from ..op_builder import TransformerBuilder, StochasticTransformerBuilder
  11. # Cuda modules will be imported if needed
  12. transformer_cuda_module = None
  13. stochastic_transformer_cuda_module = None
  14. class TransformerConfig():
  15. def __init__(self,
  16. batch_size,
  17. hidden_size,
  18. intermediate_size,
  19. heads,
  20. attn_dropout_ratio,
  21. hidden_dropout_ratio,
  22. num_hidden_layers,
  23. initializer_range):
  24. self.layer_id = -1
  25. self.batch_size = batch_size
  26. self.hidden_size = hidden_size
  27. self.intermediate_size = intermediate_size
  28. self.heads = heads
  29. self.attn_dropout_ratio = attn_dropout_ratio
  30. self.hidden_dropout_ratio = hidden_dropout_ratio
  31. self.num_hidden_layers = num_hidden_layers
  32. self.initializer_range = initializer_range
  33. class DeepSpeedTransformerConfig(TransformerConfig):
  34. """Initialize the DeepSpeed Transformer Config.
  35. Arguments:
  36. batch_size: The maximum batch size used for running the kernel on each GPU
  37. hidden_size: The hidden size of the transformer layer
  38. intermediate_size: The intermediate size of the feed-forward part of transformer layer
  39. heads: The number of heads in the self-attention of the transformer layer
  40. attn_dropout_ratio: The ratio of dropout for the attention's output
  41. hidden_dropout_ratio: The ratio of dropout for the transformer's output
  42. num_hidden_layers: The number of transformer layers
  43. initializer_range: BERT model's initializer range for initializing parameter data
  44. local_rank: Optional: The rank of GPU running the transformer kernel, it is not required
  45. to use if the model already set the current device, otherwise need to set it
  46. so that the transformer kernel can work on the right device
  47. seed: The random seed for the dropout layers
  48. fp16: Enable half-precision computation
  49. pre_layer_norm: Select between Pre-LN or Post-LN transformer architecture
  50. normalize_invertible: Optional: Enable invertible LayerNorm execution (dropping the input activation),
  51. default is False
  52. gelu_checkpoint: Optional: Enable checkpointing of Gelu activation output to save memory,
  53. default is False
  54. adjust_init_range: Optional: Set as True (default) if the model adjusts the weight initial values of
  55. its self-attention output and layer output, False keeps the initializer_range no change.
  56. See the adjustment below:
  57. output_std = self.config.initializer_range / math.sqrt(2.0 * num_layers)
  58. attn_dropout_checkpoint: Optional: Enable checkpointing of attention dropout to save memory,
  59. default is False
  60. stochastic_mode: Enable for high performance, please note that this flag has some level of
  61. non-determinism and can produce different results on different runs. However, we have seen
  62. that by enabling it, the pretraining tasks such as BERT are not affected and can obtain
  63. a high accuracy level. On the other hand, for the downstream tasks, such as fine-tuning, we recommend
  64. to turn it off in order to be able to reproduce the same result through the regular kernel execution.
  65. return_tuple: Enable if using the return_tuple interface style for sending out the forward results.
  66. training: Enable for training rather than inference.
  67. """
  68. def __init__(self,
  69. batch_size=-1,
  70. hidden_size=-1,
  71. intermediate_size=-1,
  72. heads=-1,
  73. attn_dropout_ratio=-1,
  74. hidden_dropout_ratio=-1,
  75. num_hidden_layers=-1,
  76. initializer_range=-1,
  77. layer_norm_eps=1e-12,
  78. local_rank=-1,
  79. seed=-1,
  80. fp16=False,
  81. pre_layer_norm=True,
  82. normalize_invertible=False,
  83. gelu_checkpoint=False,
  84. adjust_init_range=True,
  85. attn_dropout_checkpoint=False,
  86. stochastic_mode=False,
  87. return_tuple=False,
  88. training=True):
  89. super(DeepSpeedTransformerConfig,
  90. self).__init__(
  91. batch_size,
  92. hidden_size,
  93. (intermediate_size if intermediate_size > 0 else 4 * hidden_size),
  94. heads,
  95. attn_dropout_ratio,
  96. hidden_dropout_ratio,
  97. num_hidden_layers,
  98. initializer_range)
  99. self.fp16 = fp16
  100. self.pre_layer_norm = pre_layer_norm
  101. self.local_rank = local_rank
  102. self.seed = seed
  103. self.normalize_invertible = normalize_invertible
  104. self.gelu_checkpoint = gelu_checkpoint # True: if higher batch size is required
  105. self.adjust_init_range = adjust_init_range
  106. self.test_gemm = False
  107. self.layer_norm_eps = layer_norm_eps
  108. self.training = training
  109. self.is_grad_enabled = True
  110. self.attn_dropout_checkpoint = attn_dropout_checkpoint
  111. self.stochastic_mode = stochastic_mode
  112. self.return_tuple = return_tuple
  113. @classmethod
  114. def from_dict(cls, json_object):
  115. config = DeepSpeedTransformerConfig()
  116. for key, value in json_object.items():
  117. config.__dict__[key] = value
  118. return config
  119. @classmethod
  120. def from_json_file(cls, json_file):
  121. with open(json_file, "r", encoding='utf-16') as reader:
  122. text = reader.read()
  123. return cls.from_dict(json.loads(text))
  124. class DeepSpeedTransformerFunction(Function):
  125. @staticmethod
  126. def forward(ctx,
  127. input,
  128. input_mask,
  129. self,
  130. grads,
  131. layer_id,
  132. attn_qkvw,
  133. attn_qkvb,
  134. attn_ow,
  135. attn_ob,
  136. attn_nw,
  137. attn_nb,
  138. inter_w,
  139. inter_b,
  140. output_w,
  141. output_b,
  142. norm_w,
  143. norm_b,
  144. config):
  145. cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module
  146. forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32
  147. inp_size = input.size()
  148. if inp_size[1] % 16 != 0:
  149. input = torch.cat((input,
  150. torch.randn((inp_size[0],
  151. (16 - (inp_size[1] % 16)),
  152. inp_size[2]),
  153. device=input.device,
  154. dtype=input.dtype)),
  155. 1)
  156. input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \
  157. (16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3)
  158. (output,
  159. inp_norm,
  160. qkv_tf,
  161. soft_inp,
  162. ctx_bufB,
  163. attn_o_inp,
  164. add_res,
  165. ff1_inp,
  166. gelu_inp,
  167. ff2_inp,
  168. attn_prob_dropout_mask,
  169. attn_output_dropout_mask,
  170. layer_output_dropout_mask,
  171. attn_layer_norm_var,
  172. attn_layer_norm_mean,
  173. layer_norm_var,
  174. layer_norm_mean) = forward_func(config.layer_id,
  175. input,
  176. input_mask,
  177. attn_qkvw,
  178. attn_qkvb,
  179. attn_ow,
  180. attn_ob,
  181. attn_nw,
  182. attn_nb,
  183. inter_w,
  184. inter_b,
  185. output_w,
  186. output_b,
  187. norm_w,
  188. norm_b,
  189. config.training,
  190. config.pre_layer_norm,
  191. config.attn_dropout_checkpoint,
  192. config.normalize_invertible,
  193. config.gelu_checkpoint)
  194. # For testing only.
  195. if grads is not None:
  196. for i in [2]:
  197. attn_qkvw.register_hook(
  198. lambda x,
  199. i=i,
  200. self=self: grads.append([
  201. x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)],
  202. ("Q_W" if i == 0 else "K_W" if i == 1 else "V_W")
  203. ]))
  204. for i in [2]:
  205. attn_qkvb.register_hook(
  206. lambda x,
  207. i=i,
  208. self=self: grads.append([
  209. x[i * attn_ow.size(0):(i + 1) * attn_ow.size(0)],
  210. ("Q_B" if i == 0 else "K_B" if i == 1 else "V_B")
  211. ]))
  212. attn_ow.register_hook(lambda x, self=self: grads.append([x, "O_W"]))
  213. attn_ob.register_hook(lambda x, self=self: grads.append([x, "O_B"]))
  214. attn_nw.register_hook(lambda x, self=self: grads.append([x, "N2_W"]))
  215. attn_nb.register_hook(lambda x, self=self: grads.append([x, "N2_B"]))
  216. inter_w.register_hook(lambda x, self=self: grads.append([x, "int_W"]))
  217. inter_b.register_hook(lambda x, self=self: grads.append([x, "int_B"]))
  218. output_w.register_hook(lambda x, self=self: grads.append([x, "out_W"]))
  219. output_b.register_hook(lambda x, self=self: grads.append([x, "out_B"]))
  220. norm_w.register_hook(lambda x, self=self: grads.append([x, "norm_W"]))
  221. norm_b.register_hook(lambda x, self=self: grads.append([x, "norm_B"]))
  222. if config.is_grad_enabled and config.training:
  223. if (config.pre_layer_norm and config.normalize_invertible):
  224. ctx.save_for_backward(input_mask,
  225. attn_qkvw,
  226. attn_qkvb,
  227. attn_ow,
  228. attn_ob,
  229. attn_nw,
  230. attn_nb,
  231. inter_w,
  232. inter_b,
  233. output_w,
  234. output_b,
  235. norm_w,
  236. norm_b)
  237. else:
  238. ctx.save_for_backward(output,
  239. input,
  240. input_mask,
  241. attn_qkvw,
  242. attn_qkvb,
  243. attn_ow,
  244. attn_ob,
  245. attn_nw,
  246. attn_nb,
  247. inter_w,
  248. inter_b,
  249. output_w,
  250. output_b,
  251. norm_w,
  252. norm_b)
  253. ctx.config = config
  254. if (config.pre_layer_norm or not config.normalize_invertible):
  255. ctx.inp_norm = inp_norm
  256. ctx.qkv_tf = qkv_tf
  257. ctx.soft_inp = soft_inp
  258. if not config.attn_dropout_checkpoint:
  259. ctx.ctx_bufB = ctx_bufB
  260. ctx.attn_o_inp = attn_o_inp
  261. if not config.normalize_invertible:
  262. ctx.add_res = add_res
  263. ctx.attn_layer_norm_mean = attn_layer_norm_mean
  264. ctx.layer_norm_mean = layer_norm_mean
  265. ctx.ff1_inp = ff1_inp
  266. if not config.gelu_checkpoint:
  267. ctx.gelu_inp = gelu_inp
  268. ctx.ff2_inp = ff2_inp
  269. ctx.attn_prob_dropout_mask = attn_prob_dropout_mask
  270. ctx.attn_output_dropout_mask = attn_output_dropout_mask
  271. ctx.layer_output_dropout_mask = layer_output_dropout_mask
  272. ctx.attn_layer_norm_var = attn_layer_norm_var
  273. ctx.layer_norm_var = layer_norm_var
  274. if inp_size[1] % 16 != 0:
  275. output = torch.narrow(output, 1, 0, inp_size[1])
  276. if config.return_tuple:
  277. return (output, ) # outputs -> (output) : outputs[0] = output
  278. else:
  279. return output
  280. @staticmethod
  281. def backward(ctx, grad_output):
  282. bsz = grad_output.shape[0]
  283. grad_output_shape = grad_output.size()
  284. if grad_output_shape[1] % 16 != 0:
  285. grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \
  286. grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1)
  287. assert ctx.config.training
  288. if (ctx.config.pre_layer_norm and ctx.config.normalize_invertible):
  289. (input_mask,
  290. attn_qkvw,
  291. attn_qkvb,
  292. attn_ow,
  293. attn_ob,
  294. attn_nw,
  295. attn_nb,
  296. inter_w,
  297. inter_b,
  298. output_w,
  299. output_b,
  300. norm_w,
  301. norm_b) = ctx.saved_tensors
  302. else:
  303. (output,
  304. input,
  305. input_mask,
  306. attn_qkvw,
  307. attn_qkvb,
  308. attn_ow,
  309. attn_ob,
  310. attn_nw,
  311. attn_nb,
  312. inter_w,
  313. inter_b,
  314. output_w,
  315. output_b,
  316. norm_w,
  317. norm_b) = ctx.saved_tensors
  318. cuda_module = stochastic_transformer_cuda_module if ctx.config.stochastic_mode else transformer_cuda_module
  319. backward_func = cuda_module.backward_fp16 if ctx.config.fp16 else cuda_module.backward_fp32
  320. (grad_input,
  321. grad_attn_qkvw,
  322. grad_attn_qkvb,
  323. grad_attn_ow,
  324. grad_attn_ob,
  325. grad_attn_nw,
  326. grad_attn_nb,
  327. grad_inter_w,
  328. grad_inter_b,
  329. grad_output_w,
  330. grad_output_b,
  331. grad_norm_w,
  332. grad_norm_b) = backward_func(
  333. ctx.config.layer_id,
  334. grad_output,
  335. (ctx.inp_norm if (ctx.config.pre_layer_norm
  336. and ctx.config.normalize_invertible) else output),
  337. (ctx.inp_norm if (ctx.config.pre_layer_norm
  338. or not ctx.config.normalize_invertible) else input),
  339. ctx.qkv_tf,
  340. ctx.soft_inp,
  341. (ctx.soft_inp if ctx.config.attn_dropout_checkpoint else ctx.ctx_bufB),
  342. ctx.attn_o_inp,
  343. (ctx.ff1_inp if ctx.config.normalize_invertible else ctx.add_res),
  344. ctx.ff1_inp,
  345. (ctx.ff2_inp if ctx.config.gelu_checkpoint else ctx.gelu_inp),
  346. ctx.ff2_inp,
  347. ctx.attn_prob_dropout_mask,
  348. ctx.attn_output_dropout_mask,
  349. ctx.layer_output_dropout_mask,
  350. ctx.attn_layer_norm_var,
  351. ctx.attn_layer_norm_mean,
  352. ctx.layer_norm_var,
  353. ctx.layer_norm_mean,
  354. (ctx.inp_norm if (ctx.config.pre_layer_norm
  355. and ctx.config.normalize_invertible) else input),
  356. input_mask,
  357. attn_qkvw,
  358. attn_qkvb,
  359. attn_ow,
  360. attn_ob,
  361. attn_nw,
  362. attn_nb,
  363. inter_w,
  364. inter_b,
  365. output_w,
  366. output_b,
  367. norm_w,
  368. norm_b)
  369. # This appears to be an effective way to release context memory
  370. ctx.qkv_tf = None
  371. ctx.soft_inp = None
  372. ctx.ctx_bufB = None
  373. ctx.gelu_inp = None
  374. ctx.ff2_inp = None
  375. ctx.attn_o_inp = None
  376. ctx.ff1_inp = None
  377. ctx.add_res = None
  378. ctx.inp_norm = None
  379. ctx.config = None
  380. ctx.attn_layer_norm_mean = None
  381. ctx.layer_norm_mean = None
  382. ctx.attn_prob_dropout_mask = None
  383. ctx.attn_output_dropout_mask = None
  384. ctx.layer_output_dropout_mask = None
  385. ctx.attn_layer_norm_var = None
  386. ctx.layer_norm_var = None
  387. if grad_output_shape[1] % 16 != 0:
  388. grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1])
  389. return (grad_input,
  390. None,
  391. None,
  392. None,
  393. None,
  394. grad_attn_qkvw,
  395. grad_attn_qkvb,
  396. grad_attn_ow,
  397. grad_attn_ob,
  398. grad_attn_nw,
  399. grad_attn_nb,
  400. grad_inter_w,
  401. grad_inter_b,
  402. grad_output_w,
  403. grad_output_b,
  404. grad_norm_w,
  405. grad_norm_b,
  406. None)
  407. class DeepSpeedTransformerLayer(nn.Module):
  408. """Initialize the DeepSpeed Transformer Layer.
  409. Static variable:
  410. layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
  411. e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
  412. Arguments:
  413. config: An object of DeepSpeedTransformerConfig
  414. initial_weights: Optional: Only used for unit test
  415. initial_biases: Optional: Only used for unit test
  416. """
  417. layer_id = 0
  418. def __init__(self, config, initial_weights=None, initial_biases=None):
  419. super(DeepSpeedTransformerLayer, self).__init__()
  420. self.config = config
  421. self.config.layer_id = DeepSpeedTransformerLayer.layer_id
  422. DeepSpeedTransformerLayer.layer_id = DeepSpeedTransformerLayer.layer_id + 1
  423. print("DeepSpeed Transformer config is ", self.config.__dict__)
  424. if self.config.local_rank >= 0:
  425. torch.cuda.set_device(self.config.local_rank)
  426. if initial_weights is None and initial_biases is None:
  427. self.attn_qkvw = nn.Parameter(
  428. torch.Tensor(self.config.hidden_size * 3,
  429. self.config.hidden_size))
  430. self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
  431. self.attn_ow = nn.Parameter(
  432. torch.Tensor(self.config.hidden_size,
  433. self.config.hidden_size))
  434. self.attn_ob = nn.Parameter(torch.Tensor(self.config.hidden_size))
  435. self.attn_nw = nn.Parameter(torch.Tensor(self.config.hidden_size))
  436. self.attn_nb = nn.Parameter(torch.Tensor(self.config.hidden_size))
  437. self.inter_w = nn.Parameter(
  438. torch.Tensor(self.config.intermediate_size,
  439. self.config.hidden_size))
  440. self.inter_b = nn.Parameter(torch.Tensor(self.config.intermediate_size))
  441. self.output_w = nn.Parameter(
  442. torch.Tensor(self.config.hidden_size,
  443. self.config.intermediate_size))
  444. self.output_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
  445. self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size))
  446. self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size))
  447. self.init_transformer_weights(self.config.adjust_init_range)
  448. else:
  449. # For testing only.
  450. q = initial_weights[0].data
  451. k = initial_weights[1].data
  452. v = initial_weights[2].data
  453. self.attn_qkvw = nn.Parameter(torch.cat((q, k, v)))
  454. #self.attn_qkvw[i * self.config.hidden_size:(i + 1) * self.config.hidden_size] = \
  455. # initial_weights[i].clone()
  456. #torch.empty_like(initial_weights[i]).data.copy_(initial_weights[i].data)
  457. self.attn_qkvb = nn.Parameter(torch.Tensor(self.config.hidden_size * 3))
  458. self.attn_qkvb.data.zero_()
  459. self.attn_ow = initial_weights[3]
  460. self.attn_ob = initial_biases[3]
  461. self.attn_nw = initial_weights[4]
  462. self.attn_nb = initial_biases[4]
  463. self.inter_w = initial_weights[5]
  464. self.inter_b = initial_biases[5]
  465. self.output_w = initial_weights[6]
  466. self.output_b = initial_biases[6]
  467. self.norm_w = initial_weights[7]
  468. self.norm_b = initial_biases[7]
  469. # Load cuda modules if needed
  470. global transformer_cuda_module, stochastic_transformer_cuda_module
  471. if transformer_cuda_module is None and not self.config.stochastic_mode:
  472. transformer_cuda_module = TransformerBuilder().load()
  473. if stochastic_transformer_cuda_module is None and self.config.stochastic_mode:
  474. stochastic_transformer_cuda_module = StochasticTransformerBuilder().load()
  475. # create the layer in cuda kernels.
  476. cuda_module = stochastic_transformer_cuda_module if self.config.stochastic_mode else transformer_cuda_module
  477. create_layer_func = cuda_module.create_transformer_layer_fp16 if self.config.fp16 else cuda_module.create_transformer_layer_fp32
  478. create_layer_func(self.config.layer_id,
  479. self.config.batch_size,
  480. self.config.hidden_size,
  481. self.config.heads,
  482. self.config.intermediate_size,
  483. self.config.attn_dropout_ratio,
  484. self.config.hidden_dropout_ratio,
  485. self.config.layer_norm_eps,
  486. self.config.seed,
  487. self.config.pre_layer_norm,
  488. self.config.test_gemm,
  489. self.config.attn_dropout_checkpoint,
  490. self.config.normalize_invertible,
  491. self.config.gelu_checkpoint,
  492. self.config.stochastic_mode)
  493. def init_transformer_weights(self, adjust_init_range=False):
  494. num_layers = self.config.num_hidden_layers
  495. output_std = self.config.initializer_range
  496. if adjust_init_range and self.config.local_rank == 0:
  497. print("Accounting for accumulation on the residual path")
  498. output_std = self.config.initializer_range / math.sqrt(2.0 * num_layers)
  499. self.attn_qkvw.data.normal_(mean=0.0, std=self.config.initializer_range)
  500. self.attn_qkvb.data.zero_()
  501. self.attn_ow.data.normal_(mean=0.0, std=output_std)
  502. self.attn_ob.data.zero_()
  503. self.attn_nw.data.fill_(1.0)
  504. self.attn_nb.data.zero_()
  505. self.inter_w.data.normal_(mean=0.0, std=self.config.initializer_range)
  506. self.inter_b.data.zero_()
  507. self.output_w.data.normal_(mean=0.0, std=output_std)
  508. self.output_b.data.zero_()
  509. self.norm_w.data.fill_(1.0)
  510. self.norm_b.data.zero_()
  511. def forward(self,
  512. hidden_states,
  513. attention_mask=None,
  514. head_mask=None,
  515. layer_head_mask=None,
  516. encoder_hidden_states=None,
  517. encoder_attention_mask=None,
  518. past_key_value=None,
  519. output_attentions=False,
  520. grads=None):
  521. self.config.is_grad_enabled = torch.is_grad_enabled()
  522. return DeepSpeedTransformerFunction.apply(hidden_states,
  523. attention_mask,
  524. self,
  525. grads,
  526. self.config.layer_id,
  527. self.attn_qkvw,
  528. self.attn_qkvb,
  529. self.attn_ow,
  530. self.attn_ob,
  531. self.attn_nw,
  532. self.attn_nb,
  533. self.inter_w,
  534. self.inter_b,
  535. self.output_w,
  536. self.output_b,
  537. self.norm_w,
  538. self.norm_b,
  539. self.config)