fused_lamb_cuda_kernel.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #include <cuda.h>
  5. #include <cuda_runtime.h>
  6. #include <stdio.h>
  7. #include <cmath>
  8. #include "ATen/ATen.h"
  9. #include "ATen/TensorUtils.h"
  10. #include "ATen/cuda/CUDAContext.h"
  11. #include "ATen/cuda/detail/IndexUtils.cuh"
  12. // #include "ATen/Type.h"
  13. #include "ATen/AccumulateType.h"
  14. #include <iostream>
  15. // #include <helper_functions.h>
  16. #if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
  17. #include <hip/hip_cooperative_groups.h>
  18. #else
  19. #include <cooperative_groups.h>
  20. #endif
  21. #include <cuda_runtime_api.h>
  22. #include <stdio.h>
  23. namespace cg = cooperative_groups;
  24. // Utility class used to avoid linker errors with extern
  25. // unsized shared memory arrays with templated type
  26. namespace {
  27. // This is the un-specialized struct. Note that we prevent instantiation of this
  28. // struct by putting an undefined symbol in the function body so it won't compile.
  29. template <typename T>
  30. struct SharedMemory {
  31. // Ensure that we won't compile any un-specialized types
  32. __device__ inline operator T*()
  33. {
  34. #ifndef _WIN32
  35. extern __device__ void error(void);
  36. error();
  37. #endif
  38. return NULL;
  39. }
  40. };
  41. template <>
  42. struct SharedMemory<float> {
  43. __device__ inline operator float*()
  44. {
  45. extern __shared__ float s_float[];
  46. return s_float;
  47. }
  48. };
  49. template <>
  50. struct SharedMemory<double> {
  51. __device__ inline operator double*()
  52. {
  53. extern __shared__ double s_double[];
  54. return s_double;
  55. }
  56. };
  57. } // namespace
  58. #include "type_shim.h"
  59. typedef enum {
  60. ADAM_MODE_0 = 0, // eps under square root
  61. ADAM_MODE_1 = 1 // eps outside square root
  62. } adamMode_t;
  63. // s_a and s_b are in shared memory
  64. // g_a and g_b are in shared memory
  65. template <typename T, int blockSize>
  66. __device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
  67. {
  68. // Handle to thread block group
  69. cg::thread_block cta = cg::this_thread_block();
  70. // perform block reduction in shared memory,
  71. unsigned int tid = cta.thread_rank();
  72. T a_sum = s_a[tid];
  73. T b_sum = s_b[tid];
  74. cg::sync(cta);
  75. // do reduction in shared mem
  76. if ((blockSize >= 512) && (tid < 256)) {
  77. s_a[tid] = a_sum = a_sum + s_a[tid + 256];
  78. s_b[tid] = b_sum = b_sum + s_b[tid + 256];
  79. }
  80. cg::sync(cta);
  81. if ((blockSize >= 256) && (tid < 128)) {
  82. s_a[tid] = a_sum = a_sum + s_a[tid + 128];
  83. s_b[tid] = b_sum = b_sum + s_b[tid + 128];
  84. }
  85. cg::sync(cta);
  86. if ((blockSize >= 128) && (tid < 64)) {
  87. s_a[tid] = a_sum = a_sum + s_a[tid + 64];
  88. s_b[tid] = b_sum = b_sum + s_b[tid + 64];
  89. }
  90. cg::sync(cta);
  91. #if (__CUDA_ARCH__ >= 300) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 502)
  92. if (tid < 32) {
  93. cg::coalesced_group active = cg::coalesced_threads();
  94. // Fetch final intermediate sum from 2nd warp
  95. if (blockSize >= 64) {
  96. a_sum = a_sum + s_a[tid + 32];
  97. b_sum = b_sum + s_b[tid + 32];
  98. }
  99. // Reduce final warp using shuffle
  100. for (int offset = warpSize / 2; offset > 0; offset /= 2) {
  101. a_sum += active.shfl_down(a_sum, offset);
  102. b_sum += active.shfl_down(b_sum, offset);
  103. }
  104. }
  105. #else
  106. if ((blockSize >= 64) && (tid < 32)) {
  107. s_a[tid] = a_sum = a_sum + s_a[tid + 32];
  108. s_b[tid] = b_sum = b_sum + s_b[tid + 32];
  109. }
  110. cg::sync(cta);
  111. if ((blockSize >= 32) && (tid < 16)) {
  112. s_a[tid] = a_sum = a_sum + s_a[tid + 16];
  113. s_b[tid] = b_sum = b_sum + s_b[tid + 16];
  114. }
  115. cg::sync(cta);
  116. if ((blockSize >= 16) && (tid < 8)) {
  117. s_a[tid] = a_sum = a_sum + s_a[tid + 8];
  118. s_b[tid] = b_sum = b_sum + s_b[tid + 8];
  119. }
  120. cg::sync(cta);
  121. if ((blockSize >= 8) && (tid < 4)) {
  122. s_a[tid] = a_sum = a_sum + s_a[tid + 4];
  123. s_b[tid] = b_sum = b_sum + s_b[tid + 4];
  124. }
  125. cg::sync(cta);
  126. if ((blockSize >= 4) && (tid < 2)) {
  127. s_a[tid] = a_sum = a_sum + s_a[tid + 2];
  128. s_b[tid] = b_sum = b_sum + s_b[tid + 2];
  129. }
  130. cg::sync(cta);
  131. if ((blockSize >= 2) && (tid < 1)) {
  132. s_a[tid] = a_sum = a_sum + s_a[tid + 1];
  133. s_b[tid] = b_sum = b_sum + s_b[tid + 1];
  134. }
  135. cg::sync(cta);
  136. #endif
  137. // write result for this block to global mem
  138. if (tid == 0) {
  139. g_a[blockIdx.x] = (T)a_sum;
  140. g_b[blockIdx.x] = (T)b_sum;
  141. }
  142. }
  143. template <typename T, int blockSize>
  144. __device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
  145. {
  146. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  147. T* s_a = SharedMemory<T>();
  148. T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
  149. s_a[threadIdInBlock] = a;
  150. s_b[threadIdInBlock] = b;
  151. reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
  152. }
  153. template <typename T, typename GRAD_T, int blockSize>
  154. __global__ void lamb_cuda_kernel_part1(
  155. T* __restrict__ p,
  156. GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
  157. T* __restrict__ m,
  158. T* __restrict__ v,
  159. const GRAD_T* __restrict__ g,
  160. const float b1,
  161. const float b2,
  162. const float eps,
  163. const float grad_scale,
  164. const float step_size,
  165. const size_t tsize,
  166. adamMode_t mode,
  167. const float decay,
  168. T* __restrict__ w_l2_i,
  169. T* __restrict__ u_l2_i)
  170. {
  171. // Assuming 2D grids and 2D blocks
  172. const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
  173. const int threadsPerBlock = blockDim.x * blockDim.y;
  174. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  175. const int i = (blockId * threadsPerBlock + threadIdInBlock);
  176. const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
  177. T reg_w = 0;
  178. T reg_u = 0;
  179. for (int j = i; j < tsize; j += totThreads) {
  180. T scaled_grad = g[j] / grad_scale;
  181. T pj = p[j];
  182. m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
  183. v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
  184. float denom;
  185. if (mode == ADAM_MODE_0)
  186. denom = sqrtf(v[j] + eps);
  187. else // Mode 1
  188. denom = sqrtf(v[j]) + eps;
  189. T update = (m[j] / denom) + (decay * p[j]);
  190. reg_u += update * update;
  191. reg_w += pj * pj;
  192. }
  193. reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
  194. }
  195. template <typename T, typename GRAD_T, int blockSize>
  196. __global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
  197. {
  198. T* s_a = SharedMemory<T>();
  199. T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
  200. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  201. s_a[threadIdInBlock] = g_a[threadIdInBlock];
  202. s_b[threadIdInBlock] = g_b[threadIdInBlock];
  203. if (threadIdInBlock >= tsize) {
  204. s_a[threadIdInBlock] = 0.0;
  205. s_b[threadIdInBlock] = 0.0;
  206. }
  207. reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
  208. }
  209. template <typename T, typename GRAD_T>
  210. __global__ void lamb_cuda_kernel_part3(
  211. T* __restrict__ p,
  212. GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
  213. T* __restrict__ m,
  214. T* __restrict__ v,
  215. const GRAD_T* __restrict__ g,
  216. const float b1,
  217. const float b2,
  218. const float max_coeff,
  219. const float min_coeff,
  220. const float eps,
  221. const float grad_scale,
  222. const float step_size,
  223. const size_t tsize,
  224. adamMode_t mode,
  225. const float decay,
  226. T* __restrict__ w_l2_i,
  227. T* __restrict__ u_l2_i,
  228. T* __restrict__ lamb_coeff_val)
  229. {
  230. // Assuming 2D grids and 2D blocks
  231. const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
  232. const int threadsPerBlock = blockDim.x * blockDim.y;
  233. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  234. const int i = (blockId * threadsPerBlock + threadIdInBlock);
  235. const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
  236. T reg_w = sqrtf(w_l2_i[0]);
  237. T reg_u = sqrtf(u_l2_i[0]);
  238. float lamb_coeff = 1.0;
  239. if (reg_w != 0 && reg_u != 0) {
  240. lamb_coeff = reg_w / reg_u;
  241. if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
  242. if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
  243. }
  244. if (blockId == 0 && threadIdInBlock == 0) {
  245. lamb_coeff_val[0] = lamb_coeff;
  246. // printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
  247. }
  248. for (int j = i; j < tsize; j += totThreads) {
  249. T pj = (float)p[j];
  250. T mj = m[j];
  251. T vj = v[j];
  252. float denom;
  253. if (mode == ADAM_MODE_0)
  254. denom = sqrtf(vj + eps);
  255. else // Mode 1
  256. denom = sqrtf(vj) + eps;
  257. T update = (mj / denom) + (decay * pj);
  258. pj = pj - (step_size * lamb_coeff * update);
  259. p[j] = pj;
  260. if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
  261. }
  262. }
  263. void fused_lamb_cuda(at::Tensor& p,
  264. at::Tensor& p_copy,
  265. at::Tensor& m,
  266. at::Tensor& v,
  267. at::Tensor& g,
  268. float lr,
  269. float beta1,
  270. float beta2,
  271. float max_coeff,
  272. float min_coeff,
  273. float eps,
  274. float grad_scale,
  275. int step,
  276. int mode,
  277. int bias_correction,
  278. float decay,
  279. at::Tensor& w_l2_i,
  280. at::Tensor& u_l2_i,
  281. at::Tensor& lamb_coeff)
  282. {
  283. // using namespace at;
  284. // Get tensor size
  285. int tsize = p.numel();
  286. // Determine #threads and #blocks
  287. const int threadsPerBlock = 512;
  288. int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
  289. if (num_blocks > 512) num_blocks = 512;
  290. int smemsize = 0;
  291. if (p.type().scalarType() == at::ScalarType::Double)
  292. smemsize = 2 * threadsPerBlock * sizeof(double);
  293. else
  294. smemsize = 2 * threadsPerBlock * sizeof(float);
  295. const dim3 blocks(num_blocks);
  296. const dim3 threads(threadsPerBlock);
  297. AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
  298. "parameter tensor is too large to be indexed with int32");
  299. // Constants
  300. float step_size = 0;
  301. if (bias_correction == 1) {
  302. const float bias_correction1 = 1 - std::pow(beta1, step);
  303. const float bias_correction2 = 1 - std::pow(beta2, step);
  304. step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
  305. } else {
  306. step_size = lr;
  307. }
  308. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  309. if (g.type().scalarType() == at::ScalarType::Half) {
  310. // all other values should be fp32 for half gradients
  311. AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
  312. "expected parameter to be of float type");
  313. // dispatch is done on the gradient type
  314. using namespace at; // prevents "toString is undefined" errors
  315. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  316. g.scalar_type(), "lamb_cuda_kernel", ([&] {
  317. using accscalar_t = at::acc_type<scalar_t, true>;
  318. lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>
  319. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  320. p.data<accscalar_t>(),
  321. p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
  322. m.data<accscalar_t>(),
  323. v.data<accscalar_t>(),
  324. g.data<scalar_t>(),
  325. beta1,
  326. beta2,
  327. eps,
  328. grad_scale,
  329. step_size,
  330. tsize,
  331. (adamMode_t)mode,
  332. decay,
  333. w_l2_i.data<accscalar_t>(),
  334. u_l2_i.data<accscalar_t>());
  335. lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>
  336. <<<1, threadsPerBlock, smemsize, stream>>>(
  337. num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
  338. lamb_cuda_kernel_part3<accscalar_t, scalar_t>
  339. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  340. p.data<accscalar_t>(),
  341. p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
  342. m.data<accscalar_t>(),
  343. v.data<accscalar_t>(),
  344. g.data<scalar_t>(),
  345. beta1,
  346. beta2,
  347. max_coeff,
  348. min_coeff,
  349. eps,
  350. grad_scale,
  351. step_size,
  352. tsize,
  353. (adamMode_t)mode,
  354. decay,
  355. w_l2_i.data<accscalar_t>(),
  356. u_l2_i.data<accscalar_t>(),
  357. lamb_coeff.data<accscalar_t>());
  358. }));
  359. } else {
  360. using namespace at;
  361. AT_DISPATCH_FLOATING_TYPES(
  362. g.scalar_type(), "lamb_cuda_kernel", ([&] {
  363. lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>
  364. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  365. p.data<scalar_t>(),
  366. NULL, // don't output p_copy for fp32, it's wasted write
  367. m.data<scalar_t>(),
  368. v.data<scalar_t>(),
  369. g.data<scalar_t>(),
  370. beta1,
  371. beta2,
  372. eps,
  373. grad_scale,
  374. step_size,
  375. tsize,
  376. (adamMode_t)mode,
  377. decay,
  378. w_l2_i.data<scalar_t>(),
  379. u_l2_i.data<scalar_t>());
  380. lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>
  381. <<<1, threadsPerBlock, smemsize, stream>>>(
  382. num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
  383. lamb_cuda_kernel_part3<scalar_t, scalar_t>
  384. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  385. p.data<scalar_t>(),
  386. NULL, // don't output p_copy for fp32, it's wasted write
  387. m.data<scalar_t>(),
  388. v.data<scalar_t>(),
  389. g.data<scalar_t>(),
  390. beta1,
  391. beta2,
  392. max_coeff,
  393. min_coeff,
  394. eps,
  395. grad_scale,
  396. step_size,
  397. tsize,
  398. (adamMode_t)mode,
  399. decay,
  400. w_l2_i.data<scalar_t>(),
  401. u_l2_i.data<scalar_t>(),
  402. lamb_coeff.data<scalar_t>());
  403. }));
  404. }
  405. C10_CUDA_CHECK(cudaGetLastError());
  406. }
  407. // template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
  408. // float* g_b, cg::grid_group &cgg);