dropout_kernels.cu 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. const int unroll_factor = 4;
  6. __global__ void dropout_kernel(const int N,
  7. const float ratio,
  8. float* out,
  9. const float* Xdata,
  10. uint8_t* mask,
  11. std::pair<uint64_t, uint64_t> seed)
  12. {
  13. const float scale = 1. / (1. - ratio);
  14. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  15. curandStatePhilox4_32_10_t state;
  16. curand_init(seed.first, idx, seed.second, &state);
  17. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  18. {
  19. float4 rand = curand_uniform4(&state);
  20. uint8_t m[unroll_factor];
  21. m[0] = (uint8_t)(rand.x > ratio);
  22. m[1] = (uint8_t)(rand.y > ratio);
  23. m[2] = (uint8_t)(rand.z > ratio);
  24. m[3] = (uint8_t)(rand.w > ratio);
  25. int i = j * unroll_factor;
  26. mask[i] = (uint8_t)m[0];
  27. mask[i + 1] = (uint8_t)m[1];
  28. mask[i + 2] = (uint8_t)m[2];
  29. mask[i + 3] = (uint8_t)m[3];
  30. out[i] = Xdata[i] * scale * m[0];
  31. out[i + 1] = Xdata[i + 1] * scale * m[1];
  32. out[i + 2] = Xdata[i + 2] * scale * m[2];
  33. out[i + 3] = Xdata[i + 3] * scale * m[3];
  34. }
  35. int high_index =
  36. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  37. if (N > high_index) {
  38. float4 rand = curand_uniform4(&state);
  39. float* rand_data = &(rand.x);
  40. int k = 0;
  41. for (int i = high_index; i < N; i++) {
  42. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  43. out[i] = Xdata[i] * scale * m;
  44. mask[i] = m;
  45. }
  46. }
  47. }
  48. __global__ void dropout_kernel(const int N,
  49. const float ratio,
  50. __half* out,
  51. const __half* Xdata,
  52. uint8_t* mask,
  53. std::pair<uint64_t, uint64_t> seed)
  54. {
  55. const float scale = 1. / (1. - ratio);
  56. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  57. curandStatePhilox4_32_10_t state;
  58. curand_init(seed.first, idx, seed.second, &state);
  59. #ifdef __STOCHASTIC_MODE__
  60. const __half2 h_scale = __float2half2_rn(scale);
  61. const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
  62. float2* out_cast = reinterpret_cast<float2*>(out);
  63. uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
  64. uint32_t m_32;
  65. uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
  66. float2 result_f;
  67. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  68. __half2 mask_h[2];
  69. float2 mask_f[2];
  70. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  71. {
  72. float2 x_f = x_cast[j];
  73. __half2* x_h = reinterpret_cast<__half2*>(&x_f);
  74. float4 rand = curand_uniform4(&state);
  75. m[0] = (uint8_t)(rand.x > ratio);
  76. m[1] = (uint8_t)(rand.y > ratio);
  77. m[2] = (uint8_t)(rand.z > ratio);
  78. m[3] = (uint8_t)(rand.w > ratio);
  79. float* mask_f_data = &mask_f[0].x;
  80. #pragma unroll
  81. for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
  82. mask_h[0] = __float22half2_rn(mask_f[0]);
  83. mask_h[1] = __float22half2_rn(mask_f[1]);
  84. result_h[0] = x_h[0] * h_scale * mask_h[0];
  85. result_h[1] = x_h[1] * h_scale * mask_h[1];
  86. out_cast[j] = result_f;
  87. mask_cast[j] = m_32;
  88. }
  89. #else
  90. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  91. {
  92. int i = j * unroll_factor;
  93. const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
  94. float2 vals_half_f[2];
  95. vals_half_f[0] = __half22float2(vals_half[0]);
  96. vals_half_f[1] = __half22float2(vals_half[1]);
  97. uint8_t m[unroll_factor];
  98. float4 rand = curand_uniform4(&state);
  99. m[0] = (uint8_t)(rand.x > ratio);
  100. m[1] = (uint8_t)(rand.y > ratio);
  101. m[2] = (uint8_t)(rand.z > ratio);
  102. m[3] = (uint8_t)(rand.w > ratio);
  103. out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
  104. out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
  105. out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
  106. out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
  107. mask[i] = m[0];
  108. mask[i + 1] = m[1];
  109. mask[i + 2] = m[2];
  110. mask[i + 3] = m[3];
  111. }
  112. #endif
  113. int high_index =
  114. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  115. if (N > high_index) {
  116. float4 rand = curand_uniform4(&state);
  117. float* rand_data = &(rand.x);
  118. int k = 0;
  119. for (int i = high_index; i < N; i++) {
  120. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  121. out[i] = __float2half((float)Xdata[i] * scale * m);
  122. mask[i] = m;
  123. }
  124. }
  125. }
  126. __global__ void dropout_kernel_bwd(const int N,
  127. const float ratio,
  128. const float* Xdata,
  129. float* out,
  130. uint8_t* mask,
  131. std::pair<uint64_t, uint64_t> seed)
  132. {
  133. const float scale = 1. / (1. - ratio);
  134. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  135. {
  136. int i = j * unroll_factor;
  137. out[i] = mask[i] ? Xdata[i] * scale : 0.0;
  138. out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
  139. out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
  140. out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
  141. }
  142. int high_index =
  143. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  144. if (N > high_index) {
  145. for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
  146. }
  147. }
  148. __global__ void dropout_kernel_bwd(const int N,
  149. const float ratio,
  150. const __half* Xdata,
  151. __half* out,
  152. uint8_t* mask,
  153. std::pair<uint64_t, uint64_t> seed)
  154. {
  155. const float scale = 1. / (1. - ratio);
  156. #ifdef __STOCHASTIC_MODE__
  157. const __half2 h_scale = __float2half2_rn(scale);
  158. const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
  159. float2* out_cast = reinterpret_cast<float2*>(out);
  160. uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
  161. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  162. {
  163. float2 x_f = x_cast[j];
  164. __half2* x_h = reinterpret_cast<__half2*>(&x_f);
  165. uint32_t m_32 = mask_cast[j];
  166. uint8_t* m = (uint8_t*)&m_32;
  167. __half2 mask_h[2];
  168. float2 mask_f[2];
  169. float* mask_f_data = &mask_f[0].x;
  170. #pragma unroll
  171. for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
  172. #pragma unroll
  173. for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
  174. float2 result_f;
  175. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  176. result_h[0] = x_h[0] * h_scale * mask_h[0];
  177. result_h[1] = x_h[1] * h_scale * mask_h[1];
  178. out_cast[j] = result_f;
  179. }
  180. #else
  181. const __half h_scale = __float2half(scale);
  182. const __half h_zero = __float2half(0.0);
  183. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  184. {
  185. int i = j * unroll_factor;
  186. const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
  187. uint8_t* m = mask + i;
  188. float2 vals_half_f[2];
  189. vals_half_f[0] = __half22float2(vals_half[0]);
  190. vals_half_f[1] = __half22float2(vals_half[1]);
  191. out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
  192. out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
  193. out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
  194. out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
  195. }
  196. #endif
  197. int high_index =
  198. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  199. if (N > high_index) {
  200. for (int i = high_index; i < N; i++) {
  201. out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
  202. }
  203. }
  204. }
  205. template <typename T>
  206. void launch_dropout(T* out,
  207. const T* vals,
  208. uint8_t* mask,
  209. int total_count,
  210. int dim,
  211. float ratio,
  212. cudaStream_t stream,
  213. bool bwd)
  214. {
  215. assert(unroll_factor == 4);
  216. dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
  217. dim3 block_dim = DS_CUDA_NUM_THREADS;
  218. if (dim > 512) {
  219. block_dim.x >>= 1;
  220. grid_dim.x <<= 1;
  221. }
  222. uint64_t inc = total_count / grid_dim.x / block_dim.x;
  223. std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
  224. if (bwd)
  225. dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
  226. total_count, ratio, vals, out, mask, seed);
  227. else
  228. dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
  229. total_count, ratio, out, vals, mask, seed);
  230. }
  231. template void launch_dropout(float* out,
  232. const float* vals,
  233. uint8_t* mask,
  234. int total_count,
  235. int dim,
  236. float ratio,
  237. cudaStream_t stream,
  238. bool);
  239. template void launch_dropout(__half* out,
  240. const __half* vals,
  241. uint8_t* mask,
  242. int total_count,
  243. int dim,
  244. float ratio,
  245. cudaStream_t stream,
  246. bool);
  247. __global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
  248. {
  249. CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
  250. }
  251. __global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
  252. {
  253. const __half2 h_scale = __float2half2_rn(scale);
  254. float2* x_cast = reinterpret_cast<float2*>(Xdata);
  255. uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
  256. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  257. {
  258. float2 x_data = x_cast[j];
  259. uint32_t m_32 = mask_cast[j];
  260. uint8_t* m = (uint8_t*)&m_32;
  261. float2 result_f;
  262. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  263. #ifdef __STOCHASTIC_MODE__
  264. __half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
  265. __half2 mask_h[2];
  266. float2 mask_f[2];
  267. float* mask_f_data = &mask_f[0].x;
  268. #pragma unroll
  269. for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
  270. mask_h[0] = __float22half2_rn(mask_f[0]);
  271. mask_h[1] = __float22half2_rn(mask_f[1]);
  272. result_h[0] = x_data_h[0] * h_scale * mask_h[0];
  273. result_h[1] = x_data_h[1] * h_scale * mask_h[1];
  274. #else
  275. __half* x_data_h = reinterpret_cast<__half*>(&x_data);
  276. float2 result[2];
  277. result[0].x = (float)x_data_h[0] * scale * m[0];
  278. result[0].y = (float)x_data_h[1] * scale * m[1];
  279. result[1].x = (float)x_data_h[2] * scale * m[2];
  280. result[1].y = (float)x_data_h[3] * scale * m[3];
  281. result_h[0] = __float22half2_rn(result[0]);
  282. result_h[1] = __float22half2_rn(result[1]);
  283. #endif
  284. x_cast[j] = result_f;
  285. }
  286. int high_index =
  287. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  288. if (N > high_index) {
  289. for (int i = high_index; i < N; i++) {
  290. Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
  291. }
  292. }
  293. }
  294. template <typename T>
  295. void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream)
  296. {
  297. assert(unroll_factor == 4);
  298. const float scale = 1. / (1. - ratio);
  299. dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
  300. DS_CUDA_NUM_THREADS,
  301. 0,
  302. stream>>>(total_count, scale, vals, mask);
  303. }
  304. template void launch_dropout_grad(float* vals,
  305. uint8_t* mask,
  306. int total_count,
  307. float ratio,
  308. cudaStream_t stream);
  309. template void launch_dropout_grad(__half* vals,
  310. uint8_t* mask,
  311. int total_count,
  312. float ratio,
  313. cudaStream_t stream);
  314. __global__ void dropout_grad_kernel(const int N,
  315. const float scale,
  316. const float* Xdata,
  317. float* out,
  318. uint8_t* mask)
  319. {
  320. CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
  321. }
  322. __global__ void dropout_grad_kernel(const int N,
  323. const float scale,
  324. const __half* Xdata,
  325. __half* out,
  326. uint8_t* mask)
  327. {
  328. const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
  329. float2* out_cast = reinterpret_cast<float2*>(out);
  330. const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
  331. float2 result_f;
  332. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  333. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  334. {
  335. float2 x_data = x_cast[j];
  336. uint32_t m_32 = mask_cast[j];
  337. uint8_t* m = (uint8_t*)&m_32;
  338. __half* x_data_h = reinterpret_cast<__half*>(&x_data);
  339. float2 result[2];
  340. result[0].x = (float)x_data_h[0] * scale * m[0];
  341. result[0].y = (float)x_data_h[1] * scale * m[1];
  342. result[1].x = (float)x_data_h[2] * scale * m[2];
  343. result[1].y = (float)x_data_h[3] * scale * m[3];
  344. result_h[0] = __float22half2_rn(result[0]);
  345. result_h[1] = __float22half2_rn(result[1]);
  346. out_cast[j] = result_f;
  347. }
  348. int high_index =
  349. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  350. if (N > high_index) {
  351. for (int i = high_index; i < N; i++) {
  352. out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
  353. }
  354. }
  355. }
  356. template <typename T>
  357. void launch_dropout_grad(T* vals_out,
  358. const T* vals,
  359. uint8_t* mask,
  360. int total_count,
  361. float ratio,
  362. cudaStream_t stream)
  363. {
  364. assert(unroll_factor == 4);
  365. const float scale = 1. / (1. - ratio);
  366. dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
  367. DS_CUDA_NUM_THREADS,
  368. 0,
  369. stream>>>(total_count, scale, vals, vals_out, mask);
  370. }
  371. template void launch_dropout_grad(float*,
  372. const float* vals,
  373. uint8_t* mask,
  374. int total_count,
  375. float ratio,
  376. cudaStream_t stream);
  377. template void launch_dropout_grad(__half*,
  378. const __half* vals,
  379. uint8_t* mask,
  380. int total_count,
  381. float ratio,
  382. cudaStream_t stream);
  383. __global__ void dropout_kernel(const int N,
  384. const int dim,
  385. const float ratio,
  386. const float* bias,
  387. float* Xdata,
  388. uint8_t* mask,
  389. std::pair<uint64_t, uint64_t> seed)
  390. {
  391. const float scale = 1. / (1. - ratio);
  392. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  393. int tid = threadIdx.x % (dim / unroll_factor);
  394. curandStatePhilox4_32_10_t state;
  395. curand_init(seed.first, idx, seed.second, &state);
  396. float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
  397. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  398. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  399. CUDA_1D_KERNEL_LOOP(j, N)
  400. {
  401. float4 rand = curand_uniform4(&state);
  402. uint32_t m_32;
  403. uint8_t* m = (uint8_t*)&m_32;
  404. m[0] = (uint8_t)(rand.x > ratio);
  405. m[1] = (uint8_t)(rand.y > ratio);
  406. m[2] = (uint8_t)(rand.z > ratio);
  407. m[3] = (uint8_t)(rand.w > ratio);
  408. float4 x_data = Xdata_cast[j];
  409. float4 b_data = bias_cast[j % (dim / unroll_factor)];
  410. x_data.x += b_data.x;
  411. x_data.y += b_data.y;
  412. x_data.z += b_data.z;
  413. x_data.w += b_data.w;
  414. x_data.x = x_data.x * scale * m[0];
  415. x_data.y = x_data.y * scale * m[1];
  416. x_data.z = x_data.z * scale * m[2];
  417. x_data.w = x_data.w * scale * m[3];
  418. mask_32[j] = m_32;
  419. Xdata_cast[j] = x_data;
  420. }
  421. int high_index =
  422. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  423. if (N > high_index) {
  424. float4 rand = curand_uniform4(&state);
  425. float* rand_data = &(rand.x);
  426. int k = 0;
  427. for (int i = high_index; i < N; i++) {
  428. float x_data = Xdata[i] + bias[i % dim];
  429. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  430. Xdata[i] = x_data * scale * m;
  431. mask[i] = m;
  432. }
  433. }
  434. }
  435. __global__ void dropout_kernel(const int N,
  436. const int dim,
  437. const float ratio,
  438. const __half* bias,
  439. __half* Xdata,
  440. uint8_t* mask,
  441. std::pair<uint64_t, uint64_t> seed)
  442. {
  443. const float scale = 1. / (1. - ratio);
  444. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  445. int tid = threadIdx.x % (dim / unroll_factor);
  446. curandStatePhilox4_32_10_t state;
  447. curand_init(seed.first, idx, seed.second, &state);
  448. float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
  449. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  450. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  451. CUDA_1D_KERNEL_LOOP(j, N)
  452. {
  453. float4 rand = curand_uniform4(&state);
  454. float2 data_f;
  455. __half2* data_h = reinterpret_cast<__half2*>(&data_f);
  456. float2 bias_f;
  457. __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
  458. data_f = Xdata_cast[j];
  459. bias_f = bias_cast[j % (dim / unroll_factor)];
  460. float2 data_h_0 = __half22float2(data_h[0]);
  461. float2 data_h_1 = __half22float2(data_h[1]);
  462. float2 bias_h_0 = __half22float2(bias_h[0]);
  463. float2 bias_h_1 = __half22float2(bias_h[1]);
  464. data_h_0.x += bias_h_0.x;
  465. data_h_0.y += bias_h_0.y;
  466. data_h_1.x += bias_h_1.x;
  467. data_h_1.y += bias_h_1.y;
  468. uint32_t m_32;
  469. uint8_t* m = (uint8_t*)&m_32;
  470. m[0] = (uint8_t)(rand.x > ratio);
  471. m[1] = (uint8_t)(rand.y > ratio);
  472. m[2] = (uint8_t)(rand.z > ratio);
  473. m[3] = (uint8_t)(rand.w > ratio);
  474. data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
  475. data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
  476. data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
  477. data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
  478. float2 result_f;
  479. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  480. result_h[0] = __float22half2_rn(data_h_0);
  481. result_h[1] = __float22half2_rn(data_h_1);
  482. Xdata_cast[j] = result_f;
  483. mask_32[j] = m_32;
  484. }
  485. int high_index =
  486. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  487. if (N > high_index) {
  488. float4 rand = curand_uniform4(&state);
  489. float* rand_data = &(rand.x);
  490. int k = 0;
  491. for (int i = high_index; i < N; i++) {
  492. float x_data = (float)Xdata[i] + (float)bias[i % dim];
  493. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  494. Xdata[i] = __float2half(x_data * scale * m);
  495. mask[i] = m;
  496. }
  497. }
  498. }
  499. template <typename T>
  500. void launch_dropout(T* out,
  501. const T* bias,
  502. uint8_t* mask,
  503. int batch,
  504. int dim,
  505. float ratio,
  506. cudaStream_t stream)
  507. {
  508. assert(unroll_factor == 4);
  509. int total_count = batch * dim / unroll_factor;
  510. dim3 grid_dim = DS_GET_BLOCKS(total_count);
  511. dim3 block_dim = DS_CUDA_NUM_THREADS;
  512. uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
  513. std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
  514. dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
  515. total_count, dim, ratio, bias, out, mask, seed);
  516. }
  517. template void launch_dropout(float*,
  518. const float* bias,
  519. uint8_t* mask,
  520. int batch,
  521. int dim,
  522. float ratio,
  523. cudaStream_t stream);
  524. template void launch_dropout(__half*,
  525. const __half* bias,
  526. uint8_t* mask,
  527. int batch,
  528. int dim,
  529. float ratio,
  530. cudaStream_t stream);
  531. __global__ void dropout_kernel(const int N,
  532. const int dim,
  533. const float ratio,
  534. const float* input,
  535. const float* residual,
  536. const float* bias,
  537. float* out,
  538. uint8_t* mask,
  539. std::pair<uint64_t, uint64_t> seed)
  540. {
  541. const float scale = 1. / (1. - ratio);
  542. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  543. int tid = threadIdx.x % (dim / unroll_factor);
  544. curandStatePhilox4_32_10_t state;
  545. curand_init(seed.first, idx, seed.second, &state);
  546. float4* out_cast = reinterpret_cast<float4*>(out);
  547. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  548. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  549. const float4* residual_cast = reinterpret_cast<const float4*>(residual);
  550. const float4* input_cast = reinterpret_cast<const float4*>(input);
  551. CUDA_1D_KERNEL_LOOP(j, N)
  552. {
  553. float4 rand = curand_uniform4(&state);
  554. uint32_t m_32;
  555. uint8_t* m = (uint8_t*)&m_32;
  556. m[0] = (uint8_t)(rand.x > ratio);
  557. m[1] = (uint8_t)(rand.y > ratio);
  558. m[2] = (uint8_t)(rand.z > ratio);
  559. m[3] = (uint8_t)(rand.w > ratio);
  560. float4 out_data;
  561. float4 b_data = bias_cast[j % (dim / unroll_factor)];
  562. float4 res_data = residual_cast[j];
  563. float4 inp_data = input_cast[j];
  564. out_data.x = (b_data.x + inp_data.x);
  565. out_data.y = (b_data.y + inp_data.y);
  566. out_data.z = (b_data.z + inp_data.z);
  567. out_data.w = (b_data.w + inp_data.w);
  568. out_data.x = out_data.x * scale * m[0];
  569. out_data.y = out_data.y * scale * m[1];
  570. out_data.z = out_data.z * scale * m[2];
  571. out_data.w = out_data.w * scale * m[3];
  572. out_data.x += res_data.x;
  573. out_data.y += res_data.y;
  574. out_data.z += res_data.z;
  575. out_data.w += res_data.w;
  576. mask_32[j] = m_32;
  577. out_cast[j] = out_data;
  578. }
  579. int high_index =
  580. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  581. if (N > high_index) {
  582. float4 rand = curand_uniform4(&state);
  583. float* rand_data = &(rand.x);
  584. int k = 0;
  585. for (int i = high_index; i < N; i++) {
  586. float x_data = input[i] + bias[i % dim];
  587. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  588. x_data = x_data * scale * m;
  589. x_data += residual[i];
  590. out[i] = x_data;
  591. mask[i] = m;
  592. }
  593. }
  594. }
  595. __global__ void dropout_kernel(const int N,
  596. const int dim,
  597. const float ratio,
  598. const __half* input,
  599. const __half* residual,
  600. const __half* bias,
  601. __half* out,
  602. uint8_t* mask,
  603. std::pair<uint64_t, uint64_t> seed)
  604. {
  605. const float scale = 1. / (1. - ratio);
  606. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  607. int tid = threadIdx.x % (dim / unroll_factor);
  608. curandStatePhilox4_32_10_t state;
  609. curand_init(seed.first, idx, seed.second, &state);
  610. float2* out_cast = reinterpret_cast<float2*>(out);
  611. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  612. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  613. const float2* residual_cast = reinterpret_cast<const float2*>(residual);
  614. const float2* input_cast = reinterpret_cast<const float2*>(input);
  615. CUDA_1D_KERNEL_LOOP(j, N)
  616. {
  617. float4 rand = curand_uniform4(&state);
  618. float2 data_f;
  619. __half2* data_h = reinterpret_cast<__half2*>(&data_f);
  620. float2 bias_f;
  621. __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
  622. float2 residual_f;
  623. __half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
  624. float2 input_f;
  625. __half2* input_h = reinterpret_cast<__half2*>(&input_f);
  626. bias_f = bias_cast[j % (dim / unroll_factor)];
  627. residual_f = residual_cast[j];
  628. input_f = input_cast[j];
  629. float2 data_h_0 = __half22float2(data_h[0]);
  630. float2 data_h_1 = __half22float2(data_h[1]);
  631. float2 bias_h_0 = __half22float2(bias_h[0]);
  632. float2 bias_h_1 = __half22float2(bias_h[1]);
  633. float2 residual_h_0 = __half22float2(residual_h[0]);
  634. float2 residual_h_1 = __half22float2(residual_h[1]);
  635. float2 input_h_0 = __half22float2(input_h[0]);
  636. float2 input_h_1 = __half22float2(input_h[1]);
  637. data_h_0.x = (bias_h_0.x + input_h_0.x);
  638. data_h_0.y = (bias_h_0.y + input_h_0.y);
  639. data_h_1.x = (bias_h_1.x + input_h_1.x);
  640. data_h_1.y = (bias_h_1.y + input_h_1.y);
  641. uint32_t m_32;
  642. uint8_t* m = (uint8_t*)&m_32;
  643. m[0] = (uint8_t)(rand.x > ratio);
  644. m[1] = (uint8_t)(rand.y > ratio);
  645. m[2] = (uint8_t)(rand.z > ratio);
  646. m[3] = (uint8_t)(rand.w > ratio);
  647. data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
  648. data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
  649. data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
  650. data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
  651. data_h_0.x += residual_h_0.x;
  652. data_h_0.y += residual_h_0.y;
  653. data_h_1.x += residual_h_1.x;
  654. data_h_1.y += residual_h_1.y;
  655. float2 result_f;
  656. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  657. result_h[0] = __float22half2_rn(data_h_0);
  658. result_h[1] = __float22half2_rn(data_h_1);
  659. out_cast[j] = result_f;
  660. mask_32[j] = m_32;
  661. }
  662. int high_index =
  663. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  664. if (N > high_index) {
  665. float4 rand = curand_uniform4(&state);
  666. float* rand_data = &(rand.x);
  667. int k = 0;
  668. for (int i = high_index; i < N; i++) {
  669. float x_data = (float)input[i] + (float)bias[i % dim];
  670. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  671. x_data = x_data * scale * m;
  672. x_data += (float)residual[i];
  673. out[i] = __float2half(x_data);
  674. mask[i] = m;
  675. }
  676. }
  677. }
  678. template <typename T>
  679. void launch_dropout(T* out,
  680. const T* input,
  681. const T* residual,
  682. const T* bias,
  683. uint8_t* mask,
  684. int batch,
  685. int dim,
  686. float ratio,
  687. cudaStream_t stream)
  688. {
  689. assert(unroll_factor == 4);
  690. int total_count = batch * dim / unroll_factor;
  691. dim3 grid_dim = DS_GET_BLOCKS(total_count);
  692. dim3 block_dim = DS_CUDA_NUM_THREADS;
  693. uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
  694. std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
  695. dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
  696. total_count, dim, ratio, input, residual, bias, out, mask, seed);
  697. }
  698. template void launch_dropout(float*,
  699. const float*,
  700. const float* residual,
  701. const float* bias,
  702. uint8_t* mask,
  703. int batch,
  704. int dim,
  705. float ratio,
  706. cudaStream_t stream);
  707. template void launch_dropout(__half*,
  708. const __half*,
  709. const __half* residual,
  710. const __half* bias,
  711. uint8_t* mask,
  712. int batch,
  713. int dim,
  714. float ratio,
  715. cudaStream_t stream);