gelu_kernels.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. #include "custom_cuda_layers.h"
  2. inline __device__ float gelu(const float x)
  3. {
  4. const float sqrt_param = 0.79788456080286535587989211986876f;
  5. const float mul_param = 0.044715;
  6. return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
  7. }
  8. inline __device__ float d_gelu(const float x)
  9. {
  10. const float sqrt_param = 0.79788456080286535587989211986876f;
  11. const float mul_param = 0.044715;
  12. float x2mul = x * x * mul_param;
  13. float tan_h = tanhf(sqrt_param * (x + x * x2mul));
  14. float dg1 = 0.5f * (1.0f + tan_h);
  15. float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
  16. float dg3 = dg2 * 3 * x2mul;
  17. return (dg1 + dg2 + dg3);
  18. }
  19. /*
  20. Fused bias add with GELU
  21. Loads a vector of 4 elements each iteration, for stride
  22. iterations. It was written with the intention to launch 256 thread
  23. threadblocks, so to launch for bert-large, we would set ITERATIONS
  24. to 4. This is currently done automatically as a heuristic, setting
  25. the number of iterations as blocks of 1024.
  26. For FP16, the values are loaded from memory as __half, but converted
  27. to FP32 for the arithmetic itself, to prevent numerous overflow on
  28. the intermediate hyperbolic tangent, since there's no intrinsic
  29. that computes it directly.
  30. */
  31. __global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
  32. {
  33. int row = blockIdx.x;
  34. int id = threadIdx.x;
  35. int loop_stride = blockDim.x;
  36. const float4* input_cast = reinterpret_cast<const float4*>(input);
  37. float4* vals_cast = reinterpret_cast<float4*>(vals);
  38. for (int i = 0; i < iterations; i++) {
  39. if (i * loop_stride + id < row_stride) {
  40. float4 data = input_cast[row * row_stride + i * loop_stride + id];
  41. data.x = gelu(data.x);
  42. data.y = gelu(data.y);
  43. data.z = gelu(data.z);
  44. data.w = gelu(data.w);
  45. vals_cast[row * row_stride + i * loop_stride + id] = data;
  46. }
  47. }
  48. }
  49. __global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
  50. {
  51. #if __CUDA_ARCH__ >= 700
  52. int row = blockIdx.x;
  53. int id = threadIdx.x;
  54. int loop_stride = blockDim.x;
  55. const float2* input_cast = reinterpret_cast<const float2*>(input);
  56. float2* vals_cast = reinterpret_cast<float2*>(vals);
  57. for (int i = 0; i < iterations; i++) {
  58. if (i * loop_stride + id < row_stride) {
  59. float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
  60. __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
  61. float2 low_data = __half22float2(vals_half[0]);
  62. float2 high_data = __half22float2(vals_half[1]);
  63. low_data.x = gelu(low_data.x);
  64. low_data.y = gelu(low_data.y);
  65. high_data.x = gelu(high_data.x);
  66. high_data.y = gelu(high_data.y);
  67. vals_half[0] = __float22half2_rn(low_data);
  68. vals_half[1] = __float22half2_rn(high_data);
  69. vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
  70. }
  71. }
  72. #endif
  73. }
  74. __global__ void fused_bias_gelu(const float* input,
  75. const float* bias,
  76. float* vals,
  77. int row_stride,
  78. int iterations)
  79. {
  80. int row = blockIdx.x;
  81. int id = threadIdx.x;
  82. int loop_stride = blockDim.x;
  83. const float4* input_cast = reinterpret_cast<const float4*>(input);
  84. float4* vals_cast = reinterpret_cast<float4*>(vals);
  85. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  86. for (int i = 0; i < iterations; i++) {
  87. if (i * loop_stride + id < row_stride) {
  88. float4 data = input_cast[row * row_stride + i * loop_stride + id];
  89. float4 bias_data = bias_cast[i * loop_stride + id];
  90. data.x += bias_data.x;
  91. data.y += bias_data.y;
  92. data.z += bias_data.z;
  93. data.w += bias_data.w;
  94. data.x = gelu(data.x);
  95. data.y = gelu(data.y);
  96. data.z = gelu(data.z);
  97. data.w = gelu(data.w);
  98. vals_cast[row * row_stride + i * loop_stride + id] = data;
  99. }
  100. }
  101. }
  102. __global__ void fused_bias_gelu(const __half* input,
  103. const __half* bias,
  104. __half* vals,
  105. int row_stride,
  106. int iterations)
  107. {
  108. #if __CUDA_ARCH__ >= 700
  109. int row = blockIdx.x;
  110. int id = threadIdx.x;
  111. int loop_stride = blockDim.x;
  112. const float2* input_cast = reinterpret_cast<const float2*>(input);
  113. float2* vals_cast = reinterpret_cast<float2*>(vals);
  114. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  115. for (int i = 0; i < iterations; i++) {
  116. if (i * loop_stride + id < row_stride) {
  117. float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
  118. float2 bias_vec = bias_cast[i * loop_stride + id];
  119. __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
  120. __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
  121. float2 low_data = __half22float2(vals_half[0]);
  122. float2 high_data = __half22float2(vals_half[1]);
  123. float2 low_bias = __half22float2(bias_half[0]);
  124. float2 high_bias = __half22float2(bias_half[1]);
  125. low_data.x += low_bias.x;
  126. low_data.y += low_bias.y;
  127. high_data.x += high_bias.x;
  128. high_data.y += high_bias.y;
  129. low_data.x = gelu(low_data.x);
  130. low_data.y = gelu(low_data.y);
  131. high_data.x = gelu(high_data.x);
  132. high_data.y = gelu(high_data.y);
  133. vals_half[0] = __float22half2_rn(low_data);
  134. vals_half[1] = __float22half2_rn(high_data);
  135. vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
  136. }
  137. }
  138. #endif
  139. }
  140. __global__ void d_gelu_func(float* d_output,
  141. const float* gelu_input,
  142. const float* bias,
  143. int row_stride,
  144. int iterations)
  145. {
  146. int row = blockIdx.x;
  147. int id = threadIdx.x;
  148. int loop_stride = blockDim.x;
  149. float4* d_output_cast = reinterpret_cast<float4*>(d_output);
  150. const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
  151. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  152. for (int i = 0; i < iterations; i++) {
  153. if (i * loop_stride + id < row_stride) {
  154. float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
  155. float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
  156. float4 bias_data = bias_cast[i * loop_stride + id];
  157. gelu_input_data.x += bias_data.x;
  158. gelu_input_data.y += bias_data.y;
  159. gelu_input_data.z += bias_data.z;
  160. gelu_input_data.w += bias_data.w;
  161. output_data.x *= d_gelu(gelu_input_data.x);
  162. output_data.y *= d_gelu(gelu_input_data.y);
  163. output_data.z *= d_gelu(gelu_input_data.z);
  164. output_data.w *= d_gelu(gelu_input_data.w);
  165. d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
  166. }
  167. }
  168. }
  169. __global__ void d_gelu_func(__half* d_output,
  170. const __half* gelu_input,
  171. const __half* bias,
  172. int row_stride,
  173. int iterations)
  174. {
  175. #if __CUDA_ARCH__ >= 700
  176. int row = blockIdx.x;
  177. int id = threadIdx.x;
  178. int loop_stride = blockDim.x;
  179. float2* d_output_cast = reinterpret_cast<float2*>(d_output);
  180. const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
  181. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  182. #pragma unroll
  183. for (int i = 0; i < iterations; i++) {
  184. if (i * loop_stride + id < row_stride) {
  185. float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
  186. float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
  187. float2 bias_vec = bias_cast[i * loop_stride + id];
  188. __half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
  189. __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
  190. __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
  191. float2 output_half_0 = __half22float2(output_data_half[0]);
  192. float2 output_half_1 = __half22float2(output_data_half[1]);
  193. float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
  194. float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
  195. float2 bias_half_0 = __half22float2(bias_half[0]);
  196. float2 bias_half_1 = __half22float2(bias_half[1]);
  197. gelu_input_half_0.x += bias_half_0.x;
  198. gelu_input_half_0.y += bias_half_0.y;
  199. gelu_input_half_1.x += bias_half_1.x;
  200. gelu_input_half_1.y += bias_half_1.y;
  201. output_half_0.x *= d_gelu(gelu_input_half_0.x);
  202. output_half_0.y *= d_gelu(gelu_input_half_0.y);
  203. output_half_1.x *= d_gelu(gelu_input_half_1.x);
  204. output_half_1.y *= d_gelu(gelu_input_half_1.y);
  205. float2 result;
  206. __half2* result_half2 = reinterpret_cast<__half2*>(&result);
  207. result_half2[0] = __float22half2_rn(output_half_0);
  208. result_half2[1] = __float22half2_rn(output_half_1);
  209. d_output_cast[row * row_stride + i * loop_stride + id] = result;
  210. }
  211. }
  212. #endif
  213. }
  214. template <typename T>
  215. void launch_bias_gelu(const T* input,
  216. const T* bias,
  217. T* output,
  218. int intermediate_size,
  219. int batch_size,
  220. cudaStream_t stream)
  221. {
  222. int iterations = (intermediate_size + 1023) / 1024;
  223. int threads = (intermediate_size - 1) / (iterations * 4) + 1;
  224. dim3 block_dims(threads);
  225. dim3 grid_dims(batch_size);
  226. fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
  227. input, bias, output, intermediate_size / 4, iterations);
  228. }
  229. template <typename T>
  230. void launch_gelu(const T* input,
  231. T* output,
  232. int intermediate_size,
  233. int batch_size,
  234. cudaStream_t stream)
  235. {
  236. int iterations = (intermediate_size + 1023) / 1024;
  237. int threads = (intermediate_size - 1) / (iterations * 4) + 1;
  238. dim3 block_dims(threads);
  239. dim3 grid_dims(batch_size);
  240. gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(
  241. input, output, intermediate_size / 4, iterations);
  242. }
  243. template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
  244. template void launch_bias_gelu<__half>(const __half*,
  245. const __half*,
  246. __half*,
  247. int,
  248. int,
  249. cudaStream_t);
  250. template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
  251. template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
  252. template <typename T>
  253. void launch_d_gelu(T* d_output,
  254. const T* input,
  255. const T* bias,
  256. int intermediate_size,
  257. int batch_size,
  258. cudaStream_t stream)
  259. {
  260. int iterations = (intermediate_size + 1023) / 1024;
  261. int threads = (intermediate_size - 1) / (iterations * 4) + 1;
  262. dim3 block_dims(threads);
  263. dim3 grid_dims(batch_size);
  264. d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(
  265. d_output, input, bias, intermediate_size / 4, iterations);
  266. }
  267. template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
  268. template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);