dropout_kernels.cu 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. #include "custom_cuda_layers.h"
  2. const int unroll_factor = 4;
  3. __global__ void dropout_kernel(const int N,
  4. const float ratio,
  5. float* out,
  6. const float* Xdata,
  7. uint8_t* mask,
  8. std::pair<uint64_t, uint64_t> seed)
  9. {
  10. const float scale = 1. / (1. - ratio);
  11. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  12. curandStatePhilox4_32_10_t state;
  13. curand_init(seed.first, idx, seed.second, &state);
  14. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  15. {
  16. float4 rand = curand_uniform4(&state);
  17. uint8_t m[unroll_factor];
  18. m[0] = (uint8_t)(rand.x > ratio);
  19. m[1] = (uint8_t)(rand.y > ratio);
  20. m[2] = (uint8_t)(rand.z > ratio);
  21. m[3] = (uint8_t)(rand.w > ratio);
  22. int i = j * unroll_factor;
  23. mask[i] = (uint8_t)m[0];
  24. mask[i + 1] = (uint8_t)m[1];
  25. mask[i + 2] = (uint8_t)m[2];
  26. mask[i + 3] = (uint8_t)m[3];
  27. out[i] = Xdata[i] * scale * m[0];
  28. out[i + 1] = Xdata[i + 1] * scale * m[1];
  29. out[i + 2] = Xdata[i + 2] * scale * m[2];
  30. out[i + 3] = Xdata[i + 3] * scale * m[3];
  31. }
  32. int high_index =
  33. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  34. if (N > high_index) {
  35. float4 rand = curand_uniform4(&state);
  36. float* rand_data = &(rand.x);
  37. int k = 0;
  38. for (int i = high_index; i < N; i++) {
  39. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  40. out[i] = Xdata[i] * scale * m;
  41. mask[i] = m;
  42. }
  43. }
  44. }
  45. __global__ void dropout_kernel(const int N,
  46. const float ratio,
  47. __half* out,
  48. const __half* Xdata,
  49. uint8_t* mask,
  50. std::pair<uint64_t, uint64_t> seed)
  51. {
  52. const float scale = 1. / (1. - ratio);
  53. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  54. curandStatePhilox4_32_10_t state;
  55. curand_init(seed.first, idx, seed.second, &state);
  56. #ifdef __STOCHASTIC_MODE__
  57. const __half2 h_scale = __float2half2_rn(scale);
  58. const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
  59. float2* out_cast = reinterpret_cast<float2*>(out);
  60. uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
  61. uint32_t m_32;
  62. uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
  63. float2 result_f;
  64. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  65. __half2 mask_h[2];
  66. float2 mask_f[2];
  67. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  68. {
  69. float2 x_f = x_cast[j];
  70. __half2* x_h = reinterpret_cast<__half2*>(&x_f);
  71. float4 rand = curand_uniform4(&state);
  72. m[0] = (uint8_t)(rand.x > ratio);
  73. m[1] = (uint8_t)(rand.y > ratio);
  74. m[2] = (uint8_t)(rand.z > ratio);
  75. m[3] = (uint8_t)(rand.w > ratio);
  76. float* mask_f_data = &mask_f[0].x;
  77. #pragma unroll
  78. for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
  79. mask_h[0] = __float22half2_rn(mask_f[0]);
  80. mask_h[1] = __float22half2_rn(mask_f[1]);
  81. result_h[0] = x_h[0] * h_scale * mask_h[0];
  82. result_h[1] = x_h[1] * h_scale * mask_h[1];
  83. out_cast[j] = result_f;
  84. mask_cast[j] = m_32;
  85. }
  86. #else
  87. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  88. {
  89. int i = j * unroll_factor;
  90. const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
  91. float2 vals_half_f[2];
  92. vals_half_f[0] = __half22float2(vals_half[0]);
  93. vals_half_f[1] = __half22float2(vals_half[1]);
  94. uint8_t m[unroll_factor];
  95. float4 rand = curand_uniform4(&state);
  96. m[0] = (uint8_t)(rand.x > ratio);
  97. m[1] = (uint8_t)(rand.y > ratio);
  98. m[2] = (uint8_t)(rand.z > ratio);
  99. m[3] = (uint8_t)(rand.w > ratio);
  100. out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
  101. out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
  102. out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
  103. out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
  104. mask[i] = m[0];
  105. mask[i + 1] = m[1];
  106. mask[i + 2] = m[2];
  107. mask[i + 3] = m[3];
  108. }
  109. #endif
  110. int high_index =
  111. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  112. if (N > high_index) {
  113. float4 rand = curand_uniform4(&state);
  114. float* rand_data = &(rand.x);
  115. int k = 0;
  116. for (int i = high_index; i < N; i++) {
  117. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  118. out[i] = __float2half((float)Xdata[i] * scale * m);
  119. mask[i] = m;
  120. }
  121. }
  122. }
  123. __global__ void dropout_kernel_bwd(const int N,
  124. const float ratio,
  125. const float* Xdata,
  126. float* out,
  127. uint8_t* mask,
  128. std::pair<uint64_t, uint64_t> seed)
  129. {
  130. const float scale = 1. / (1. - ratio);
  131. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  132. {
  133. int i = j * unroll_factor;
  134. out[i] = mask[i] ? Xdata[i] * scale : 0.0;
  135. out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
  136. out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
  137. out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
  138. }
  139. int high_index =
  140. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  141. if (N > high_index) {
  142. for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
  143. }
  144. }
  145. __global__ void dropout_kernel_bwd(const int N,
  146. const float ratio,
  147. const __half* Xdata,
  148. __half* out,
  149. uint8_t* mask,
  150. std::pair<uint64_t, uint64_t> seed)
  151. {
  152. const float scale = 1. / (1. - ratio);
  153. #ifdef __STOCHASTIC_MODE__
  154. const __half2 h_scale = __float2half2_rn(scale);
  155. const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
  156. float2* out_cast = reinterpret_cast<float2*>(out);
  157. uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
  158. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  159. {
  160. float2 x_f = x_cast[j];
  161. __half2* x_h = reinterpret_cast<__half2*>(&x_f);
  162. uint32_t m_32 = mask_cast[j];
  163. uint8_t* m = (uint8_t*)&m_32;
  164. __half2 mask_h[2];
  165. float2 mask_f[2];
  166. float* mask_f_data = &mask_f[0].x;
  167. #pragma unroll
  168. for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
  169. #pragma unroll
  170. for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
  171. float2 result_f;
  172. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  173. result_h[0] = x_h[0] * h_scale * mask_h[0];
  174. result_h[1] = x_h[1] * h_scale * mask_h[1];
  175. out_cast[j] = result_f;
  176. }
  177. #else
  178. const __half h_scale = __float2half(scale);
  179. const __half h_zero = __float2half(0.0);
  180. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  181. {
  182. int i = j * unroll_factor;
  183. const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
  184. uint8_t* m = mask + i;
  185. float2 vals_half_f[2];
  186. vals_half_f[0] = __half22float2(vals_half[0]);
  187. vals_half_f[1] = __half22float2(vals_half[1]);
  188. out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
  189. out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
  190. out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
  191. out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
  192. }
  193. #endif
  194. int high_index =
  195. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  196. if (N > high_index) {
  197. for (int i = high_index; i < N; i++) {
  198. out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
  199. }
  200. }
  201. }
  202. template <typename T>
  203. void launch_dropout(T* out,
  204. const T* vals,
  205. uint8_t* mask,
  206. int total_count,
  207. int dim,
  208. float ratio,
  209. cudaStream_t stream,
  210. bool bwd)
  211. {
  212. assert(unroll_factor == 4);
  213. dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
  214. dim3 block_dim = DS_CUDA_NUM_THREADS;
  215. if (dim > 512) {
  216. block_dim.x >>= 1;
  217. grid_dim.x <<= 1;
  218. }
  219. uint64_t inc = total_count / grid_dim.x / block_dim.x;
  220. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  221. if (bwd)
  222. dropout_kernel_bwd<<<grid_dim, block_dim, 0, stream>>>(
  223. total_count, ratio, vals, out, mask, seed);
  224. else
  225. dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
  226. total_count, ratio, out, vals, mask, seed);
  227. }
  228. template void launch_dropout(float* out,
  229. const float* vals,
  230. uint8_t* mask,
  231. int total_count,
  232. int dim,
  233. float ratio,
  234. cudaStream_t stream,
  235. bool);
  236. template void launch_dropout(__half* out,
  237. const __half* vals,
  238. uint8_t* mask,
  239. int total_count,
  240. int dim,
  241. float ratio,
  242. cudaStream_t stream,
  243. bool);
  244. __global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
  245. {
  246. CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
  247. }
  248. __global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
  249. {
  250. const __half2 h_scale = __float2half2_rn(scale);
  251. float2* x_cast = reinterpret_cast<float2*>(Xdata);
  252. uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
  253. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  254. {
  255. float2 x_data = x_cast[j];
  256. uint32_t m_32 = mask_cast[j];
  257. uint8_t* m = (uint8_t*)&m_32;
  258. float2 result_f;
  259. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  260. #ifdef __STOCHASTIC_MODE__
  261. __half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
  262. __half2 mask_h[2];
  263. float2 mask_f[2];
  264. float* mask_f_data = &mask_f[0].x;
  265. #pragma unroll
  266. for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
  267. mask_h[0] = __float22half2_rn(mask_f[0]);
  268. mask_h[1] = __float22half2_rn(mask_f[1]);
  269. result_h[0] = x_data_h[0] * h_scale * mask_h[0];
  270. result_h[1] = x_data_h[1] * h_scale * mask_h[1];
  271. #else
  272. __half* x_data_h = reinterpret_cast<__half*>(&x_data);
  273. float2 result[2];
  274. result[0].x = (float)x_data_h[0] * scale * m[0];
  275. result[0].y = (float)x_data_h[1] * scale * m[1];
  276. result[1].x = (float)x_data_h[2] * scale * m[2];
  277. result[1].y = (float)x_data_h[3] * scale * m[3];
  278. result_h[0] = __float22half2_rn(result[0]);
  279. result_h[1] = __float22half2_rn(result[1]);
  280. #endif
  281. x_cast[j] = result_f;
  282. }
  283. int high_index =
  284. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  285. if (N > high_index) {
  286. for (int i = high_index; i < N; i++) {
  287. Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
  288. }
  289. }
  290. }
  291. template <typename T>
  292. void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream)
  293. {
  294. assert(unroll_factor == 4);
  295. const float scale = 1. / (1. - ratio);
  296. dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
  297. DS_CUDA_NUM_THREADS,
  298. 0,
  299. stream>>>(total_count, scale, vals, mask);
  300. }
  301. template void launch_dropout_grad(float* vals,
  302. uint8_t* mask,
  303. int total_count,
  304. float ratio,
  305. cudaStream_t stream);
  306. template void launch_dropout_grad(__half* vals,
  307. uint8_t* mask,
  308. int total_count,
  309. float ratio,
  310. cudaStream_t stream);
  311. __global__ void dropout_grad_kernel(const int N,
  312. const float scale,
  313. const float* Xdata,
  314. float* out,
  315. uint8_t* mask)
  316. {
  317. CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
  318. }
  319. __global__ void dropout_grad_kernel(const int N,
  320. const float scale,
  321. const __half* Xdata,
  322. __half* out,
  323. uint8_t* mask)
  324. {
  325. const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
  326. float2* out_cast = reinterpret_cast<float2*>(out);
  327. const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
  328. float2 result_f;
  329. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  330. CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
  331. {
  332. float2 x_data = x_cast[j];
  333. uint32_t m_32 = mask_cast[j];
  334. uint8_t* m = (uint8_t*)&m_32;
  335. __half* x_data_h = reinterpret_cast<__half*>(&x_data);
  336. float2 result[2];
  337. result[0].x = (float)x_data_h[0] * scale * m[0];
  338. result[0].y = (float)x_data_h[1] * scale * m[1];
  339. result[1].x = (float)x_data_h[2] * scale * m[2];
  340. result[1].y = (float)x_data_h[3] * scale * m[3];
  341. result_h[0] = __float22half2_rn(result[0]);
  342. result_h[1] = __float22half2_rn(result[1]);
  343. out_cast[j] = result_f;
  344. }
  345. int high_index =
  346. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  347. if (N > high_index) {
  348. for (int i = high_index; i < N; i++) {
  349. out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
  350. }
  351. }
  352. }
  353. template <typename T>
  354. void launch_dropout_grad(T* vals_out,
  355. const T* vals,
  356. uint8_t* mask,
  357. int total_count,
  358. float ratio,
  359. cudaStream_t stream)
  360. {
  361. assert(unroll_factor == 4);
  362. const float scale = 1. / (1. - ratio);
  363. dropout_grad_kernel<<<DS_GET_BLOCKS(total_count / unroll_factor),
  364. DS_CUDA_NUM_THREADS,
  365. 0,
  366. stream>>>(total_count, scale, vals, vals_out, mask);
  367. }
  368. template void launch_dropout_grad(float*,
  369. const float* vals,
  370. uint8_t* mask,
  371. int total_count,
  372. float ratio,
  373. cudaStream_t stream);
  374. template void launch_dropout_grad(__half*,
  375. const __half* vals,
  376. uint8_t* mask,
  377. int total_count,
  378. float ratio,
  379. cudaStream_t stream);
  380. __global__ void dropout_kernel(const int N,
  381. const int dim,
  382. const float ratio,
  383. const float* bias,
  384. float* Xdata,
  385. uint8_t* mask,
  386. std::pair<uint64_t, uint64_t> seed)
  387. {
  388. const float scale = 1. / (1. - ratio);
  389. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  390. int tid = threadIdx.x % (dim / unroll_factor);
  391. curandStatePhilox4_32_10_t state;
  392. curand_init(seed.first, idx, seed.second, &state);
  393. float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
  394. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  395. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  396. CUDA_1D_KERNEL_LOOP(j, N)
  397. {
  398. float4 rand = curand_uniform4(&state);
  399. uint32_t m_32;
  400. uint8_t* m = (uint8_t*)&m_32;
  401. m[0] = (uint8_t)(rand.x > ratio);
  402. m[1] = (uint8_t)(rand.y > ratio);
  403. m[2] = (uint8_t)(rand.z > ratio);
  404. m[3] = (uint8_t)(rand.w > ratio);
  405. float4 x_data = Xdata_cast[j];
  406. float4 b_data = bias_cast[j % (dim / unroll_factor)];
  407. x_data.x += b_data.x;
  408. x_data.y += b_data.y;
  409. x_data.z += b_data.z;
  410. x_data.w += b_data.w;
  411. x_data.x = x_data.x * scale * m[0];
  412. x_data.y = x_data.y * scale * m[1];
  413. x_data.z = x_data.z * scale * m[2];
  414. x_data.w = x_data.w * scale * m[3];
  415. mask_32[j] = m_32;
  416. Xdata_cast[j] = x_data;
  417. }
  418. int high_index =
  419. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  420. if (N > high_index) {
  421. float4 rand = curand_uniform4(&state);
  422. float* rand_data = &(rand.x);
  423. int k = 0;
  424. for (int i = high_index; i < N; i++) {
  425. float x_data = Xdata[i] + bias[i % dim];
  426. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  427. Xdata[i] = x_data * scale * m;
  428. mask[i] = m;
  429. }
  430. }
  431. }
  432. __global__ void dropout_kernel(const int N,
  433. const int dim,
  434. const float ratio,
  435. const __half* bias,
  436. __half* Xdata,
  437. uint8_t* mask,
  438. std::pair<uint64_t, uint64_t> seed)
  439. {
  440. const float scale = 1. / (1. - ratio);
  441. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  442. int tid = threadIdx.x % (dim / unroll_factor);
  443. curandStatePhilox4_32_10_t state;
  444. curand_init(seed.first, idx, seed.second, &state);
  445. float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
  446. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  447. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  448. CUDA_1D_KERNEL_LOOP(j, N)
  449. {
  450. float4 rand = curand_uniform4(&state);
  451. float2 data_f;
  452. __half2* data_h = reinterpret_cast<__half2*>(&data_f);
  453. float2 bias_f;
  454. __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
  455. data_f = Xdata_cast[j];
  456. bias_f = bias_cast[j % (dim / unroll_factor)];
  457. float2 data_h_0 = __half22float2(data_h[0]);
  458. float2 data_h_1 = __half22float2(data_h[1]);
  459. float2 bias_h_0 = __half22float2(bias_h[0]);
  460. float2 bias_h_1 = __half22float2(bias_h[1]);
  461. data_h_0.x += bias_h_0.x;
  462. data_h_0.y += bias_h_0.y;
  463. data_h_1.x += bias_h_1.x;
  464. data_h_1.y += bias_h_1.y;
  465. uint32_t m_32;
  466. uint8_t* m = (uint8_t*)&m_32;
  467. m[0] = (uint8_t)(rand.x > ratio);
  468. m[1] = (uint8_t)(rand.y > ratio);
  469. m[2] = (uint8_t)(rand.z > ratio);
  470. m[3] = (uint8_t)(rand.w > ratio);
  471. data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
  472. data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
  473. data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
  474. data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
  475. float2 result_f;
  476. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  477. result_h[0] = __float22half2_rn(data_h_0);
  478. result_h[1] = __float22half2_rn(data_h_1);
  479. Xdata_cast[j] = result_f;
  480. mask_32[j] = m_32;
  481. }
  482. int high_index =
  483. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  484. if (N > high_index) {
  485. float4 rand = curand_uniform4(&state);
  486. float* rand_data = &(rand.x);
  487. int k = 0;
  488. for (int i = high_index; i < N; i++) {
  489. float x_data = (float)Xdata[i] + (float)bias[i % dim];
  490. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  491. Xdata[i] = __float2half(x_data * scale * m);
  492. mask[i] = m;
  493. }
  494. }
  495. }
  496. template <typename T>
  497. void launch_dropout(T* out,
  498. const T* bias,
  499. uint8_t* mask,
  500. int batch,
  501. int dim,
  502. float ratio,
  503. cudaStream_t stream)
  504. {
  505. assert(unroll_factor == 4);
  506. int total_count = batch * dim / unroll_factor;
  507. dim3 grid_dim = DS_GET_BLOCKS(total_count);
  508. dim3 block_dim = DS_CUDA_NUM_THREADS;
  509. uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
  510. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  511. dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
  512. total_count, dim, ratio, bias, out, mask, seed);
  513. }
  514. template void launch_dropout(float*,
  515. const float* bias,
  516. uint8_t* mask,
  517. int batch,
  518. int dim,
  519. float ratio,
  520. cudaStream_t stream);
  521. template void launch_dropout(__half*,
  522. const __half* bias,
  523. uint8_t* mask,
  524. int batch,
  525. int dim,
  526. float ratio,
  527. cudaStream_t stream);
  528. __global__ void dropout_kernel(const int N,
  529. const int dim,
  530. const float ratio,
  531. const float* input,
  532. const float* residual,
  533. const float* bias,
  534. float* out,
  535. uint8_t* mask,
  536. std::pair<uint64_t, uint64_t> seed)
  537. {
  538. const float scale = 1. / (1. - ratio);
  539. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  540. int tid = threadIdx.x % (dim / unroll_factor);
  541. curandStatePhilox4_32_10_t state;
  542. curand_init(seed.first, idx, seed.second, &state);
  543. float4* out_cast = reinterpret_cast<float4*>(out);
  544. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  545. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  546. const float4* residual_cast = reinterpret_cast<const float4*>(residual);
  547. const float4* input_cast = reinterpret_cast<const float4*>(input);
  548. CUDA_1D_KERNEL_LOOP(j, N)
  549. {
  550. float4 rand = curand_uniform4(&state);
  551. uint32_t m_32;
  552. uint8_t* m = (uint8_t*)&m_32;
  553. m[0] = (uint8_t)(rand.x > ratio);
  554. m[1] = (uint8_t)(rand.y > ratio);
  555. m[2] = (uint8_t)(rand.z > ratio);
  556. m[3] = (uint8_t)(rand.w > ratio);
  557. float4 out_data;
  558. float4 b_data = bias_cast[j % (dim / unroll_factor)];
  559. float4 res_data = residual_cast[j];
  560. float4 inp_data = input_cast[j];
  561. out_data.x = (b_data.x + inp_data.x);
  562. out_data.y = (b_data.y + inp_data.y);
  563. out_data.z = (b_data.z + inp_data.z);
  564. out_data.w = (b_data.w + inp_data.w);
  565. out_data.x = out_data.x * scale * m[0];
  566. out_data.y = out_data.y * scale * m[1];
  567. out_data.z = out_data.z * scale * m[2];
  568. out_data.w = out_data.w * scale * m[3];
  569. out_data.x += res_data.x;
  570. out_data.y += res_data.y;
  571. out_data.z += res_data.z;
  572. out_data.w += res_data.w;
  573. mask_32[j] = m_32;
  574. out_cast[j] = out_data;
  575. }
  576. int high_index =
  577. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  578. if (N > high_index) {
  579. float4 rand = curand_uniform4(&state);
  580. float* rand_data = &(rand.x);
  581. int k = 0;
  582. for (int i = high_index; i < N; i++) {
  583. float x_data = input[i] + bias[i % dim];
  584. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  585. x_data = x_data * scale * m;
  586. x_data += residual[i];
  587. out[i] = x_data;
  588. mask[i] = m;
  589. }
  590. }
  591. }
  592. __global__ void dropout_kernel(const int N,
  593. const int dim,
  594. const float ratio,
  595. const __half* input,
  596. const __half* residual,
  597. const __half* bias,
  598. __half* out,
  599. uint8_t* mask,
  600. std::pair<uint64_t, uint64_t> seed)
  601. {
  602. const float scale = 1. / (1. - ratio);
  603. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  604. int tid = threadIdx.x % (dim / unroll_factor);
  605. curandStatePhilox4_32_10_t state;
  606. curand_init(seed.first, idx, seed.second, &state);
  607. float2* out_cast = reinterpret_cast<float2*>(out);
  608. uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
  609. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  610. const float2* residual_cast = reinterpret_cast<const float2*>(residual);
  611. const float2* input_cast = reinterpret_cast<const float2*>(input);
  612. CUDA_1D_KERNEL_LOOP(j, N)
  613. {
  614. float4 rand = curand_uniform4(&state);
  615. float2 data_f;
  616. __half2* data_h = reinterpret_cast<__half2*>(&data_f);
  617. float2 bias_f;
  618. __half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
  619. float2 residual_f;
  620. __half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
  621. float2 input_f;
  622. __half2* input_h = reinterpret_cast<__half2*>(&input_f);
  623. bias_f = bias_cast[j % (dim / unroll_factor)];
  624. residual_f = residual_cast[j];
  625. input_f = input_cast[j];
  626. float2 data_h_0 = __half22float2(data_h[0]);
  627. float2 data_h_1 = __half22float2(data_h[1]);
  628. float2 bias_h_0 = __half22float2(bias_h[0]);
  629. float2 bias_h_1 = __half22float2(bias_h[1]);
  630. float2 residual_h_0 = __half22float2(residual_h[0]);
  631. float2 residual_h_1 = __half22float2(residual_h[1]);
  632. float2 input_h_0 = __half22float2(input_h[0]);
  633. float2 input_h_1 = __half22float2(input_h[1]);
  634. data_h_0.x = (bias_h_0.x + input_h_0.x);
  635. data_h_0.y = (bias_h_0.y + input_h_0.y);
  636. data_h_1.x = (bias_h_1.x + input_h_1.x);
  637. data_h_1.y = (bias_h_1.y + input_h_1.y);
  638. uint32_t m_32;
  639. uint8_t* m = (uint8_t*)&m_32;
  640. m[0] = (uint8_t)(rand.x > ratio);
  641. m[1] = (uint8_t)(rand.y > ratio);
  642. m[2] = (uint8_t)(rand.z > ratio);
  643. m[3] = (uint8_t)(rand.w > ratio);
  644. data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
  645. data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
  646. data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
  647. data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
  648. data_h_0.x += residual_h_0.x;
  649. data_h_0.y += residual_h_0.y;
  650. data_h_1.x += residual_h_1.x;
  651. data_h_1.y += residual_h_1.y;
  652. float2 result_f;
  653. __half2* result_h = reinterpret_cast<__half2*>(&result_f);
  654. result_h[0] = __float22half2_rn(data_h_0);
  655. result_h[1] = __float22half2_rn(data_h_1);
  656. out_cast[j] = result_f;
  657. mask_32[j] = m_32;
  658. }
  659. int high_index =
  660. ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
  661. if (N > high_index) {
  662. float4 rand = curand_uniform4(&state);
  663. float* rand_data = &(rand.x);
  664. int k = 0;
  665. for (int i = high_index; i < N; i++) {
  666. float x_data = (float)input[i] + (float)bias[i % dim];
  667. uint8_t m = (uint8_t)(rand_data[k++] > ratio);
  668. x_data = x_data * scale * m;
  669. x_data += (float)residual[i];
  670. out[i] = __float2half(x_data);
  671. mask[i] = m;
  672. }
  673. }
  674. }
  675. template <typename T>
  676. void launch_dropout(T* out,
  677. const T* input,
  678. const T* residual,
  679. const T* bias,
  680. uint8_t* mask,
  681. int batch,
  682. int dim,
  683. float ratio,
  684. cudaStream_t stream)
  685. {
  686. assert(unroll_factor == 4);
  687. int total_count = batch * dim / unroll_factor;
  688. dim3 grid_dim = DS_GET_BLOCKS(total_count);
  689. dim3 block_dim = DS_CUDA_NUM_THREADS;
  690. uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
  691. std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
  692. dropout_kernel<<<grid_dim, block_dim, 0, stream>>>(
  693. total_count, dim, ratio, input, residual, bias, out, mask, seed);
  694. }
  695. template void launch_dropout(float*,
  696. const float*,
  697. const float* residual,
  698. const float* bias,
  699. uint8_t* mask,
  700. int batch,
  701. int dim,
  702. float ratio,
  703. cudaStream_t stream);
  704. template void launch_dropout(__half*,
  705. const __half*,
  706. const __half* residual,
  707. const __half* bias,
  708. uint8_t* mask,
  709. int batch,
  710. int dim,
  711. float ratio,
  712. cudaStream_t stream);