ds_transformer_cuda.cpp 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <torch/extension.h>
  5. #include <cublas_v2.h>
  6. #include <cuda_fp16.h>
  7. #include <cuda_runtime.h>
  8. #include <type_traits>
  9. #include <unordered_map>
  10. #include <vector>
  11. #include "Timer.h"
  12. #include "context.h"
  13. #include "cublas_wrappers.h"
  14. #include "custom_cuda_layers.h"
  15. #include "ds_transformer_cuda.h"
  16. static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
  17. const int init_seq_length = 128;
  18. // C++ interface
  19. template <typename T>
  20. unsigned get_workspace_size(unsigned maxBatchSize,
  21. unsigned seq_len,
  22. unsigned hidden_size,
  23. unsigned intermediate_size,
  24. unsigned heads,
  25. bool training,
  26. bool gelu_checkpoint)
  27. {
  28. unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
  29. if (training) {
  30. workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
  31. workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
  32. 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
  33. if (gelu_checkpoint)
  34. workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
  35. }
  36. return workSpacesize; // * sizeof(T);
  37. }
  38. // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
  39. #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
  40. #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
  41. #define CHECK_INPUT(x) \
  42. CHECK_CUDA(x); \
  43. CHECK_CONTIGUOUS(x)
  44. template <typename T>
  45. BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
  46. unsigned batch_size,
  47. unsigned hidden_size,
  48. unsigned num_heads,
  49. unsigned intermediate_size,
  50. unsigned seq_length,
  51. float attn_prob_dropout_ratio,
  52. float hidden_output_dropout_ratio,
  53. float layer_norm_eps,
  54. bool pre_or_postLayerNorm,
  55. const std::vector<std::array<int, 3>>& gemm_algos,
  56. bool attn_dropout_checkpoint,
  57. bool normalize_invertible,
  58. bool gelu_checkpoint,
  59. bool stochastic_mode)
  60. : _layer_id(layer_id),
  61. _batch_size(batch_size),
  62. _hidden_size(hidden_size),
  63. _heads(num_heads),
  64. _intermediate_size(intermediate_size),
  65. _seq_length(seq_length),
  66. _training(true),
  67. _pre_or_postLayerNorm(pre_or_postLayerNorm),
  68. _attn_dropout_checkpoint(attn_dropout_checkpoint),
  69. _normalize_invertible(normalize_invertible),
  70. _gelu_checkpoint(gelu_checkpoint),
  71. _stochastic_mode(stochastic_mode),
  72. _stream(TrainingContext::Instance().GetCurrentStream()),
  73. _cublasHandle(TrainingContext::Instance().GetCublasHandle()),
  74. _qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
  75. 3 * hidden_size,
  76. hidden_size,
  77. gemm_algos[0])),
  78. _attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
  79. hidden_size,
  80. hidden_size,
  81. gemm_algos[0])),
  82. _attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
  83. seq_length,
  84. hidden_size,
  85. layer_norm_eps,
  86. true,
  87. !normalize_invertible)),
  88. _layer_norm(typename Normalize_Layer<T>::Config(batch_size,
  89. seq_length,
  90. hidden_size,
  91. layer_norm_eps,
  92. true,
  93. !normalize_invertible)),
  94. _ff1(typename FeedForward<T>::Config(batch_size * seq_length,
  95. _intermediate_size,
  96. hidden_size,
  97. gemm_algos[1])),
  98. _ff2(typename FeedForward<T>::Config(batch_size * seq_length,
  99. hidden_size,
  100. _intermediate_size,
  101. gemm_algos[2])),
  102. _softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
  103. _gelu(typename Gelu<T>::Config(_intermediate_size)),
  104. _attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
  105. _attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
  106. _layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
  107. _attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
  108. _seq_length,
  109. _seq_length,
  110. _hidden_size / _heads,
  111. (T(1.0) / T(sqrt(_hidden_size / _heads))),
  112. T(0.0),
  113. CUBLAS_OP_T,
  114. CUBLAS_OP_N,
  115. gemm_algos[3])),
  116. _attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
  117. _hidden_size / _heads,
  118. _seq_length,
  119. _seq_length,
  120. T(1.0),
  121. T(0.0),
  122. CUBLAS_OP_N,
  123. CUBLAS_OP_N,
  124. gemm_algos[4]))
  125. {
  126. assert(_hidden_size % _heads == 0);
  127. Initialize();
  128. }
  129. template <typename T>
  130. BertTransformerLayer<T>::~BertTransformerLayer()
  131. {
  132. }
  133. template <typename T>
  134. void BertTransformerLayer<T>::Initialize()
  135. {
  136. #ifndef __HIP_PLATFORM_HCC__
  137. if (std::is_same<T, __half>::value) cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
  138. #endif
  139. }
  140. template <typename T>
  141. void BertTransformerLayer<T>::Forward(unsigned bsz,
  142. const T* input_ptr,
  143. const T* input_mask_ptr,
  144. const T* attn_qkvw_ptr,
  145. const T* attn_qkvb_ptr,
  146. const T* attn_ow_ptr,
  147. const T* attn_ob_ptr,
  148. const T* attn_nw_ptr,
  149. const T* attn_nb_ptr,
  150. const T* inter_w_ptr,
  151. const T* inter_b_ptr,
  152. const T* output_w_ptr,
  153. const T* output_b_ptr,
  154. const T* norm_w_ptr,
  155. const T* norm_b_ptr,
  156. T* out_ptr,
  157. T* inp_norm_ptr,
  158. T* q_tf_ptr,
  159. T* k_tf_ptr,
  160. T* v_tf_ptr,
  161. T* soft_out_ptr,
  162. T* ctx_bufB_ptr,
  163. T* attn_o_inp_ptr,
  164. T* add_res_ptr,
  165. T* ff1_inp_ptr,
  166. T* gelu_inp_ptr,
  167. T* ff2_inp_ptr)
  168. {
  169. cublasSetStream(_cublasHandle, _stream);
  170. if (!_stochastic_mode) cudaStreamSynchronize(_stream);
  171. T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
  172. size_t small_buf_size = bsz * _seq_length * _hidden_size;
  173. T* buf_0 = workspace;
  174. T* buf_1 = buf_0 + small_buf_size;
  175. T* buf_2 = buf_1;
  176. if (_normalize_invertible) {
  177. add_res_ptr = buf_1 + 3 * small_buf_size;
  178. buf_2 = add_res_ptr;
  179. }
  180. if (_gelu_checkpoint) buf_2 += small_buf_size;
  181. if (_attn_dropout_checkpoint)
  182. ctx_bufB_ptr =
  183. (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
  184. : (buf_1 + 4 * small_buf_size));
  185. int bsz_seq = bsz * _seq_length;
  186. if (_pre_or_postLayerNorm) {
  187. if (_layer_norm.UseMean())
  188. _layer_norm.ForwardCheckpoint(
  189. bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  190. else
  191. _layer_norm.Forward(
  192. bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  193. }
  194. if (_pre_or_postLayerNorm)
  195. _qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
  196. else
  197. _qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
  198. launch_bias_add_transform_0213<T>(
  199. q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
  200. int bsz_heads = bsz * _heads;
  201. // attention scores
  202. _attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
  203. // Softmax + Mask
  204. _softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
  205. // attn prob dropout.
  206. _attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
  207. // attention context
  208. _attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
  209. launch_transform4d_0213<T>(
  210. attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
  211. if (_pre_or_postLayerNorm)
  212. _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
  213. else
  214. _attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
  215. // attn output dropout.
  216. if (_pre_or_postLayerNorm)
  217. _attn_output_dropout.ForwardWithBias(
  218. bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
  219. else
  220. _attn_output_dropout.ForwardWithBias(
  221. bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
  222. if (_pre_or_postLayerNorm) {
  223. if (_attn_layer_norm.UseMean())
  224. _attn_layer_norm.ForwardCheckpoint(
  225. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  226. else
  227. _attn_layer_norm.Forward(
  228. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  229. } else {
  230. if (_attn_layer_norm.UseMean())
  231. _attn_layer_norm.ForwardCheckpoint(
  232. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  233. else
  234. _attn_layer_norm.Forward(
  235. bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
  236. }
  237. _ff1.Forward(bsz_seq,
  238. ff1_inp_ptr,
  239. inter_w_ptr,
  240. (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
  241. _cublasHandle);
  242. _gelu.ForwardWithBiasAdd(bsz_seq,
  243. (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
  244. inter_b_ptr,
  245. (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
  246. _stream);
  247. _ff2.Forward(
  248. bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
  249. // layer output dropout.
  250. if (_pre_or_postLayerNorm)
  251. _layer_output_dropout.ForwardWithBias(
  252. bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
  253. else
  254. _layer_output_dropout.ForwardWithBias(
  255. bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
  256. if (!_pre_or_postLayerNorm) {
  257. if (_layer_norm.UseMean())
  258. _layer_norm.ForwardCheckpoint(
  259. bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  260. else
  261. _layer_norm.Forward(
  262. bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
  263. }
  264. }
  265. template <typename T>
  266. void BertTransformerLayer<T>::Backward(unsigned bsz,
  267. const T* grad_output_ptr,
  268. const T* input_ptr,
  269. const T* output_ptr,
  270. const T* inp_norm_ptr,
  271. const T* q_tf_ptr,
  272. const T* k_tf_ptr,
  273. const T* v_tf_ptr,
  274. const T* soft_out_ptr,
  275. const T* ctx_bufB_ptr,
  276. const T* attn_o_inp_ptr,
  277. const T* add_res_ptr,
  278. const T* ff1_inp_ptr,
  279. const T* gelu_inp_ptr,
  280. const T* ff2_inp_ptr,
  281. const T* input_mask_ptr,
  282. const T* attn_qkvw_ptr,
  283. const T* attn_ow_ptr,
  284. const T* attn_nw_ptr,
  285. const T* attn_nb_ptr,
  286. const T* inter_w_ptr,
  287. const T* inter_b_ptr,
  288. const T* output_w_ptr,
  289. const T* norm_w_ptr,
  290. const T* norm_b_ptr,
  291. T* grad_input_ptr,
  292. T* grad_attn_qkvw_ptr,
  293. T* grad_attn_qkvb_ptr,
  294. T* grad_attn_ow_ptr,
  295. T* grad_attn_ob_ptr,
  296. T* grad_attn_nw_ptr,
  297. T* grad_attn_nb_ptr,
  298. T* grad_inter_w_ptr,
  299. T* grad_inter_b_ptr,
  300. T* grad_output_w_ptr,
  301. T* grad_output_b_ptr,
  302. T* grad_norm_w_ptr,
  303. T* grad_norm_b_ptr)
  304. {
  305. cublasSetStream(_cublasHandle, _stream);
  306. if (!_stochastic_mode) cudaStreamSynchronize(_stream);
  307. T* workspace = static_cast<T*>(TrainingContext::Instance().GetWorkSpace());
  308. size_t small_buf_size = bsz * _seq_length * _hidden_size;
  309. T* buf_0 = workspace;
  310. T* buf_1 = buf_0 + small_buf_size;
  311. T* buf_2 = buf_1 + small_buf_size;
  312. T* buf_3 = buf_2 + small_buf_size;
  313. T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
  314. : buf_3 + small_buf_size);
  315. T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
  316. cudaStream_t streams[2] = {_stream, _stream};
  317. int bsz_seq = bsz * _seq_length;
  318. int bsz_heads = bsz * _heads;
  319. if (!_pre_or_postLayerNorm) {
  320. if (_layer_norm.UseMean())
  321. _layer_norm.Backward(bsz_seq,
  322. grad_output_ptr,
  323. norm_w_ptr,
  324. grad_norm_w_ptr,
  325. grad_norm_b_ptr,
  326. streams,
  327. buf_1,
  328. inp_norm_ptr);
  329. else
  330. _layer_norm.Backward(bsz_seq,
  331. grad_output_ptr,
  332. norm_w_ptr,
  333. norm_b_ptr,
  334. grad_norm_w_ptr,
  335. grad_norm_b_ptr,
  336. streams,
  337. buf_1,
  338. output_ptr);
  339. }
  340. if (_pre_or_postLayerNorm)
  341. _layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
  342. else
  343. _layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
  344. const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
  345. ? buf_0
  346. : (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
  347. if (_gelu_checkpoint)
  348. _gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
  349. _ff2.Backward(bsz_seq,
  350. layer_dropout_buf,
  351. (_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
  352. output_w_ptr,
  353. grad_output_w_ptr,
  354. grad_output_b_ptr,
  355. _cublasHandle,
  356. _stream,
  357. ff2_buf);
  358. _gelu.Backward(
  359. bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
  360. _ff1.Backward(bsz_seq,
  361. ff2_buf,
  362. ff1_inp_ptr,
  363. inter_w_ptr,
  364. grad_inter_w_ptr,
  365. grad_inter_b_ptr,
  366. _cublasHandle,
  367. _stream,
  368. buf_3);
  369. if (!_pre_or_postLayerNorm)
  370. launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
  371. if (_pre_or_postLayerNorm) {
  372. if (_attn_layer_norm.UseMean())
  373. _attn_layer_norm.BackwardFusedAdd(bsz_seq,
  374. buf_3,
  375. grad_output_ptr,
  376. attn_nw_ptr,
  377. grad_attn_nw_ptr,
  378. grad_attn_nb_ptr,
  379. streams,
  380. buf_0,
  381. add_res_ptr);
  382. else
  383. _attn_layer_norm.BackwardFusedAdd(bsz_seq,
  384. buf_3,
  385. grad_output_ptr,
  386. attn_nw_ptr,
  387. attn_nb_ptr,
  388. grad_attn_nw_ptr,
  389. grad_attn_nb_ptr,
  390. streams,
  391. buf_0,
  392. ff1_inp_ptr);
  393. } else {
  394. if (_attn_layer_norm.UseMean())
  395. _attn_layer_norm.Backward(bsz_seq,
  396. buf_2,
  397. attn_nw_ptr,
  398. grad_attn_nw_ptr,
  399. grad_attn_nb_ptr,
  400. streams,
  401. buf_0,
  402. add_res_ptr);
  403. else
  404. _attn_layer_norm.Backward(bsz_seq,
  405. buf_2,
  406. attn_nw_ptr,
  407. attn_nb_ptr,
  408. grad_attn_nw_ptr,
  409. grad_attn_nb_ptr,
  410. streams,
  411. buf_0,
  412. ff1_inp_ptr);
  413. }
  414. _attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
  415. T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
  416. _attn_out_linear.Backward(bsz_seq,
  417. attn_output_dropout_buf,
  418. attn_o_inp_ptr,
  419. attn_ow_ptr,
  420. grad_attn_ow_ptr,
  421. grad_attn_ob_ptr,
  422. _cublasHandle,
  423. _stream,
  424. buf_1);
  425. launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
  426. if (_attn_prob_dropout.HasDropout()) {
  427. if (_attn_dropout_checkpoint)
  428. _attn_prob_dropout.Forward(
  429. bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
  430. _attn_context.Backward(bsz_heads,
  431. buf_2,
  432. v_tf_ptr,
  433. (_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
  434. _cublasHandle,
  435. buf_3,
  436. ff2_buf);
  437. } else
  438. _attn_context.Backward(
  439. bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
  440. _attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
  441. _softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
  442. _attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
  443. launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
  444. if (_pre_or_postLayerNorm)
  445. _qkv_linear.Backward(bsz_seq,
  446. ff2_buf,
  447. inp_norm_ptr,
  448. attn_qkvw_ptr,
  449. grad_attn_qkvw_ptr,
  450. grad_attn_qkvb_ptr,
  451. _cublasHandle,
  452. _stream,
  453. buf_2);
  454. else
  455. _qkv_linear.Backward(bsz_seq,
  456. ff2_buf,
  457. input_ptr,
  458. attn_qkvw_ptr,
  459. grad_attn_qkvw_ptr,
  460. grad_attn_qkvb_ptr,
  461. _cublasHandle,
  462. _stream,
  463. buf_2);
  464. if (_pre_or_postLayerNorm) {
  465. if (_layer_norm.UseMean())
  466. _layer_norm.BackwardFusedAdd(bsz_seq,
  467. buf_2,
  468. buf_0,
  469. norm_w_ptr,
  470. grad_norm_w_ptr,
  471. grad_norm_b_ptr,
  472. streams,
  473. grad_input_ptr,
  474. input_ptr);
  475. else
  476. _layer_norm.BackwardFusedAdd(bsz_seq,
  477. buf_2,
  478. buf_0,
  479. norm_w_ptr,
  480. norm_b_ptr,
  481. grad_norm_w_ptr,
  482. grad_norm_b_ptr,
  483. streams,
  484. grad_input_ptr,
  485. inp_norm_ptr);
  486. } else
  487. launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
  488. }
  489. template <typename T>
  490. void BertTransformerLayer<T>::SetTrainingMode(bool training)
  491. {
  492. // Dropout will be skipped when not in training model.
  493. _attn_prob_dropout.SetTrainingMode(training);
  494. _attn_output_dropout.SetTrainingMode(training);
  495. _layer_output_dropout.SetTrainingMode(training);
  496. }
  497. template <typename T>
  498. void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
  499. uint8_t* attn_output_dropout_mask_ptr,
  500. uint8_t* layer_output_dropout_mask_ptr,
  501. T* attn_layer_norm_var,
  502. T* attn_layer_norm_mean,
  503. T* layer_norm_var,
  504. T* layer_norm_mean)
  505. {
  506. _attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
  507. _attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
  508. _layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
  509. _attn_layer_norm.SetVar(attn_layer_norm_var);
  510. _attn_layer_norm.SetMean(attn_layer_norm_mean);
  511. _layer_norm.SetVar(layer_norm_var);
  512. _layer_norm.SetMean(layer_norm_mean);
  513. }
  514. template <typename T>
  515. void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
  516. {
  517. _seq_length = seq_len;
  518. _softmax.SetSeqLength(_seq_length);
  519. _attn_prob_dropout.SetDimension(_seq_length);
  520. _attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
  521. _attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
  522. }
  523. template <typename T>
  524. int create_transformer_layer(unsigned layer_id,
  525. unsigned batch_size,
  526. unsigned hidden_dim,
  527. unsigned num_heads,
  528. unsigned intermediate_size,
  529. float attn_dropout_ratio,
  530. float hidden_dropout_ratio,
  531. float layer_norm_eps,
  532. int seed,
  533. bool pre_or_postLayerNorm,
  534. bool test_gemm,
  535. bool attn_dropout_checkpoint,
  536. bool normalize_invertible,
  537. bool gelu_checkpoint,
  538. bool stochastic_mode)
  539. {
  540. TrainingContext::Instance().SetSeed(seed);
  541. TrainingContext::Instance().TestGemmFP16(
  542. test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
  543. auto layer =
  544. std::make_shared<BertTransformerLayer<T>>(layer_id,
  545. batch_size,
  546. hidden_dim,
  547. num_heads,
  548. intermediate_size,
  549. init_seq_length,
  550. attn_dropout_ratio,
  551. hidden_dropout_ratio,
  552. layer_norm_eps,
  553. pre_or_postLayerNorm,
  554. TrainingContext::Instance().GetGemmAlgos(),
  555. attn_dropout_checkpoint,
  556. normalize_invertible,
  557. gelu_checkpoint,
  558. stochastic_mode);
  559. s_transformer_layers[layer_id] = layer;
  560. std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
  561. std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
  562. << std::endl;
  563. return 0;
  564. }
  565. template <typename T>
  566. std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
  567. const torch::Tensor& input,
  568. const torch::Tensor& input_mask,
  569. const torch::Tensor& attn_qkvw,
  570. const torch::Tensor& attn_qkvb,
  571. const torch::Tensor& attn_ow,
  572. const torch::Tensor& attn_ob,
  573. const torch::Tensor& attn_nw,
  574. const torch::Tensor& attn_nb,
  575. const torch::Tensor& inter_w,
  576. const torch::Tensor& inter_b,
  577. const torch::Tensor& output_w,
  578. const torch::Tensor& output_b,
  579. const torch::Tensor& norm_w,
  580. const torch::Tensor& norm_b,
  581. bool training_mode,
  582. bool prelayernorm,
  583. bool attn_dropout_checkpoint,
  584. bool normalize_invertible,
  585. bool gelu_checkpoint)
  586. {
  587. CHECK_INPUT(input);
  588. CHECK_INPUT(input_mask);
  589. CHECK_INPUT(attn_qkvw);
  590. CHECK_INPUT(attn_qkvb);
  591. CHECK_INPUT(attn_ow);
  592. CHECK_INPUT(attn_ob);
  593. CHECK_INPUT(attn_nw);
  594. CHECK_INPUT(attn_nb);
  595. CHECK_INPUT(inter_w);
  596. CHECK_INPUT(inter_b);
  597. CHECK_INPUT(output_w);
  598. CHECK_INPUT(output_b);
  599. CHECK_INPUT(norm_w);
  600. CHECK_INPUT(norm_b);
  601. unsigned bsz = input.size(0);
  602. const T* input_ptr = (const T*)input.data_ptr();
  603. const T* input_mask_ptr = (const T*)input_mask.data_ptr();
  604. const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
  605. const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
  606. const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
  607. const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
  608. const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
  609. const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
  610. const T* inter_w_ptr = (const T*)inter_w.data_ptr();
  611. const T* inter_b_ptr = (const T*)inter_b.data_ptr();
  612. const T* output_w_ptr = (const T*)output_w.data_ptr();
  613. const T* output_b_ptr = (const T*)output_b.data_ptr();
  614. const T* norm_w_ptr = (const T*)norm_w.data_ptr();
  615. const T* norm_b_ptr = (const T*)norm_b.data_ptr();
  616. auto output = torch::empty_like(input);
  617. T* out_ptr = (T*)output.data_ptr();
  618. auto options = torch::TensorOptions()
  619. .dtype(input.options().dtype())
  620. .layout(torch::kStrided)
  621. .device(torch::kCUDA)
  622. .requires_grad(true);
  623. auto uint8_options = torch::TensorOptions()
  624. .dtype(torch::kInt8)
  625. .layout(torch::kStrided)
  626. .device(torch::kCUDA)
  627. .requires_grad(false);
  628. std::shared_ptr<BertTransformerLayer<T>> layer =
  629. std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
  630. unsigned seq_len = layer->GetSeqLength();
  631. if (input.size(1) != seq_len) {
  632. seq_len = input.size(1);
  633. layer->SetSeqLength(seq_len);
  634. }
  635. auto workspace = torch::empty({get_workspace_size<T>(bsz,
  636. seq_len,
  637. layer->GetHiddenSize(),
  638. layer->GetIntermediateSize(),
  639. layer->GetNumHeads(),
  640. layer->IsTrainingMode(),
  641. layer->GeluCheckpoint())},
  642. options);
  643. TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
  644. auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
  645. auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
  646. auto attn_o_inp = torch::empty_like(input);
  647. auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
  648. auto attn_prob_dropout_mask =
  649. torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
  650. auto attn_output_dropout_mask =
  651. torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
  652. auto layer_output_dropout_mask =
  653. torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
  654. auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
  655. auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
  656. auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
  657. auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
  658. T* inp_norm_ptr = (T*)inp_norm.data_ptr();
  659. T* add_res_ptr = (T*)add_res.data_ptr();
  660. T* q_tf_ptr = (T*)qkv_tf.data_ptr();
  661. T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
  662. T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
  663. T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
  664. torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
  665. torch::Tensor gelu_inp =
  666. (gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
  667. auto ff1_inp = torch::empty_like(input);
  668. T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
  669. T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
  670. T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
  671. torch::Tensor soft_out =
  672. torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
  673. torch::Tensor ctx_bufB =
  674. (attn_dropout_checkpoint
  675. ? soft_out
  676. : torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
  677. T* soft_out_ptr = (T*)soft_out.data_ptr();
  678. T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
  679. layer->SetTrainingMode(training_mode);
  680. layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
  681. (uint8_t*)attn_output_dropout_mask.data_ptr(),
  682. (uint8_t*)layer_output_dropout_mask.data_ptr(),
  683. (T*)attn_layer_norm_var.data_ptr(),
  684. (T*)attn_layer_norm_mean.data_ptr(),
  685. (T*)layer_norm_var.data_ptr(),
  686. (T*)layer_norm_mean.data_ptr());
  687. layer->Forward(bsz,
  688. input_ptr,
  689. input_mask_ptr,
  690. attn_qkvw_ptr,
  691. attn_qkvb_ptr,
  692. attn_ow_ptr,
  693. attn_ob_ptr,
  694. attn_nw_ptr,
  695. attn_nb_ptr,
  696. inter_w_ptr,
  697. inter_b_ptr,
  698. output_w_ptr,
  699. output_b_ptr,
  700. norm_w_ptr,
  701. norm_b_ptr,
  702. out_ptr,
  703. inp_norm_ptr,
  704. q_tf_ptr,
  705. k_tf_ptr,
  706. v_tf_ptr,
  707. soft_out_ptr,
  708. ctx_bufB_ptr,
  709. attn_o_inp_ptr,
  710. add_res_ptr,
  711. ff1_inp_ptr,
  712. gelu_inp_ptr,
  713. ff2_inp_ptr);
  714. return {output,
  715. inp_norm,
  716. qkv_tf,
  717. soft_out,
  718. ctx_bufB,
  719. attn_o_inp,
  720. add_res,
  721. ff1_inp,
  722. gelu_inp,
  723. ff2_inp,
  724. attn_prob_dropout_mask,
  725. attn_output_dropout_mask,
  726. layer_output_dropout_mask,
  727. attn_layer_norm_var,
  728. attn_layer_norm_mean,
  729. layer_norm_var,
  730. layer_norm_mean};
  731. }
  732. template <typename T>
  733. std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
  734. const torch::Tensor& grad_output,
  735. const torch::Tensor& output,
  736. const torch::Tensor& inp_norm,
  737. const torch::Tensor& qkv_tf,
  738. const torch::Tensor& soft_out,
  739. const torch::Tensor& ctx_bufB,
  740. const torch::Tensor& attn_o_inp,
  741. const torch::Tensor& add_res,
  742. const torch::Tensor& ff1_inp,
  743. const torch::Tensor& gelu_inp,
  744. const torch::Tensor& ff2_inp,
  745. const torch::Tensor& attn_prob_dropout_mask,
  746. const torch::Tensor& attn_output_dropout_mask,
  747. const torch::Tensor& layer_output_dropout_mask,
  748. const torch::Tensor& attn_layer_norm_var,
  749. const torch::Tensor& attn_layer_norm_mean,
  750. const torch::Tensor& layer_norm_var,
  751. const torch::Tensor& layer_norm_mean,
  752. const torch::Tensor& input,
  753. const torch::Tensor& input_mask,
  754. const torch::Tensor& attn_qkvw,
  755. const torch::Tensor& attn_qkvb,
  756. const torch::Tensor& attn_ow,
  757. const torch::Tensor& attn_ob,
  758. const torch::Tensor& attn_nw,
  759. const torch::Tensor& attn_nb,
  760. const torch::Tensor& inter_w,
  761. const torch::Tensor& inter_b,
  762. const torch::Tensor& output_w,
  763. const torch::Tensor& output_b,
  764. const torch::Tensor& norm_w,
  765. const torch::Tensor& norm_b)
  766. {
  767. auto g_output = grad_output.contiguous();
  768. CHECK_INPUT(g_output);
  769. CHECK_INPUT(output);
  770. CHECK_INPUT(inp_norm);
  771. CHECK_INPUT(qkv_tf);
  772. CHECK_INPUT(add_res);
  773. CHECK_INPUT(soft_out);
  774. CHECK_INPUT(ctx_bufB);
  775. CHECK_INPUT(attn_o_inp);
  776. CHECK_INPUT(ff1_inp);
  777. CHECK_INPUT(gelu_inp);
  778. CHECK_INPUT(ff2_inp);
  779. CHECK_INPUT(input);
  780. CHECK_INPUT(input_mask);
  781. CHECK_INPUT(attn_qkvw);
  782. CHECK_INPUT(attn_qkvb);
  783. CHECK_INPUT(attn_ow);
  784. CHECK_INPUT(attn_ob);
  785. CHECK_INPUT(attn_nw);
  786. CHECK_INPUT(attn_nb);
  787. CHECK_INPUT(inter_w);
  788. CHECK_INPUT(inter_b);
  789. CHECK_INPUT(output_w);
  790. CHECK_INPUT(output_b);
  791. CHECK_INPUT(norm_w);
  792. CHECK_INPUT(norm_b);
  793. unsigned bsz = g_output.size(0);
  794. std::shared_ptr<BertTransformerLayer<T>> layer =
  795. std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
  796. unsigned seq_len = layer->GetSeqLength();
  797. if (g_output.size(1) != seq_len) {
  798. seq_len = g_output.size(1);
  799. layer->SetSeqLength(seq_len);
  800. }
  801. auto options = torch::TensorOptions()
  802. .dtype(g_output.options().dtype())
  803. .layout(torch::kStrided)
  804. .device(torch::kCUDA)
  805. .requires_grad(true);
  806. auto workspace = torch::empty({get_workspace_size<T>(bsz,
  807. seq_len,
  808. layer->GetHiddenSize(),
  809. layer->GetIntermediateSize(),
  810. layer->GetNumHeads(),
  811. layer->IsTrainingMode(),
  812. layer->GeluCheckpoint())},
  813. options);
  814. TrainingContext::Instance().SetWorkSpace((T*)workspace.data_ptr());
  815. auto grad_input = torch::empty_like(input);
  816. auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
  817. auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
  818. auto grad_attn_ow = torch::empty_like(attn_ow);
  819. auto grad_attn_ob = torch::empty_like(attn_ob);
  820. auto grad_attn_nw = torch::empty_like(attn_nw);
  821. auto grad_attn_nb = torch::empty_like(attn_nb);
  822. auto grad_inter_w = torch::empty_like(inter_w);
  823. auto grad_inter_b = torch::empty_like(inter_b);
  824. auto grad_output_w = torch::empty_like(output_w);
  825. auto grad_output_b = torch::empty_like(output_b);
  826. auto grad_norm_w = torch::empty_like(norm_w);
  827. auto grad_norm_b = torch::empty_like(norm_b);
  828. // inputs.
  829. const T* grad_output_ptr = (const T*)g_output.data_ptr();
  830. const T* input_ptr = (const T*)input.data_ptr();
  831. const T* output_ptr = (const T*)output.data_ptr();
  832. const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
  833. const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
  834. const T* add_res_ptr = (const T*)add_res.data_ptr();
  835. const T* k_tf_ptr =
  836. q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
  837. const T* v_tf_ptr =
  838. k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
  839. const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
  840. const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
  841. const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
  842. const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
  843. const T* soft_out_ptr = (const T*)soft_out.data_ptr();
  844. const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
  845. const T* input_mask_ptr = (const T*)input_mask.data_ptr();
  846. const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
  847. const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
  848. const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
  849. const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
  850. const T* inter_w_ptr = (const T*)inter_w.data_ptr();
  851. const T* inter_b_ptr = (const T*)inter_b.data_ptr();
  852. const T* output_w_ptr = (const T*)output_w.data_ptr();
  853. const T* norm_w_ptr = (const T*)norm_w.data_ptr();
  854. const T* norm_b_ptr = (const T*)norm_b.data_ptr();
  855. // outputs.
  856. T* grad_input_ptr = (T*)grad_input.data_ptr();
  857. T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
  858. T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
  859. T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
  860. T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
  861. T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
  862. T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
  863. T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
  864. T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
  865. T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
  866. T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
  867. T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
  868. T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
  869. layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
  870. (uint8_t*)attn_output_dropout_mask.data_ptr(),
  871. (uint8_t*)layer_output_dropout_mask.data_ptr(),
  872. (T*)attn_layer_norm_var.data_ptr(),
  873. (T*)attn_layer_norm_mean.data_ptr(),
  874. (T*)layer_norm_var.data_ptr(),
  875. (T*)layer_norm_mean.data_ptr());
  876. layer->Backward(bsz,
  877. grad_output_ptr,
  878. input_ptr,
  879. output_ptr,
  880. inp_norm_ptr,
  881. q_tf_ptr,
  882. k_tf_ptr,
  883. v_tf_ptr,
  884. soft_out_ptr,
  885. ctx_bufB_ptr,
  886. attn_o_inp_ptr,
  887. add_res_ptr,
  888. ff1_inp_ptr,
  889. gelu_inp_ptr,
  890. ff2_inp_ptr,
  891. input_mask_ptr,
  892. attn_qkvw_ptr,
  893. attn_ow_ptr,
  894. attn_nw_ptr,
  895. attn_nb_ptr,
  896. inter_w_ptr,
  897. inter_b_ptr,
  898. output_w_ptr,
  899. norm_w_ptr,
  900. norm_b_ptr,
  901. grad_input_ptr,
  902. grad_attn_qkvw_ptr,
  903. grad_attn_qkvb_ptr,
  904. grad_attn_ow_ptr,
  905. grad_attn_ob_ptr,
  906. grad_attn_nw_ptr,
  907. grad_attn_nb_ptr,
  908. grad_inter_w_ptr,
  909. grad_inter_b_ptr,
  910. grad_output_w_ptr,
  911. grad_output_b_ptr,
  912. grad_norm_w_ptr,
  913. grad_norm_b_ptr);
  914. return {grad_input,
  915. grad_attn_qkvw,
  916. grad_attn_qkvb,
  917. grad_attn_ow,
  918. grad_attn_ob,
  919. grad_attn_nw,
  920. grad_attn_nb,
  921. grad_inter_w,
  922. grad_inter_b,
  923. grad_output_w,
  924. grad_output_b,
  925. grad_norm_w,
  926. grad_norm_b};
  927. }
  928. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
  929. {
  930. m.def("forward_fp32",
  931. &ds_transformer_forward<float>,
  932. "DeepSpeed Transformer forward with fp32 (CUDA)");
  933. m.def("forward_fp16",
  934. &ds_transformer_forward<__half>,
  935. "DeepSpeed Transformer forward with fp16 (CUDA)");
  936. m.def("backward_fp32",
  937. &ds_transformer_backward<float>,
  938. "DeepSpeed Transformer backward with fp32 (CUDA)");
  939. m.def("backward_fp16",
  940. &ds_transformer_backward<__half>,
  941. "DeepSpeed Transformer backward with fp16 (CUDA)");
  942. m.def("create_transformer_layer_fp32",
  943. &create_transformer_layer<float>,
  944. "Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
  945. m.def("create_transformer_layer_fp16",
  946. &create_transformer_layer<__half>,
  947. "Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
  948. }