ds_transformer_cuda.cpp 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. #include <torch/extension.h>
  2. #include <cublas_v2.h>
  3. #include <cuda_fp16.h>
  4. #include <cuda_runtime.h>
  5. #include <type_traits>
  6. #include <unordered_map>
  7. #include <vector>
  8. #include "Timer.h"
  9. #include "context.h"
  10. #include "cublas_wrappers.h"
  11. #include "custom_cuda_layers.h"
  12. #include "ds_transformer_cuda.h"
  13. static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
  14. const int init_seq_length = 128;
  15. // C++ interface
  16. template <typename T>
  17. unsigned get_workspace_size(unsigned maxBatchSize,
  18. unsigned seq_len,
  19. unsigned hidden_size,
  20. unsigned intermediate_size,
  21. unsigned heads,
  22. bool training,
  23. bool gelu_checkpoint)
  24. {
  25. unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
  26. if (training) {
  27. workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
  28. workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
  29. 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
  30. if (gelu_checkpoint)
  31. workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
  32. }
  33. return workSpacesize; // * sizeof(T);
  34. }
  35. // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
  36. #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
  37. #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
  38. #define CHECK_INPUT(x) \
  39. CHECK_CUDA(x); \
  40. CHECK_CONTIGUOUS(x)
  41. template <typename T>
  42. BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
  43. unsigned batch_size,
  44. unsigned hidden_size,
  45. unsigned num_heads,
  46. unsigned intermediate_size,
  47. unsigned seq_length,
  48. float attn_prob_dropout_ratio,
  49. float hidden_output_dropout_ratio,
  50. float layer_norm_eps,
  51. bool pre_or_postLayerNorm,
  52. const std::vector<std::array<int, 3>>& gemm_algos,
  53. bool attn_dropout_checkpoint,
  54. bool normalize_invertible,
  55. bool gelu_checkpoint,
  56. bool stochastic_mode)
  57. : _layer_id(layer_id),
  58. _batch_size(batch_size),
  59. _hidden_size(hidden_size),
  60. _heads(num_heads),
  61. _intermediate_size(intermediate_size),
  62. _seq_length(seq_length),
  63. _training(true),
  64. _pre_or_postLayerNorm(pre_or_postLayerNorm),
  65. _attn_dropout_checkpoint(attn_dropout_checkpoint),
  66. _normalize_invertible(normalize_invertible),
  67. _gelu_checkpoint(gelu_checkpoint),
  68. _stochastic_mode(stochastic_mode),
  69. _stream(Context::Instance().GetCurrentStream()),
  70. _cublasHandle(Context::Instance().GetCublasHandle()),
  71. _qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
  72. 3 * hidden_size,
  73. hidden_size,
  74. gemm_algos[0])),
  75. _attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
  76. hidden_size,
  77. hidden_size,
  78. gemm_algos[0])),
  79. _attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
  80. seq_length,
  81. hidden_size,
  82. layer_norm_eps,
  83. true,
  84. !normalize_invertible)),
  85. _layer_norm(typename Normalize_Layer<T>::Config(batch_size,
  86. seq_length,
  87. hidden_size,
  88. layer_norm_eps,
  89. true,
  90. !normalize_invertible)),
  91. _ff1(typename FeedForward<T>::Config(batch_size * seq_length,
  92. _intermediate_size,
  93. hidden_size,
  94. gemm_algos[1])),
  95. _ff2(typename FeedForward<T>::Config(batch_size * seq_length,
  96. hidden_size,
  97. _intermediate_size,
  98. gemm_algos[2])),
  99. _softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
  100. _gelu(typename Gelu<T>::Config(_intermediate_size)),
  101. _attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
  102. _attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
  103. _layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
  104. _attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
  105. _seq_length,
  106. _seq_length,
  107. _hidden_size / _heads,
  108. (T(1.0) / T(sqrt(_hidden_size / _heads))),
  109. T(0.0),
  110. CUBLAS_OP_T,
  111. CUBLAS_OP_N,
  112. gemm_algos[3])),
  113. _attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
  114. _hidden_size / _heads,
  115. _seq_length,
  116. _seq_length,
  117. T(1.0),
  118. T(0.0),
  119. CUBLAS_OP_N,
  120. CUBLAS_OP_N,
  121. gemm_algos[4]))
  122. {
  123. assert(_hidden_size % _heads == 0);
  124. Initialize();
  125. }
  126. template <typename T>
  127. BertTransformerLayer<T>::~BertTransformerLayer()
  128. {
  129. }
  130. template <typename T>
  131. void BertTransformerLayer<T>::Initialize()
  132. {
  133. if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
  134. }
  135. template <typename T>
  136. void BertTransformerLayer<T>::Forward(unsigned bsz,
  137. const T* input_ptr,
  138. const T* input_mask_ptr,
  139. const T* attn_qkvw_ptr,
  140. const T* attn_qkvb_ptr,
  141. const T* attn_ow_ptr,
  142. const T* attn_ob_ptr,
  143. const T* attn_nw_ptr,
  144. const T* attn_nb_ptr,
  145. const T* inter_w_ptr,
  146. const T* inter_b_ptr,
  147. const T* output_w_ptr,
  148. const T* output_b_ptr,
  149. const T* norm_w_ptr,
  150. const T* norm_b_ptr,
  151. T* out_ptr,
  152. T* inp_norm_ptr,
  153. T* q_tf_ptr,
  154. T* k_tf_ptr,
  155. T* v_tf_ptr,
  156. T* soft_out_ptr,
  157. T* ctx_bufB_ptr,
  158. T* attn_o_inp_ptr,
  159. T* add_res_ptr,
  160. T* ff1_inp_ptr,
  161. T* gelu_inp_ptr,
  162. T* ff2_inp_ptr)
  163. {
  164. cublasSetStream(_cublasHandle, _stream);
  165. if (!_stochastic_mode) cudaStreamSynchronize(_stream);
  166. T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
  167. size_t small_buf_size = bsz * _seq_length * _hidden_size;
  168. T* buf_0 = workspace;
  169. T* buf_1 = buf_0 + small_buf_size;
  170. T* buf_2 = buf_1;
  171. if (_normalize_invertible) {
  172. add_res_ptr = buf_1 + 3 * small_buf_size;
  173. buf_2 = add_res_ptr;
  174. }
  175. if (_gelu_checkpoint) buf_2 += small_buf_size;
  176. if (_attn_dropout_checkpoint)
  177. ctx_bufB_ptr =
  178. (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
  179. : (buf_1 + 4 * small_buf_size));
  180. int bsz_seq = bsz * _seq_length;
  181. if (_pre_or_postLayerNorm) {
  182. if (_layer_norm.UseMean())
  183. _layer_norm.ForwardCheckpoint(
  184. bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  185. else
  186. _layer_norm.Forward(
  187. bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  188. }
  189. if (_pre_or_postLayerNorm)
  190. _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
  191. else
  192. _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
  193. launch_bias_add_transform_0213<T>(
  194. q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
  195. int bsz_heads = bsz * _heads;
  196. // attention scores
  197. _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
  198. // Softmax + Mask
  199. _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
  200. // attn prob dropout.
  201. _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
  202. // attention context
  203. _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
  204. launch_transform4d_0213<T>(
  205. attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
  206. if (_pre_or_postLayerNorm)
  207. _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
  208. else
  209. _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
  210. // attn output dropout.
  211. if (_pre_or_postLayerNorm)
  212. _attn_output_dropout.ForwardWithBias(
  213. bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
  214. else
  215. _attn_output_dropout.ForwardWithBias(
  216. bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
  217. if (_pre_or_postLayerNorm) {
  218. if (_attn_layer_norm.UseMean())
  219. _attn_layer_norm.ForwardCheckpoint(
  220. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  221. else
  222. _attn_layer_norm.Forward(
  223. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  224. } else {
  225. if (_attn_layer_norm.UseMean())
  226. _attn_layer_norm.ForwardCheckpoint(
  227. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  228. else
  229. _attn_layer_norm.Forward(
  230. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  231. }
  232. _ff1.Forward(bsz_seq,
  233. ff1_inp_ptr,
  234. inter_w_ptr,
  235. (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
  236. _cublasHandle);
  237. _gelu.ForwardWithBiasAdd(bsz_seq,
  238. (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
  239. inter_b_ptr,
  240. (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
  241. _stream);
  242. _ff2.Forward(
  243. bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
  244. // layer output dropout.
  245. if (_pre_or_postLayerNorm)
  246. _layer_output_dropout.ForwardWithBias(
  247. bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
  248. else
  249. _layer_output_dropout.ForwardWithBias(
  250. bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
  251. if (!_pre_or_postLayerNorm) {
  252. if (_layer_norm.UseMean())
  253. _layer_norm.ForwardCheckpoint(
  254. bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  255. else
  256. _layer_norm.Forward(
  257. bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  258. }
  259. }
  260. template <typename T>
  261. void BertTransformerLayer<T>::Backward(unsigned bsz,
  262. const T* grad_output_ptr,
  263. const T* input_ptr,
  264. const T* output_ptr,
  265. const T* inp_norm_ptr,
  266. const T* q_tf_ptr,
  267. const T* k_tf_ptr,
  268. const T* v_tf_ptr,
  269. const T* soft_out_ptr,
  270. const T* ctx_bufB_ptr,
  271. const T* attn_o_inp_ptr,
  272. const T* add_res_ptr,
  273. const T* ff1_inp_ptr,
  274. const T* gelu_inp_ptr,
  275. const T* ff2_inp_ptr,
  276. const T* input_mask_ptr,
  277. const T* attn_qkvw_ptr,
  278. const T* attn_ow_ptr,
  279. const T* attn_nw_ptr,
  280. const T* attn_nb_ptr,
  281. const T* inter_w_ptr,
  282. const T* inter_b_ptr,
  283. const T* output_w_ptr,
  284. const T* norm_w_ptr,
  285. const T* norm_b_ptr,
  286. T* grad_input_ptr,
  287. T* grad_attn_qkvw_ptr,
  288. T* grad_attn_qkvb_ptr,
  289. T* grad_attn_ow_ptr,
  290. T* grad_attn_ob_ptr,
  291. T* grad_attn_nw_ptr,
  292. T* grad_attn_nb_ptr,
  293. T* grad_inter_w_ptr,
  294. T* grad_inter_b_ptr,
  295. T* grad_output_w_ptr,
  296. T* grad_output_b_ptr,
  297. T* grad_norm_w_ptr,
  298. T* grad_norm_b_ptr)
  299. {
  300. cublasSetStream(_cublasHandle, _stream);
  301. if (!_stochastic_mode) cudaStreamSynchronize(_stream);
  302. T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
  303. size_t small_buf_size = bsz * _seq_length * _hidden_size;
  304. T* buf_0 = workspace;
  305. T* buf_1 = buf_0 + small_buf_size;
  306. T* buf_2 = buf_1 + small_buf_size;
  307. T* buf_3 = buf_2 + small_buf_size;
  308. T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
  309. : buf_3 + small_buf_size);
  310. T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
  311. cudaStream_t streams[2] = {_stream, _stream};
  312. int bsz_seq = bsz * _seq_length;
  313. int bsz_heads = bsz * _heads;
  314. if (!_pre_or_postLayerNorm) {
  315. if (_layer_norm.UseMean())
  316. _layer_norm.Backward(bsz_seq,
  317. grad_output_ptr,
  318. norm_w_ptr,
  319. grad_norm_w_ptr,
  320. grad_norm_b_ptr,
  321. streams,
  322. buf_1,
  323. inp_norm_ptr);
  324. else
  325. _layer_norm.Backward(bsz_seq,
  326. grad_output_ptr,
  327. norm_w_ptr,
  328. norm_b_ptr,
  329. grad_norm_w_ptr,
  330. grad_norm_b_ptr,
  331. streams,
  332. buf_1,
  333. output_ptr);
  334. }
  335. if (_pre_or_postLayerNorm)
  336. _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
  337. else
  338. _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
  339. const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
  340. ? buf_0
  341. : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
  342. if (_gelu_checkpoint)
  343. _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
  344. _ff2.Backward(bsz_seq,
  345. layer_dropout_buf,
  346. (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
  347. output_w_ptr,
  348. grad_output_w_ptr,
  349. grad_output_b_ptr,
  350. _cublasHandle,
  351. _stream,
  352. ff2_buf);
  353. _gelu.Backward(
  354. bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
  355. _ff1.Backward(bsz_seq,
  356. ff2_buf,
  357. ff1_inp_ptr,
  358. inter_w_ptr,
  359. grad_inter_w_ptr,
  360. grad_inter_b_ptr,
  361. _cublasHandle,
  362. _stream,
  363. buf_3);
  364. if (!_pre_or_postLayerNorm)
  365. launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
  366. if (_pre_or_postLayerNorm) {
  367. if (_attn_layer_norm.UseMean())
  368. _attn_layer_norm.BackwardFusedAdd(bsz_seq,
  369. buf_3,
  370. grad_output_ptr,
  371. attn_nw_ptr,
  372. grad_attn_nw_ptr,
  373. grad_attn_nb_ptr,
  374. streams,
  375. buf_0,
  376. add_res_ptr);
  377. else
  378. _attn_layer_norm.BackwardFusedAdd(bsz_seq,
  379. buf_3,
  380. grad_output_ptr,
  381. attn_nw_ptr,
  382. attn_nb_ptr,
  383. grad_attn_nw_ptr,
  384. grad_attn_nb_ptr,
  385. streams,
  386. buf_0,
  387. ff1_inp_ptr);
  388. } else {
  389. if (_attn_layer_norm.UseMean())
  390. _attn_layer_norm.Backward(bsz_seq,
  391. buf_2,
  392. attn_nw_ptr,
  393. grad_attn_nw_ptr,
  394. grad_attn_nb_ptr,
  395. streams,
  396. buf_0,
  397. add_res_ptr);
  398. else
  399. _attn_layer_norm.Backward(bsz_seq,
  400. buf_2,
  401. attn_nw_ptr,
  402. attn_nb_ptr,
  403. grad_attn_nw_ptr,
  404. grad_attn_nb_ptr,
  405. streams,
  406. buf_0,
  407. ff1_inp_ptr);
  408. }
  409. _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
  410. T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
  411. _attn_out_linear.Backward(bsz_seq,
  412. attn_output_dropout_buf,
  413. attn_o_inp_ptr,
  414. attn_ow_ptr,
  415. grad_attn_ow_ptr,
  416. grad_attn_ob_ptr,
  417. _cublasHandle,
  418. _stream,
  419. buf_1);
  420. launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
  421. if (_attn_prob_dropout.HasDropout()) {
  422. if (_attn_dropout_checkpoint)
  423. _attn_prob_dropout.Forward(
  424. bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
  425. _attn_context.Backward(bsz_heads,
  426. buf_2,
  427. v_tf_ptr,
  428. (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
  429. _cublasHandle,
  430. buf_3,
  431. ff2_buf);
  432. } else
  433. _attn_context.Backward(
  434. bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
  435. _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
  436. _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
  437. _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
  438. launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
  439. if (_pre_or_postLayerNorm)
  440. _qkv_linear.Backward(bsz_seq,
  441. ff2_buf,
  442. inp_norm_ptr,
  443. attn_qkvw_ptr,
  444. grad_attn_qkvw_ptr,
  445. grad_attn_qkvb_ptr,
  446. _cublasHandle,
  447. _stream,
  448. buf_2);
  449. else
  450. _qkv_linear.Backward(bsz_seq,
  451. ff2_buf,
  452. input_ptr,
  453. attn_qkvw_ptr,
  454. grad_attn_qkvw_ptr,
  455. grad_attn_qkvb_ptr,
  456. _cublasHandle,
  457. _stream,
  458. buf_2);
  459. if (_pre_or_postLayerNorm) {
  460. if (_layer_norm.UseMean())
  461. _layer_norm.BackwardFusedAdd(bsz_seq,
  462. buf_2,
  463. buf_0,
  464. norm_w_ptr,
  465. grad_norm_w_ptr,
  466. grad_norm_b_ptr,
  467. streams,
  468. grad_input_ptr,
  469. input_ptr);
  470. else
  471. _layer_norm.BackwardFusedAdd(bsz_seq,
  472. buf_2,
  473. buf_0,
  474. norm_w_ptr,
  475. norm_b_ptr,
  476. grad_norm_w_ptr,
  477. grad_norm_b_ptr,
  478. streams,
  479. grad_input_ptr,
  480. inp_norm_ptr);
  481. } else
  482. launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
  483. }
  484. template <typename T>
  485. void BertTransformerLayer<T>::SetTrainingMode(bool training)
  486. {
  487. // Dropout will be skipped when not in training model.
  488. _attn_prob_dropout.SetTrainingMode(training);
  489. _attn_output_dropout.SetTrainingMode(training);
  490. _layer_output_dropout.SetTrainingMode(training);
  491. }
  492. template <typename T>
  493. void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
  494. uint8_t* attn_output_dropout_mask_ptr,
  495. uint8_t* layer_output_dropout_mask_ptr,
  496. T* attn_layer_norm_var,
  497. T* attn_layer_norm_mean,
  498. T* layer_norm_var,
  499. T* layer_norm_mean)
  500. {
  501. _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
  502. _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
  503. _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
  504. _attn_layer_norm.SetVar(attn_layer_norm_var);
  505. _attn_layer_norm.SetMean(attn_layer_norm_mean);
  506. _layer_norm.SetVar(layer_norm_var);
  507. _layer_norm.SetMean(layer_norm_mean);
  508. }
  509. template <typename T>
  510. void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
  511. {
  512. _seq_length = seq_len;
  513. _softmax.SetSeqLength(_seq_length);
  514. _attn_prob_dropout.SetDimension(_seq_length);
  515. _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
  516. _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
  517. }
  518. template <typename T>
  519. int create_transformer_layer(unsigned layer_id,
  520. unsigned batch_size,
  521. unsigned hidden_dim,
  522. unsigned num_heads,
  523. unsigned intermediate_size,
  524. float attn_dropout_ratio,
  525. float hidden_dropout_ratio,
  526. float layer_norm_eps,
  527. int seed,
  528. bool pre_or_postLayerNorm,
  529. bool test_gemm,
  530. bool attn_dropout_checkpoint,
  531. bool normalize_invertible,
  532. bool gelu_checkpoint,
  533. bool stochastic_mode)
  534. {
  535. Context::Instance().SetSeed(seed);
  536. Context::Instance().TestGemmFP16(
  537. test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
  538. auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
  539. batch_size,
  540. hidden_dim,
  541. num_heads,
  542. intermediate_size,
  543. init_seq_length,
  544. attn_dropout_ratio,
  545. hidden_dropout_ratio,
  546. layer_norm_eps,
  547. pre_or_postLayerNorm,
  548. Context::Instance().GetGemmAlgos(),
  549. attn_dropout_checkpoint,
  550. normalize_invertible,
  551. gelu_checkpoint,
  552. stochastic_mode);
  553. s_transformer_layers[layer_id] = layer;
  554. std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
  555. std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
  556. << std::endl;
  557. return 0;
  558. }
  559. template <typename T>
  560. std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
  561. const torch::Tensor& input,
  562. const torch::Tensor& input_mask,
  563. const torch::Tensor& attn_qkvw,
  564. const torch::Tensor& attn_qkvb,
  565. const torch::Tensor& attn_ow,
  566. const torch::Tensor& attn_ob,
  567. const torch::Tensor& attn_nw,
  568. const torch::Tensor& attn_nb,
  569. const torch::Tensor& inter_w,
  570. const torch::Tensor& inter_b,
  571. const torch::Tensor& output_w,
  572. const torch::Tensor& output_b,
  573. const torch::Tensor& norm_w,
  574. const torch::Tensor& norm_b,
  575. bool training_mode,
  576. bool prelayernorm,
  577. bool attn_dropout_checkpoint,
  578. bool normalize_invertible,
  579. bool gelu_checkpoint)
  580. {
  581. CHECK_INPUT(input);
  582. CHECK_INPUT(input_mask);
  583. CHECK_INPUT(attn_qkvw);
  584. CHECK_INPUT(attn_qkvb);
  585. CHECK_INPUT(attn_ow);
  586. CHECK_INPUT(attn_ob);
  587. CHECK_INPUT(attn_nw);
  588. CHECK_INPUT(attn_nb);
  589. CHECK_INPUT(inter_w);
  590. CHECK_INPUT(inter_b);
  591. CHECK_INPUT(output_w);
  592. CHECK_INPUT(output_b);
  593. CHECK_INPUT(norm_w);
  594. CHECK_INPUT(norm_b);
  595. unsigned bsz = input.size(0);
  596. const T* input_ptr = (const T*)input.data_ptr();
  597. const T* input_mask_ptr = (const T*)input_mask.data_ptr();
  598. const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
  599. const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
  600. const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
  601. const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
  602. const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
  603. const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
  604. const T* inter_w_ptr = (const T*)inter_w.data_ptr();
  605. const T* inter_b_ptr = (const T*)inter_b.data_ptr();
  606. const T* output_w_ptr = (const T*)output_w.data_ptr();
  607. const T* output_b_ptr = (const T*)output_b.data_ptr();
  608. const T* norm_w_ptr = (const T*)norm_w.data_ptr();
  609. const T* norm_b_ptr = (const T*)norm_b.data_ptr();
  610. auto output = torch::empty_like(input);
  611. T* out_ptr = (T*)output.data_ptr();
  612. auto options = torch::TensorOptions()
  613. .dtype(input.options().dtype())
  614. .layout(torch::kStrided)
  615. .device(torch::kCUDA)
  616. .requires_grad(true);
  617. auto uint8_options = torch::TensorOptions()
  618. .dtype(torch::kInt8)
  619. .layout(torch::kStrided)
  620. .device(torch::kCUDA)
  621. .requires_grad(false);
  622. std::shared_ptr<BertTransformerLayer<T>> layer =
  623. std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
  624. unsigned seq_len = layer->GetSeqLength();
  625. if (input.size(1) != seq_len) {
  626. seq_len = input.size(1);
  627. layer->SetSeqLength(seq_len);
  628. }
  629. auto workspace = torch::empty({get_workspace_size<T>(bsz,
  630. seq_len,
  631. layer->GetHiddenSize(),
  632. layer->GetIntermediateSize(),
  633. layer->GetNumHeads(),
  634. layer->IsTrainingMode(),
  635. layer->GeluCheckpoint())},
  636. options);
  637. Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
  638. auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
  639. auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
  640. auto attn_o_inp = torch::empty_like(input);
  641. auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
  642. auto attn_prob_dropout_mask =
  643. torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
  644. auto attn_output_dropout_mask =
  645. torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
  646. auto layer_output_dropout_mask =
  647. torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
  648. auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
  649. auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
  650. auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
  651. auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
  652. T* inp_norm_ptr = (T*)inp_norm.data_ptr();
  653. T* add_res_ptr = (T*)add_res.data_ptr();
  654. T* q_tf_ptr = (T*)qkv_tf.data_ptr();
  655. T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
  656. T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
  657. T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
  658. torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
  659. torch::Tensor gelu_inp =
  660. (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
  661. auto ff1_inp = torch::empty_like(input);
  662. T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
  663. T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
  664. T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
  665. torch::Tensor soft_out =
  666. torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
  667. torch::Tensor ctx_bufB =
  668. (attn_dropout_checkpoint
  669. ? soft_out
  670. : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
  671. T* soft_out_ptr = (T*)soft_out.data_ptr();
  672. T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
  673. layer->SetTrainingMode(training_mode);
  674. layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
  675. (uint8_t*)attn_output_dropout_mask.data_ptr(),
  676. (uint8_t*)layer_output_dropout_mask.data_ptr(),
  677. (T*)attn_layer_norm_var.data_ptr(),
  678. (T*)attn_layer_norm_mean.data_ptr(),
  679. (T*)layer_norm_var.data_ptr(),
  680. (T*)layer_norm_mean.data_ptr());
  681. layer->Forward(bsz,
  682. input_ptr,
  683. input_mask_ptr,
  684. attn_qkvw_ptr,
  685. attn_qkvb_ptr,
  686. attn_ow_ptr,
  687. attn_ob_ptr,
  688. attn_nw_ptr,
  689. attn_nb_ptr,
  690. inter_w_ptr,
  691. inter_b_ptr,
  692. output_w_ptr,
  693. output_b_ptr,
  694. norm_w_ptr,
  695. norm_b_ptr,
  696. out_ptr,
  697. inp_norm_ptr,
  698. q_tf_ptr,
  699. k_tf_ptr,
  700. v_tf_ptr,
  701. soft_out_ptr,
  702. ctx_bufB_ptr,
  703. attn_o_inp_ptr,
  704. add_res_ptr,
  705. ff1_inp_ptr,
  706. gelu_inp_ptr,
  707. ff2_inp_ptr);
  708. return {output,
  709. inp_norm,
  710. qkv_tf,
  711. soft_out,
  712. ctx_bufB,
  713. attn_o_inp,
  714. add_res,
  715. ff1_inp,
  716. gelu_inp,
  717. ff2_inp,
  718. attn_prob_dropout_mask,
  719. attn_output_dropout_mask,
  720. layer_output_dropout_mask,
  721. attn_layer_norm_var,
  722. attn_layer_norm_mean,
  723. layer_norm_var,
  724. layer_norm_mean};
  725. }
  726. template <typename T>
  727. std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
  728. const torch::Tensor& grad_output,
  729. const torch::Tensor& output,
  730. const torch::Tensor& inp_norm,
  731. const torch::Tensor& qkv_tf,
  732. const torch::Tensor& soft_out,
  733. const torch::Tensor& ctx_bufB,
  734. const torch::Tensor& attn_o_inp,
  735. const torch::Tensor& add_res,
  736. const torch::Tensor& ff1_inp,
  737. const torch::Tensor& gelu_inp,
  738. const torch::Tensor& ff2_inp,
  739. const torch::Tensor& attn_prob_dropout_mask,
  740. const torch::Tensor& attn_output_dropout_mask,
  741. const torch::Tensor& layer_output_dropout_mask,
  742. const torch::Tensor& attn_layer_norm_var,
  743. const torch::Tensor& attn_layer_norm_mean,
  744. const torch::Tensor& layer_norm_var,
  745. const torch::Tensor& layer_norm_mean,
  746. const torch::Tensor& input,
  747. const torch::Tensor& input_mask,
  748. const torch::Tensor& attn_qkvw,
  749. const torch::Tensor& attn_qkvb,
  750. const torch::Tensor& attn_ow,
  751. const torch::Tensor& attn_ob,
  752. const torch::Tensor& attn_nw,
  753. const torch::Tensor& attn_nb,
  754. const torch::Tensor& inter_w,
  755. const torch::Tensor& inter_b,
  756. const torch::Tensor& output_w,
  757. const torch::Tensor& output_b,
  758. const torch::Tensor& norm_w,
  759. const torch::Tensor& norm_b)
  760. {
  761. auto g_output = grad_output.contiguous();
  762. CHECK_INPUT(g_output);
  763. CHECK_INPUT(output);
  764. CHECK_INPUT(inp_norm);
  765. CHECK_INPUT(qkv_tf);
  766. CHECK_INPUT(add_res);
  767. CHECK_INPUT(soft_out);
  768. CHECK_INPUT(ctx_bufB);
  769. CHECK_INPUT(attn_o_inp);
  770. CHECK_INPUT(ff1_inp);
  771. CHECK_INPUT(gelu_inp);
  772. CHECK_INPUT(ff2_inp);
  773. CHECK_INPUT(input);
  774. CHECK_INPUT(input_mask);
  775. CHECK_INPUT(attn_qkvw);
  776. CHECK_INPUT(attn_qkvb);
  777. CHECK_INPUT(attn_ow);
  778. CHECK_INPUT(attn_ob);
  779. CHECK_INPUT(attn_nw);
  780. CHECK_INPUT(attn_nb);
  781. CHECK_INPUT(inter_w);
  782. CHECK_INPUT(inter_b);
  783. CHECK_INPUT(output_w);
  784. CHECK_INPUT(output_b);
  785. CHECK_INPUT(norm_w);
  786. CHECK_INPUT(norm_b);
  787. unsigned bsz = g_output.size(0);
  788. std::shared_ptr<BertTransformerLayer<T>> layer =
  789. std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
  790. unsigned seq_len = layer->GetSeqLength();
  791. if (g_output.size(1) != seq_len) {
  792. seq_len = g_output.size(1);
  793. layer->SetSeqLength(seq_len);
  794. }
  795. auto options = torch::TensorOptions()
  796. .dtype(g_output.options().dtype())
  797. .layout(torch::kStrided)
  798. .device(torch::kCUDA)
  799. .requires_grad(true);
  800. auto workspace = torch::empty({get_workspace_size<T>(bsz,
  801. seq_len,
  802. layer->GetHiddenSize(),
  803. layer->GetIntermediateSize(),
  804. layer->GetNumHeads(),
  805. layer->IsTrainingMode(),
  806. layer->GeluCheckpoint())},
  807. options);
  808. Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
  809. auto grad_input = torch::empty_like(input);
  810. auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
  811. auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
  812. auto grad_attn_ow = torch::empty_like(attn_ow);
  813. auto grad_attn_ob = torch::empty_like(attn_ob);
  814. auto grad_attn_nw = torch::empty_like(attn_nw);
  815. auto grad_attn_nb = torch::empty_like(attn_nb);
  816. auto grad_inter_w = torch::empty_like(inter_w);
  817. auto grad_inter_b = torch::empty_like(inter_b);
  818. auto grad_output_w = torch::empty_like(output_w);
  819. auto grad_output_b = torch::empty_like(output_b);
  820. auto grad_norm_w = torch::empty_like(norm_w);
  821. auto grad_norm_b = torch::empty_like(norm_b);
  822. // inputs.
  823. const T* grad_output_ptr = (const T*)g_output.data_ptr();
  824. const T* input_ptr = (const T*)input.data_ptr();
  825. const T* output_ptr = (const T*)output.data_ptr();
  826. const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
  827. const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
  828. const T* add_res_ptr = (const T*)add_res.data_ptr();
  829. const T* k_tf_ptr =
  830. q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
  831. const T* v_tf_ptr =
  832. k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
  833. const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
  834. const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
  835. const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
  836. const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
  837. const T* soft_out_ptr = (const T*)soft_out.data_ptr();
  838. const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
  839. const T* input_mask_ptr = (const T*)input_mask.data_ptr();
  840. const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
  841. const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
  842. const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
  843. const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
  844. const T* inter_w_ptr = (const T*)inter_w.data_ptr();
  845. const T* inter_b_ptr = (const T*)inter_b.data_ptr();
  846. const T* output_w_ptr = (const T*)output_w.data_ptr();
  847. const T* norm_w_ptr = (const T*)norm_w.data_ptr();
  848. const T* norm_b_ptr = (const T*)norm_b.data_ptr();
  849. // outputs.
  850. T* grad_input_ptr = (T*)grad_input.data_ptr();
  851. T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
  852. T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
  853. T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
  854. T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
  855. T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
  856. T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
  857. T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
  858. T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
  859. T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
  860. T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
  861. T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
  862. T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
  863. layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
  864. (uint8_t*)attn_output_dropout_mask.data_ptr(),
  865. (uint8_t*)layer_output_dropout_mask.data_ptr(),
  866. (T*)attn_layer_norm_var.data_ptr(),
  867. (T*)attn_layer_norm_mean.data_ptr(),
  868. (T*)layer_norm_var.data_ptr(),
  869. (T*)layer_norm_mean.data_ptr());
  870. layer->Backward(bsz,
  871. grad_output_ptr,
  872. input_ptr,
  873. output_ptr,
  874. inp_norm_ptr,
  875. q_tf_ptr,
  876. k_tf_ptr,
  877. v_tf_ptr,
  878. soft_out_ptr,
  879. ctx_bufB_ptr,
  880. attn_o_inp_ptr,
  881. add_res_ptr,
  882. ff1_inp_ptr,
  883. gelu_inp_ptr,
  884. ff2_inp_ptr,
  885. input_mask_ptr,
  886. attn_qkvw_ptr,
  887. attn_ow_ptr,
  888. attn_nw_ptr,
  889. attn_nb_ptr,
  890. inter_w_ptr,
  891. inter_b_ptr,
  892. output_w_ptr,
  893. norm_w_ptr,
  894. norm_b_ptr,
  895. grad_input_ptr,
  896. grad_attn_qkvw_ptr,
  897. grad_attn_qkvb_ptr,
  898. grad_attn_ow_ptr,
  899. grad_attn_ob_ptr,
  900. grad_attn_nw_ptr,
  901. grad_attn_nb_ptr,
  902. grad_inter_w_ptr,
  903. grad_inter_b_ptr,
  904. grad_output_w_ptr,
  905. grad_output_b_ptr,
  906. grad_norm_w_ptr,
  907. grad_norm_b_ptr);
  908. return {grad_input,
  909. grad_attn_qkvw,
  910. grad_attn_qkvb,
  911. grad_attn_ow,
  912. grad_attn_ob,
  913. grad_attn_nw,
  914. grad_attn_nb,
  915. grad_inter_w,
  916. grad_inter_b,
  917. grad_output_w,
  918. grad_output_b,
  919. grad_norm_w,
  920. grad_norm_b};
  921. }
  922. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  923. {
  924. m.def("forward_fp32",
  925. &ds_transformer_forward<float>,
  926. "DeepSpeed Transformer forward with fp32 (CUDA)");
  927. m.def("forward_fp16",
  928. &ds_transformer_forward<__half>,
  929. "DeepSpeed Transformer forward with fp16 (CUDA)");
  930. m.def("backward_fp32",
  931. &ds_transformer_backward<float>,
  932. "DeepSpeed Transformer backward with fp32 (CUDA)");
  933. m.def("backward_fp16",
  934. &ds_transformer_backward<__half>,
  935. "DeepSpeed Transformer backward with fp16 (CUDA)");
  936. m.def("create_transformer_layer_fp32",
  937. &create_transformer_layer<float>,
  938. "Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
  939. m.def("create_transformer_layer_fp16",
  940. &create_transformer_layer<__half>,
  941. "Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
  942. }