gelu_kernels.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include "custom_cuda_layers.h"
  5. inline __device__ float gelu(const float x)
  6. {
  7. const float sqrt_param = 0.79788456080286535587989211986876f;
  8. const float mul_param = 0.044715;
  9. return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
  10. }
  11. inline __device__ float d_gelu(const float x)
  12. {
  13. const float sqrt_param = 0.79788456080286535587989211986876f;
  14. const float mul_param = 0.044715;
  15. float x2mul = x * x * mul_param;
  16. float tan_h = tanhf(sqrt_param * (x + x * x2mul));
  17. float dg1 = 0.5f * (1.0f + tan_h);
  18. float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
  19. float dg3 = dg2 * 3 * x2mul;
  20. return (dg1 + dg2 + dg3);
  21. }
  22. /*
  23. Fused bias add with GELU
  24. Loads a vector of 4 elements each iteration, for stride
  25. iterations. It was written with the intention to launch 256 thread
  26. threadblocks, so to launch for bert-large, we would set ITERATIONS
  27. to 4. This is currently done automatically as a heuristic, setting
  28. the number of iterations as blocks of 1024.
  29. For FP16, the values are loaded from memory as __half, but converted
  30. to FP32 for the arithmetic itself, to prevent numerous overflow on
  31. the intermediate hyperbolic tangent, since there's no intrinsic
  32. that computes it directly.
  33. */
  34. __global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
  35. {
  36. int row = blockIdx.x;
  37. int id = threadIdx.x;
  38. int loop_stride = blockDim.x;
  39. const float4* input_cast = reinterpret_cast<const float4*>(input);
  40. float4* vals_cast = reinterpret_cast<float4*>(vals);
  41. for (int i = 0; i < iterations; i++) {
  42. if (i * loop_stride + id < row_stride) {
  43. float4 data = input_cast[row * row_stride + i * loop_stride + id];
  44. data.x = gelu(data.x);
  45. data.y = gelu(data.y);
  46. data.z = gelu(data.z);
  47. data.w = gelu(data.w);
  48. vals_cast[row * row_stride + i * loop_stride + id] = data;
  49. }
  50. }
  51. }
  52. __global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
  53. {
  54. #ifdef HALF_PRECISION_AVAILABLE
  55. int row = blockIdx.x;
  56. int id = threadIdx.x;
  57. int loop_stride = blockDim.x;
  58. const float2* input_cast = reinterpret_cast<const float2*>(input);
  59. float2* vals_cast = reinterpret_cast<float2*>(vals);
  60. for (int i = 0; i < iterations; i++) {
  61. if (i * loop_stride + id < row_stride) {
  62. float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
  63. __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
  64. float2 low_data = __half22float2(vals_half[0]);
  65. float2 high_data = __half22float2(vals_half[1]);
  66. low_data.x = gelu(low_data.x);
  67. low_data.y = gelu(low_data.y);
  68. high_data.x = gelu(high_data.x);
  69. high_data.y = gelu(high_data.y);
  70. vals_half[0] = __float22half2_rn(low_data);
  71. vals_half[1] = __float22half2_rn(high_data);
  72. vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
  73. }
  74. }
  75. #endif
  76. }
  77. __global__ void fused_bias_gelu(const float* input,
  78. const float* bias,
  79. float* vals,
  80. int row_stride,
  81. int iterations)
  82. {
  83. int row = blockIdx.x;
  84. int id = threadIdx.x;
  85. int loop_stride = blockDim.x;
  86. const float4* input_cast = reinterpret_cast<const float4*>(input);
  87. float4* vals_cast = reinterpret_cast<float4*>(vals);
  88. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  89. for (int i = 0; i < iterations; i++) {
  90. if (i * loop_stride + id < row_stride) {
  91. float4 data = input_cast[row * row_stride + i * loop_stride + id];
  92. float4 bias_data = bias_cast[i * loop_stride + id];
  93. data.x += bias_data.x;
  94. data.y += bias_data.y;
  95. data.z += bias_data.z;
  96. data.w += bias_data.w;
  97. data.x = gelu(data.x);
  98. data.y = gelu(data.y);
  99. data.z = gelu(data.z);
  100. data.w = gelu(data.w);
  101. vals_cast[row * row_stride + i * loop_stride + id] = data;
  102. }
  103. }
  104. }
  105. __global__ void fused_bias_gelu(const __half* input,
  106. const __half* bias,
  107. __half* vals,
  108. int row_stride,
  109. int iterations)
  110. {
  111. #ifdef HALF_PRECISION_AVAILABLE
  112. int row = blockIdx.x;
  113. int id = threadIdx.x;
  114. int loop_stride = blockDim.x;
  115. const float2* input_cast = reinterpret_cast<const float2*>(input);
  116. float2* vals_cast = reinterpret_cast<float2*>(vals);
  117. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  118. for (int i = 0; i < iterations; i++) {
  119. if (i * loop_stride + id < row_stride) {
  120. float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
  121. float2 bias_vec = bias_cast[i * loop_stride + id];
  122. __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
  123. __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
  124. float2 low_data = __half22float2(vals_half[0]);
  125. float2 high_data = __half22float2(vals_half[1]);
  126. float2 low_bias = __half22float2(bias_half[0]);
  127. float2 high_bias = __half22float2(bias_half[1]);
  128. low_data.x += low_bias.x;
  129. low_data.y += low_bias.y;
  130. high_data.x += high_bias.x;
  131. high_data.y += high_bias.y;
  132. low_data.x = gelu(low_data.x);
  133. low_data.y = gelu(low_data.y);
  134. high_data.x = gelu(high_data.x);
  135. high_data.y = gelu(high_data.y);
  136. vals_half[0] = __float22half2_rn(low_data);
  137. vals_half[1] = __float22half2_rn(high_data);
  138. vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
  139. }
  140. }
  141. #endif
  142. }
  143. __global__ void d_gelu_func(float* d_output,
  144. const float* gelu_input,
  145. const float* bias,
  146. int row_stride,
  147. int iterations)
  148. {
  149. int row = blockIdx.x;
  150. int id = threadIdx.x;
  151. int loop_stride = blockDim.x;
  152. float4* d_output_cast = reinterpret_cast<float4*>(d_output);
  153. const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
  154. const float4* bias_cast = reinterpret_cast<const float4*>(bias);
  155. for (int i = 0; i < iterations; i++) {
  156. if (i * loop_stride + id < row_stride) {
  157. float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
  158. float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
  159. float4 bias_data = bias_cast[i * loop_stride + id];
  160. gelu_input_data.x += bias_data.x;
  161. gelu_input_data.y += bias_data.y;
  162. gelu_input_data.z += bias_data.z;
  163. gelu_input_data.w += bias_data.w;
  164. output_data.x *= d_gelu(gelu_input_data.x);
  165. output_data.y *= d_gelu(gelu_input_data.y);
  166. output_data.z *= d_gelu(gelu_input_data.z);
  167. output_data.w *= d_gelu(gelu_input_data.w);
  168. d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
  169. }
  170. }
  171. }
  172. __global__ void d_gelu_func(__half* d_output,
  173. const __half* gelu_input,
  174. const __half* bias,
  175. int row_stride,
  176. int iterations)
  177. {
  178. #ifdef HALF_PRECISION_AVAILABLE
  179. int row = blockIdx.x;
  180. int id = threadIdx.x;
  181. int loop_stride = blockDim.x;
  182. float2* d_output_cast = reinterpret_cast<float2*>(d_output);
  183. const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
  184. const float2* bias_cast = reinterpret_cast<const float2*>(bias);
  185. #pragma unroll
  186. for (int i = 0; i < iterations; i++) {
  187. if (i * loop_stride + id < row_stride) {
  188. float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
  189. float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
  190. float2 bias_vec = bias_cast[i * loop_stride + id];
  191. __half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
  192. __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
  193. __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
  194. float2 output_half_0 = __half22float2(output_data_half[0]);
  195. float2 output_half_1 = __half22float2(output_data_half[1]);
  196. float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
  197. float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
  198. float2 bias_half_0 = __half22float2(bias_half[0]);
  199. float2 bias_half_1 = __half22float2(bias_half[1]);
  200. gelu_input_half_0.x += bias_half_0.x;
  201. gelu_input_half_0.y += bias_half_0.y;
  202. gelu_input_half_1.x += bias_half_1.x;
  203. gelu_input_half_1.y += bias_half_1.y;
  204. output_half_0.x *= d_gelu(gelu_input_half_0.x);
  205. output_half_0.y *= d_gelu(gelu_input_half_0.y);
  206. output_half_1.x *= d_gelu(gelu_input_half_1.x);
  207. output_half_1.y *= d_gelu(gelu_input_half_1.y);
  208. float2 result;
  209. __half2* result_half2 = reinterpret_cast<__half2*>(&result);
  210. result_half2[0] = __float22half2_rn(output_half_0);
  211. result_half2[1] = __float22half2_rn(output_half_1);
  212. d_output_cast[row * row_stride + i * loop_stride + id] = result;
  213. }
  214. }
  215. #endif
  216. }
  217. template <typename T>
  218. void launch_bias_gelu(const T* input,
  219. const T* bias,
  220. T* output,
  221. int intermediate_size,
  222. int batch_size,
  223. cudaStream_t stream)
  224. {
  225. int iterations = (intermediate_size + 1023) / 1024;
  226. int threads = (intermediate_size - 1) / (iterations * 4) + 1;
  227. dim3 block_dims(threads);
  228. dim3 grid_dims(batch_size);
  229. fused_bias_gelu<<<grid_dims, block_dims, 0, stream>>>(
  230. input, bias, output, intermediate_size / 4, iterations);
  231. }
  232. template <typename T>
  233. void launch_gelu(const T* input,
  234. T* output,
  235. int intermediate_size,
  236. int batch_size,
  237. cudaStream_t stream)
  238. {
  239. int iterations = (intermediate_size + 1023) / 1024;
  240. int threads = (intermediate_size - 1) / (iterations * 4) + 1;
  241. dim3 block_dims(threads);
  242. dim3 grid_dims(batch_size);
  243. gelu_kernel<<<grid_dims, block_dims, 0, stream>>>(
  244. input, output, intermediate_size / 4, iterations);
  245. }
  246. template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, cudaStream_t);
  247. template void launch_bias_gelu<__half>(const __half*,
  248. const __half*,
  249. __half*,
  250. int,
  251. int,
  252. cudaStream_t);
  253. template void launch_gelu<float>(const float*, float*, int, int, cudaStream_t);
  254. template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t);
  255. template <typename T>
  256. void launch_d_gelu(T* d_output,
  257. const T* input,
  258. const T* bias,
  259. int intermediate_size,
  260. int batch_size,
  261. cudaStream_t stream)
  262. {
  263. int iterations = (intermediate_size + 1023) / 1024;
  264. int threads = (intermediate_size - 1) / (iterations * 4) + 1;
  265. dim3 block_dims(threads);
  266. dim3 grid_dims(batch_size);
  267. d_gelu_func<<<grid_dims, block_dims, 0, stream>>>(
  268. d_output, input, bias, intermediate_size / 4, iterations);
  269. }
  270. template void launch_d_gelu<float>(float*, const float*, const float*, int, int, cudaStream_t);
  271. template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t);