ds_transformer_cuda.cpp 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049
  1. #include <torch/extension.h>
  2. #include <cublas_v2.h>
  3. #include <cuda_fp16.h>
  4. #include <cuda_runtime.h>
  5. #include <type_traits>
  6. #include <unordered_map>
  7. #include <vector>
  8. #include "Timer.h"
  9. #include "context.h"
  10. #include "cublas_wrappers.h"
  11. #include "custom_cuda_layers.h"
  12. #include "ds_transformer_cuda.h"
  13. static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
  14. const int init_seq_length = 128;
  15. // C++ interface
  16. template <typename T>
  17. unsigned get_workspace_size(unsigned maxBatchSize,
  18. unsigned seq_len,
  19. unsigned hidden_size,
  20. unsigned intermediate_size,
  21. unsigned heads,
  22. bool training,
  23. bool gelu_checkpoint)
  24. {
  25. unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
  26. if (training) {
  27. workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
  28. workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
  29. 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
  30. if (gelu_checkpoint)
  31. workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
  32. }
  33. return workSpacesize; // * sizeof(T);
  34. }
  35. // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
  36. #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
  37. #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
  38. #define CHECK_INPUT(x) \
  39. CHECK_CUDA(x); \
  40. CHECK_CONTIGUOUS(x)
  41. template <typename T>
  42. BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
  43. unsigned batch_size,
  44. unsigned hidden_size,
  45. unsigned num_heads,
  46. unsigned intermediate_size,
  47. unsigned seq_length,
  48. float attn_prob_dropout_ratio,
  49. float hidden_output_dropout_ratio,
  50. float layer_norm_eps,
  51. bool pre_or_postLayerNorm,
  52. const std::vector<std::array<int, 3>>& gemm_algos,
  53. bool attn_dropout_checkpoint,
  54. bool normalize_invertible,
  55. bool gelu_checkpoint,
  56. bool stochastic_mode)
  57. : _layer_id(layer_id),
  58. _batch_size(batch_size),
  59. _hidden_size(hidden_size),
  60. _heads(num_heads),
  61. _intermediate_size(intermediate_size),
  62. _seq_length(seq_length),
  63. _training(true),
  64. _pre_or_postLayerNorm(pre_or_postLayerNorm),
  65. _attn_dropout_checkpoint(attn_dropout_checkpoint),
  66. _normalize_invertible(normalize_invertible),
  67. _gelu_checkpoint(gelu_checkpoint),
  68. _stochastic_mode(stochastic_mode),
  69. _stream(Context::Instance().GetCurrentStream()),
  70. _cublasHandle(Context::Instance().GetCublasHandle()),
  71. _qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
  72. 3 * hidden_size,
  73. hidden_size,
  74. gemm_algos[0])),
  75. _attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
  76. hidden_size,
  77. hidden_size,
  78. gemm_algos[0])),
  79. _attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
  80. seq_length,
  81. hidden_size,
  82. layer_norm_eps,
  83. true,
  84. !normalize_invertible)),
  85. _layer_norm(typename Normalize_Layer<T>::Config(batch_size,
  86. seq_length,
  87. hidden_size,
  88. layer_norm_eps,
  89. true,
  90. !normalize_invertible)),
  91. _ff1(typename FeedForward<T>::Config(batch_size * seq_length,
  92. _intermediate_size,
  93. hidden_size,
  94. gemm_algos[1])),
  95. _ff2(typename FeedForward<T>::Config(batch_size * seq_length,
  96. hidden_size,
  97. _intermediate_size,
  98. gemm_algos[2])),
  99. _softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
  100. _gelu(typename Gelu<T>::Config(_intermediate_size)),
  101. _attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
  102. _attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
  103. _layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
  104. _attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
  105. _seq_length,
  106. _seq_length,
  107. _hidden_size / _heads,
  108. (T(1.0) / T(sqrt(_hidden_size / _heads))),
  109. T(0.0),
  110. CUBLAS_OP_T,
  111. CUBLAS_OP_N,
  112. gemm_algos[3])),
  113. _attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
  114. _hidden_size / _heads,
  115. _seq_length,
  116. _seq_length,
  117. T(1.0),
  118. T(0.0),
  119. CUBLAS_OP_N,
  120. CUBLAS_OP_N,
  121. gemm_algos[4]))
  122. {
  123. assert(_hidden_size % _heads == 0);
  124. Initialize();
  125. }
  126. template <typename T>
  127. BertTransformerLayer<T>::~BertTransformerLayer()
  128. {
  129. }
  130. template <typename T>
  131. void BertTransformerLayer<T>::Initialize()
  132. {
  133. #ifndef __HIP_PLATFORM_HCC__
  134. if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
  135. #endif
  136. }
  137. template <typename T>
  138. void BertTransformerLayer<T>::Forward(unsigned bsz,
  139. const T* input_ptr,
  140. const T* input_mask_ptr,
  141. const T* attn_qkvw_ptr,
  142. const T* attn_qkvb_ptr,
  143. const T* attn_ow_ptr,
  144. const T* attn_ob_ptr,
  145. const T* attn_nw_ptr,
  146. const T* attn_nb_ptr,
  147. const T* inter_w_ptr,
  148. const T* inter_b_ptr,
  149. const T* output_w_ptr,
  150. const T* output_b_ptr,
  151. const T* norm_w_ptr,
  152. const T* norm_b_ptr,
  153. T* out_ptr,
  154. T* inp_norm_ptr,
  155. T* q_tf_ptr,
  156. T* k_tf_ptr,
  157. T* v_tf_ptr,
  158. T* soft_out_ptr,
  159. T* ctx_bufB_ptr,
  160. T* attn_o_inp_ptr,
  161. T* add_res_ptr,
  162. T* ff1_inp_ptr,
  163. T* gelu_inp_ptr,
  164. T* ff2_inp_ptr)
  165. {
  166. cublasSetStream(_cublasHandle, _stream);
  167. if (!_stochastic_mode) cudaStreamSynchronize(_stream);
  168. T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
  169. size_t small_buf_size = bsz * _seq_length * _hidden_size;
  170. T* buf_0 = workspace;
  171. T* buf_1 = buf_0 + small_buf_size;
  172. T* buf_2 = buf_1;
  173. if (_normalize_invertible) {
  174. add_res_ptr = buf_1 + 3 * small_buf_size;
  175. buf_2 = add_res_ptr;
  176. }
  177. if (_gelu_checkpoint) buf_2 += small_buf_size;
  178. if (_attn_dropout_checkpoint)
  179. ctx_bufB_ptr =
  180. (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
  181. : (buf_1 + 4 * small_buf_size));
  182. int bsz_seq = bsz * _seq_length;
  183. if (_pre_or_postLayerNorm) {
  184. if (_layer_norm.UseMean())
  185. _layer_norm.ForwardCheckpoint(
  186. bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  187. else
  188. _layer_norm.Forward(
  189. bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  190. }
  191. if (_pre_or_postLayerNorm)
  192. _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
  193. else
  194. _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
  195. launch_bias_add_transform_0213<T>(
  196. q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
  197. int bsz_heads = bsz * _heads;
  198. // attention scores
  199. _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
  200. // Softmax + Mask
  201. _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
  202. // attn prob dropout.
  203. _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
  204. // attention context
  205. _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
  206. launch_transform4d_0213<T>(
  207. attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
  208. if (_pre_or_postLayerNorm)
  209. _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
  210. else
  211. _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
  212. // attn output dropout.
  213. if (_pre_or_postLayerNorm)
  214. _attn_output_dropout.ForwardWithBias(
  215. bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
  216. else
  217. _attn_output_dropout.ForwardWithBias(
  218. bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
  219. if (_pre_or_postLayerNorm) {
  220. if (_attn_layer_norm.UseMean())
  221. _attn_layer_norm.ForwardCheckpoint(
  222. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  223. else
  224. _attn_layer_norm.Forward(
  225. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  226. } else {
  227. if (_attn_layer_norm.UseMean())
  228. _attn_layer_norm.ForwardCheckpoint(
  229. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  230. else
  231. _attn_layer_norm.Forward(
  232. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  233. }
  234. _ff1.Forward(bsz_seq,
  235. ff1_inp_ptr,
  236. inter_w_ptr,
  237. (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
  238. _cublasHandle);
  239. _gelu.ForwardWithBiasAdd(bsz_seq,
  240. (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
  241. inter_b_ptr,
  242. (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
  243. _stream);
  244. _ff2.Forward(
  245. bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
  246. // layer output dropout.
  247. if (_pre_or_postLayerNorm)
  248. _layer_output_dropout.ForwardWithBias(
  249. bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
  250. else
  251. _layer_output_dropout.ForwardWithBias(
  252. bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
  253. if (!_pre_or_postLayerNorm) {
  254. if (_layer_norm.UseMean())
  255. _layer_norm.ForwardCheckpoint(
  256. bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  257. else
  258. _layer_norm.Forward(
  259. bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  260. }
  261. }
  262. template <typename T>
  263. void BertTransformerLayer<T>::Backward(unsigned bsz,
  264. const T* grad_output_ptr,
  265. const T* input_ptr,
  266. const T* output_ptr,
  267. const T* inp_norm_ptr,
  268. const T* q_tf_ptr,
  269. const T* k_tf_ptr,
  270. const T* v_tf_ptr,
  271. const T* soft_out_ptr,
  272. const T* ctx_bufB_ptr,
  273. const T* attn_o_inp_ptr,
  274. const T* add_res_ptr,
  275. const T* ff1_inp_ptr,
  276. const T* gelu_inp_ptr,
  277. const T* ff2_inp_ptr,
  278. const T* input_mask_ptr,
  279. const T* attn_qkvw_ptr,
  280. const T* attn_ow_ptr,
  281. const T* attn_nw_ptr,
  282. const T* attn_nb_ptr,
  283. const T* inter_w_ptr,
  284. const T* inter_b_ptr,
  285. const T* output_w_ptr,
  286. const T* norm_w_ptr,
  287. const T* norm_b_ptr,
  288. T* grad_input_ptr,
  289. T* grad_attn_qkvw_ptr,
  290. T* grad_attn_qkvb_ptr,
  291. T* grad_attn_ow_ptr,
  292. T* grad_attn_ob_ptr,
  293. T* grad_attn_nw_ptr,
  294. T* grad_attn_nb_ptr,
  295. T* grad_inter_w_ptr,
  296. T* grad_inter_b_ptr,
  297. T* grad_output_w_ptr,
  298. T* grad_output_b_ptr,
  299. T* grad_norm_w_ptr,
  300. T* grad_norm_b_ptr)
  301. {
  302. cublasSetStream(_cublasHandle, _stream);
  303. if (!_stochastic_mode) cudaStreamSynchronize(_stream);
  304. T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
  305. size_t small_buf_size = bsz * _seq_length * _hidden_size;
  306. T* buf_0 = workspace;
  307. T* buf_1 = buf_0 + small_buf_size;
  308. T* buf_2 = buf_1 + small_buf_size;
  309. T* buf_3 = buf_2 + small_buf_size;
  310. T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
  311. : buf_3 + small_buf_size);
  312. T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
  313. cudaStream_t streams[2] = {_stream, _stream};
  314. int bsz_seq = bsz * _seq_length;
  315. int bsz_heads = bsz * _heads;
  316. if (!_pre_or_postLayerNorm) {
  317. if (_layer_norm.UseMean())
  318. _layer_norm.Backward(bsz_seq,
  319. grad_output_ptr,
  320. norm_w_ptr,
  321. grad_norm_w_ptr,
  322. grad_norm_b_ptr,
  323. streams,
  324. buf_1,
  325. inp_norm_ptr);
  326. else
  327. _layer_norm.Backward(bsz_seq,
  328. grad_output_ptr,
  329. norm_w_ptr,
  330. norm_b_ptr,
  331. grad_norm_w_ptr,
  332. grad_norm_b_ptr,
  333. streams,
  334. buf_1,
  335. output_ptr);
  336. }
  337. if (_pre_or_postLayerNorm)
  338. _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
  339. else
  340. _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
  341. const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
  342. ? buf_0
  343. : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
  344. if (_gelu_checkpoint)
  345. _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
  346. _ff2.Backward(bsz_seq,
  347. layer_dropout_buf,
  348. (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
  349. output_w_ptr,
  350. grad_output_w_ptr,
  351. grad_output_b_ptr,
  352. _cublasHandle,
  353. _stream,
  354. ff2_buf);
  355. _gelu.Backward(
  356. bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
  357. _ff1.Backward(bsz_seq,
  358. ff2_buf,
  359. ff1_inp_ptr,
  360. inter_w_ptr,
  361. grad_inter_w_ptr,
  362. grad_inter_b_ptr,
  363. _cublasHandle,
  364. _stream,
  365. buf_3);
  366. if (!_pre_or_postLayerNorm)
  367. launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
  368. if (_pre_or_postLayerNorm) {
  369. if (_attn_layer_norm.UseMean())
  370. _attn_layer_norm.BackwardFusedAdd(bsz_seq,
  371. buf_3,
  372. grad_output_ptr,
  373. attn_nw_ptr,
  374. grad_attn_nw_ptr,
  375. grad_attn_nb_ptr,
  376. streams,
  377. buf_0,
  378. add_res_ptr);
  379. else
  380. _attn_layer_norm.BackwardFusedAdd(bsz_seq,
  381. buf_3,
  382. grad_output_ptr,
  383. attn_nw_ptr,
  384. attn_nb_ptr,
  385. grad_attn_nw_ptr,
  386. grad_attn_nb_ptr,
  387. streams,
  388. buf_0,
  389. ff1_inp_ptr);
  390. } else {
  391. if (_attn_layer_norm.UseMean())
  392. _attn_layer_norm.Backward(bsz_seq,
  393. buf_2,
  394. attn_nw_ptr,
  395. grad_attn_nw_ptr,
  396. grad_attn_nb_ptr,
  397. streams,
  398. buf_0,
  399. add_res_ptr);
  400. else
  401. _attn_layer_norm.Backward(bsz_seq,
  402. buf_2,
  403. attn_nw_ptr,
  404. attn_nb_ptr,
  405. grad_attn_nw_ptr,
  406. grad_attn_nb_ptr,
  407. streams,
  408. buf_0,
  409. ff1_inp_ptr);
  410. }
  411. _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
  412. T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
  413. _attn_out_linear.Backward(bsz_seq,
  414. attn_output_dropout_buf,
  415. attn_o_inp_ptr,
  416. attn_ow_ptr,
  417. grad_attn_ow_ptr,
  418. grad_attn_ob_ptr,
  419. _cublasHandle,
  420. _stream,
  421. buf_1);
  422. launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
  423. if (_attn_prob_dropout.HasDropout()) {
  424. if (_attn_dropout_checkpoint)
  425. _attn_prob_dropout.Forward(
  426. bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
  427. _attn_context.Backward(bsz_heads,
  428. buf_2,
  429. v_tf_ptr,
  430. (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
  431. _cublasHandle,
  432. buf_3,
  433. ff2_buf);
  434. } else
  435. _attn_context.Backward(
  436. bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
  437. _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
  438. _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
  439. _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
  440. launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
  441. if (_pre_or_postLayerNorm)
  442. _qkv_linear.Backward(bsz_seq,
  443. ff2_buf,
  444. inp_norm_ptr,
  445. attn_qkvw_ptr,
  446. grad_attn_qkvw_ptr,
  447. grad_attn_qkvb_ptr,
  448. _cublasHandle,
  449. _stream,
  450. buf_2);
  451. else
  452. _qkv_linear.Backward(bsz_seq,
  453. ff2_buf,
  454. input_ptr,
  455. attn_qkvw_ptr,
  456. grad_attn_qkvw_ptr,
  457. grad_attn_qkvb_ptr,
  458. _cublasHandle,
  459. _stream,
  460. buf_2);
  461. if (_pre_or_postLayerNorm) {
  462. if (_layer_norm.UseMean())
  463. _layer_norm.BackwardFusedAdd(bsz_seq,
  464. buf_2,
  465. buf_0,
  466. norm_w_ptr,
  467. grad_norm_w_ptr,
  468. grad_norm_b_ptr,
  469. streams,
  470. grad_input_ptr,
  471. input_ptr);
  472. else
  473. _layer_norm.BackwardFusedAdd(bsz_seq,
  474. buf_2,
  475. buf_0,
  476. norm_w_ptr,
  477. norm_b_ptr,
  478. grad_norm_w_ptr,
  479. grad_norm_b_ptr,
  480. streams,
  481. grad_input_ptr,
  482. inp_norm_ptr);
  483. } else
  484. launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
  485. }
  486. template <typename T>
  487. void BertTransformerLayer<T>::SetTrainingMode(bool training)
  488. {
  489. // Dropout will be skipped when not in training model.
  490. _attn_prob_dropout.SetTrainingMode(training);
  491. _attn_output_dropout.SetTrainingMode(training);
  492. _layer_output_dropout.SetTrainingMode(training);
  493. }
  494. template <typename T>
  495. void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
  496. uint8_t* attn_output_dropout_mask_ptr,
  497. uint8_t* layer_output_dropout_mask_ptr,
  498. T* attn_layer_norm_var,
  499. T* attn_layer_norm_mean,
  500. T* layer_norm_var,
  501. T* layer_norm_mean)
  502. {
  503. _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
  504. _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
  505. _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
  506. _attn_layer_norm.SetVar(attn_layer_norm_var);
  507. _attn_layer_norm.SetMean(attn_layer_norm_mean);
  508. _layer_norm.SetVar(layer_norm_var);
  509. _layer_norm.SetMean(layer_norm_mean);
  510. }
  511. template <typename T>
  512. void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
  513. {
  514. _seq_length = seq_len;
  515. _softmax.SetSeqLength(_seq_length);
  516. _attn_prob_dropout.SetDimension(_seq_length);
  517. _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
  518. _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
  519. }
  520. template <typename T>
  521. int create_transformer_layer(unsigned layer_id,
  522. unsigned batch_size,
  523. unsigned hidden_dim,
  524. unsigned num_heads,
  525. unsigned intermediate_size,
  526. float attn_dropout_ratio,
  527. float hidden_dropout_ratio,
  528. float layer_norm_eps,
  529. int seed,
  530. bool pre_or_postLayerNorm,
  531. bool test_gemm,
  532. bool attn_dropout_checkpoint,
  533. bool normalize_invertible,
  534. bool gelu_checkpoint,
  535. bool stochastic_mode)
  536. {
  537. Context::Instance().SetSeed(seed);
  538. Context::Instance().TestGemmFP16(
  539. test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
  540. auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
  541. batch_size,
  542. hidden_dim,
  543. num_heads,
  544. intermediate_size,
  545. init_seq_length,
  546. attn_dropout_ratio,
  547. hidden_dropout_ratio,
  548. layer_norm_eps,
  549. pre_or_postLayerNorm,
  550. Context::Instance().GetGemmAlgos(),
  551. attn_dropout_checkpoint,
  552. normalize_invertible,
  553. gelu_checkpoint,
  554. stochastic_mode);
  555. s_transformer_layers[layer_id] = layer;
  556. std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
  557. std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
  558. << std::endl;
  559. return 0;
  560. }
  561. template <typename T>
  562. std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
  563. const torch::Tensor& input,
  564. const torch::Tensor& input_mask,
  565. const torch::Tensor& attn_qkvw,
  566. const torch::Tensor& attn_qkvb,
  567. const torch::Tensor& attn_ow,
  568. const torch::Tensor& attn_ob,
  569. const torch::Tensor& attn_nw,
  570. const torch::Tensor& attn_nb,
  571. const torch::Tensor& inter_w,
  572. const torch::Tensor& inter_b,
  573. const torch::Tensor& output_w,
  574. const torch::Tensor& output_b,
  575. const torch::Tensor& norm_w,
  576. const torch::Tensor& norm_b,
  577. bool training_mode,
  578. bool prelayernorm,
  579. bool attn_dropout_checkpoint,
  580. bool normalize_invertible,
  581. bool gelu_checkpoint)
  582. {
  583. CHECK_INPUT(input);
  584. CHECK_INPUT(input_mask);
  585. CHECK_INPUT(attn_qkvw);
  586. CHECK_INPUT(attn_qkvb);
  587. CHECK_INPUT(attn_ow);
  588. CHECK_INPUT(attn_ob);
  589. CHECK_INPUT(attn_nw);
  590. CHECK_INPUT(attn_nb);
  591. CHECK_INPUT(inter_w);
  592. CHECK_INPUT(inter_b);
  593. CHECK_INPUT(output_w);
  594. CHECK_INPUT(output_b);
  595. CHECK_INPUT(norm_w);
  596. CHECK_INPUT(norm_b);
  597. unsigned bsz = input.size(0);
  598. const T* input_ptr = (const T*)input.data_ptr();
  599. const T* input_mask_ptr = (const T*)input_mask.data_ptr();
  600. const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
  601. const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
  602. const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
  603. const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
  604. const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
  605. const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
  606. const T* inter_w_ptr = (const T*)inter_w.data_ptr();
  607. const T* inter_b_ptr = (const T*)inter_b.data_ptr();
  608. const T* output_w_ptr = (const T*)output_w.data_ptr();
  609. const T* output_b_ptr = (const T*)output_b.data_ptr();
  610. const T* norm_w_ptr = (const T*)norm_w.data_ptr();
  611. const T* norm_b_ptr = (const T*)norm_b.data_ptr();
  612. auto output = torch::empty_like(input);
  613. T* out_ptr = (T*)output.data_ptr();
  614. auto options = torch::TensorOptions()
  615. .dtype(input.options().dtype())
  616. .layout(torch::kStrided)
  617. .device(torch::kCUDA)
  618. .requires_grad(true);
  619. auto uint8_options = torch::TensorOptions()
  620. .dtype(torch::kInt8)
  621. .layout(torch::kStrided)
  622. .device(torch::kCUDA)
  623. .requires_grad(false);
  624. std::shared_ptr<BertTransformerLayer<T>> layer =
  625. std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
  626. unsigned seq_len = layer->GetSeqLength();
  627. if (input.size(1) != seq_len) {
  628. seq_len = input.size(1);
  629. layer->SetSeqLength(seq_len);
  630. }
  631. auto workspace = torch::empty({get_workspace_size<T>(bsz,
  632. seq_len,
  633. layer->GetHiddenSize(),
  634. layer->GetIntermediateSize(),
  635. layer->GetNumHeads(),
  636. layer->IsTrainingMode(),
  637. layer->GeluCheckpoint())},
  638. options);
  639. Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
  640. auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
  641. auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
  642. auto attn_o_inp = torch::empty_like(input);
  643. auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
  644. auto attn_prob_dropout_mask =
  645. torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
  646. auto attn_output_dropout_mask =
  647. torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
  648. auto layer_output_dropout_mask =
  649. torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
  650. auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
  651. auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
  652. auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
  653. auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
  654. T* inp_norm_ptr = (T*)inp_norm.data_ptr();
  655. T* add_res_ptr = (T*)add_res.data_ptr();
  656. T* q_tf_ptr = (T*)qkv_tf.data_ptr();
  657. T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
  658. T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
  659. T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
  660. torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
  661. torch::Tensor gelu_inp =
  662. (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
  663. auto ff1_inp = torch::empty_like(input);
  664. T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
  665. T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
  666. T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
  667. torch::Tensor soft_out =
  668. torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
  669. torch::Tensor ctx_bufB =
  670. (attn_dropout_checkpoint
  671. ? soft_out
  672. : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
  673. T* soft_out_ptr = (T*)soft_out.data_ptr();
  674. T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
  675. layer->SetTrainingMode(training_mode);
  676. layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
  677. (uint8_t*)attn_output_dropout_mask.data_ptr(),
  678. (uint8_t*)layer_output_dropout_mask.data_ptr(),
  679. (T*)attn_layer_norm_var.data_ptr(),
  680. (T*)attn_layer_norm_mean.data_ptr(),
  681. (T*)layer_norm_var.data_ptr(),
  682. (T*)layer_norm_mean.data_ptr());
  683. layer->Forward(bsz,
  684. input_ptr,
  685. input_mask_ptr,
  686. attn_qkvw_ptr,
  687. attn_qkvb_ptr,
  688. attn_ow_ptr,
  689. attn_ob_ptr,
  690. attn_nw_ptr,
  691. attn_nb_ptr,
  692. inter_w_ptr,
  693. inter_b_ptr,
  694. output_w_ptr,
  695. output_b_ptr,
  696. norm_w_ptr,
  697. norm_b_ptr,
  698. out_ptr,
  699. inp_norm_ptr,
  700. q_tf_ptr,
  701. k_tf_ptr,
  702. v_tf_ptr,
  703. soft_out_ptr,
  704. ctx_bufB_ptr,
  705. attn_o_inp_ptr,
  706. add_res_ptr,
  707. ff1_inp_ptr,
  708. gelu_inp_ptr,
  709. ff2_inp_ptr);
  710. return {output,
  711. inp_norm,
  712. qkv_tf,
  713. soft_out,
  714. ctx_bufB,
  715. attn_o_inp,
  716. add_res,
  717. ff1_inp,
  718. gelu_inp,
  719. ff2_inp,
  720. attn_prob_dropout_mask,
  721. attn_output_dropout_mask,
  722. layer_output_dropout_mask,
  723. attn_layer_norm_var,
  724. attn_layer_norm_mean,
  725. layer_norm_var,
  726. layer_norm_mean};
  727. }
  728. template <typename T>
  729. std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
  730. const torch::Tensor& grad_output,
  731. const torch::Tensor& output,
  732. const torch::Tensor& inp_norm,
  733. const torch::Tensor& qkv_tf,
  734. const torch::Tensor& soft_out,
  735. const torch::Tensor& ctx_bufB,
  736. const torch::Tensor& attn_o_inp,
  737. const torch::Tensor& add_res,
  738. const torch::Tensor& ff1_inp,
  739. const torch::Tensor& gelu_inp,
  740. const torch::Tensor& ff2_inp,
  741. const torch::Tensor& attn_prob_dropout_mask,
  742. const torch::Tensor& attn_output_dropout_mask,
  743. const torch::Tensor& layer_output_dropout_mask,
  744. const torch::Tensor& attn_layer_norm_var,
  745. const torch::Tensor& attn_layer_norm_mean,
  746. const torch::Tensor& layer_norm_var,
  747. const torch::Tensor& layer_norm_mean,
  748. const torch::Tensor& input,
  749. const torch::Tensor& input_mask,
  750. const torch::Tensor& attn_qkvw,
  751. const torch::Tensor& attn_qkvb,
  752. const torch::Tensor& attn_ow,
  753. const torch::Tensor& attn_ob,
  754. const torch::Tensor& attn_nw,
  755. const torch::Tensor& attn_nb,
  756. const torch::Tensor& inter_w,
  757. const torch::Tensor& inter_b,
  758. const torch::Tensor& output_w,
  759. const torch::Tensor& output_b,
  760. const torch::Tensor& norm_w,
  761. const torch::Tensor& norm_b)
  762. {
  763. auto g_output = grad_output.contiguous();
  764. CHECK_INPUT(g_output);
  765. CHECK_INPUT(output);
  766. CHECK_INPUT(inp_norm);
  767. CHECK_INPUT(qkv_tf);
  768. CHECK_INPUT(add_res);
  769. CHECK_INPUT(soft_out);
  770. CHECK_INPUT(ctx_bufB);
  771. CHECK_INPUT(attn_o_inp);
  772. CHECK_INPUT(ff1_inp);
  773. CHECK_INPUT(gelu_inp);
  774. CHECK_INPUT(ff2_inp);
  775. CHECK_INPUT(input);
  776. CHECK_INPUT(input_mask);
  777. CHECK_INPUT(attn_qkvw);
  778. CHECK_INPUT(attn_qkvb);
  779. CHECK_INPUT(attn_ow);
  780. CHECK_INPUT(attn_ob);
  781. CHECK_INPUT(attn_nw);
  782. CHECK_INPUT(attn_nb);
  783. CHECK_INPUT(inter_w);
  784. CHECK_INPUT(inter_b);
  785. CHECK_INPUT(output_w);
  786. CHECK_INPUT(output_b);
  787. CHECK_INPUT(norm_w);
  788. CHECK_INPUT(norm_b);
  789. unsigned bsz = g_output.size(0);
  790. std::shared_ptr<BertTransformerLayer<T>> layer =
  791. std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
  792. unsigned seq_len = layer->GetSeqLength();
  793. if (g_output.size(1) != seq_len) {
  794. seq_len = g_output.size(1);
  795. layer->SetSeqLength(seq_len);
  796. }
  797. auto options = torch::TensorOptions()
  798. .dtype(g_output.options().dtype())
  799. .layout(torch::kStrided)
  800. .device(torch::kCUDA)
  801. .requires_grad(true);
  802. auto workspace = torch::empty({get_workspace_size<T>(bsz,
  803. seq_len,
  804. layer->GetHiddenSize(),
  805. layer->GetIntermediateSize(),
  806. layer->GetNumHeads(),
  807. layer->IsTrainingMode(),
  808. layer->GeluCheckpoint())},
  809. options);
  810. Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
  811. auto grad_input = torch::empty_like(input);
  812. auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
  813. auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
  814. auto grad_attn_ow = torch::empty_like(attn_ow);
  815. auto grad_attn_ob = torch::empty_like(attn_ob);
  816. auto grad_attn_nw = torch::empty_like(attn_nw);
  817. auto grad_attn_nb = torch::empty_like(attn_nb);
  818. auto grad_inter_w = torch::empty_like(inter_w);
  819. auto grad_inter_b = torch::empty_like(inter_b);
  820. auto grad_output_w = torch::empty_like(output_w);
  821. auto grad_output_b = torch::empty_like(output_b);
  822. auto grad_norm_w = torch::empty_like(norm_w);
  823. auto grad_norm_b = torch::empty_like(norm_b);
  824. // inputs.
  825. const T* grad_output_ptr = (const T*)g_output.data_ptr();
  826. const T* input_ptr = (const T*)input.data_ptr();
  827. const T* output_ptr = (const T*)output.data_ptr();
  828. const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
  829. const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
  830. const T* add_res_ptr = (const T*)add_res.data_ptr();
  831. const T* k_tf_ptr =
  832. q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
  833. const T* v_tf_ptr =
  834. k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
  835. const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
  836. const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
  837. const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
  838. const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
  839. const T* soft_out_ptr = (const T*)soft_out.data_ptr();
  840. const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
  841. const T* input_mask_ptr = (const T*)input_mask.data_ptr();
  842. const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
  843. const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
  844. const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
  845. const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
  846. const T* inter_w_ptr = (const T*)inter_w.data_ptr();
  847. const T* inter_b_ptr = (const T*)inter_b.data_ptr();
  848. const T* output_w_ptr = (const T*)output_w.data_ptr();
  849. const T* norm_w_ptr = (const T*)norm_w.data_ptr();
  850. const T* norm_b_ptr = (const T*)norm_b.data_ptr();
  851. // outputs.
  852. T* grad_input_ptr = (T*)grad_input.data_ptr();
  853. T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
  854. T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
  855. T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
  856. T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
  857. T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
  858. T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
  859. T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
  860. T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
  861. T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
  862. T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
  863. T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
  864. T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
  865. layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
  866. (uint8_t*)attn_output_dropout_mask.data_ptr(),
  867. (uint8_t*)layer_output_dropout_mask.data_ptr(),
  868. (T*)attn_layer_norm_var.data_ptr(),
  869. (T*)attn_layer_norm_mean.data_ptr(),
  870. (T*)layer_norm_var.data_ptr(),
  871. (T*)layer_norm_mean.data_ptr());
  872. layer->Backward(bsz,
  873. grad_output_ptr,
  874. input_ptr,
  875. output_ptr,
  876. inp_norm_ptr,
  877. q_tf_ptr,
  878. k_tf_ptr,
  879. v_tf_ptr,
  880. soft_out_ptr,
  881. ctx_bufB_ptr,
  882. attn_o_inp_ptr,
  883. add_res_ptr,
  884. ff1_inp_ptr,
  885. gelu_inp_ptr,
  886. ff2_inp_ptr,
  887. input_mask_ptr,
  888. attn_qkvw_ptr,
  889. attn_ow_ptr,
  890. attn_nw_ptr,
  891. attn_nb_ptr,
  892. inter_w_ptr,
  893. inter_b_ptr,
  894. output_w_ptr,
  895. norm_w_ptr,
  896. norm_b_ptr,
  897. grad_input_ptr,
  898. grad_attn_qkvw_ptr,
  899. grad_attn_qkvb_ptr,
  900. grad_attn_ow_ptr,
  901. grad_attn_ob_ptr,
  902. grad_attn_nw_ptr,
  903. grad_attn_nb_ptr,
  904. grad_inter_w_ptr,
  905. grad_inter_b_ptr,
  906. grad_output_w_ptr,
  907. grad_output_b_ptr,
  908. grad_norm_w_ptr,
  909. grad_norm_b_ptr);
  910. return {grad_input,
  911. grad_attn_qkvw,
  912. grad_attn_qkvb,
  913. grad_attn_ow,
  914. grad_attn_ob,
  915. grad_attn_nw,
  916. grad_attn_nb,
  917. grad_inter_w,
  918. grad_inter_b,
  919. grad_output_w,
  920. grad_output_b,
  921. grad_norm_w,
  922. grad_norm_b};
  923. }
  924. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  925. {
  926. m.def("forward_fp32",
  927. &ds_transformer_forward<float>,
  928. "DeepSpeed Transformer forward with fp32 (CUDA)");
  929. m.def("forward_fp16",
  930. &ds_transformer_forward<__half>,
  931. "DeepSpeed Transformer forward with fp16 (CUDA)");
  932. m.def("backward_fp32",
  933. &ds_transformer_backward<float>,
  934. "DeepSpeed Transformer backward with fp32 (CUDA)");
  935. m.def("backward_fp16",
  936. &ds_transformer_backward<__half>,
  937. "DeepSpeed Transformer backward with fp16 (CUDA)");
  938. m.def("create_transformer_layer_fp32",
  939. &create_transformer_layer<float>,
  940. "Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
  941. m.def("create_transformer_layer_fp16",
  942. &create_transformer_layer<__half>,
  943. "Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
  944. }