fused_lamb_cuda_kernel.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. /* Copyright 2019 The Microsoft DeepSpeed Team */
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <stdio.h>
  5. #include <cmath>
  6. #include "ATen/ATen.h"
  7. #include "ATen/TensorUtils.h"
  8. #include "ATen/cuda/CUDAContext.h"
  9. #include "ATen/cuda/detail/IndexUtils.cuh"
  10. // #include "ATen/Type.h"
  11. #include "ATen/AccumulateType.h"
  12. #include <iostream>
  13. // #include <helper_functions.h>
  14. #if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
  15. #include <hip/hip_cooperative_groups.h>
  16. #else
  17. #include <cooperative_groups.h>
  18. #endif
  19. #include <cuda_runtime_api.h>
  20. #include <stdio.h>
  21. namespace cg = cooperative_groups;
  22. // Utility class used to avoid linker errors with extern
  23. // unsized shared memory arrays with templated type
  24. namespace {
  25. // This is the un-specialized struct. Note that we prevent instantiation of this
  26. // struct by putting an undefined symbol in the function body so it won't compile.
  27. template <typename T>
  28. struct SharedMemory {
  29. // Ensure that we won't compile any un-specialized types
  30. __device__ inline operator T*()
  31. {
  32. #ifndef _WIN32
  33. extern __device__ void error(void);
  34. error();
  35. #endif
  36. return NULL;
  37. }
  38. };
  39. template <>
  40. struct SharedMemory<float> {
  41. __device__ inline operator float*()
  42. {
  43. extern __shared__ float s_float[];
  44. return s_float;
  45. }
  46. };
  47. template <>
  48. struct SharedMemory<double> {
  49. __device__ inline operator double*()
  50. {
  51. extern __shared__ double s_double[];
  52. return s_double;
  53. }
  54. };
  55. } // namespace
  56. #include "type_shim.h"
  57. typedef enum {
  58. ADAM_MODE_0 = 0, // eps under square root
  59. ADAM_MODE_1 = 1 // eps outside square root
  60. } adamMode_t;
  61. // s_a and s_b are in shared memory
  62. // g_a and g_b are in shared memory
  63. template <typename T, int blockSize>
  64. __device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
  65. {
  66. // Handle to thread block group
  67. cg::thread_block cta = cg::this_thread_block();
  68. // perform block reduction in shared memory,
  69. unsigned int tid = cta.thread_rank();
  70. T a_sum = s_a[tid];
  71. T b_sum = s_b[tid];
  72. cg::sync(cta);
  73. // do reduction in shared mem
  74. if ((blockSize >= 512) && (tid < 256)) {
  75. s_a[tid] = a_sum = a_sum + s_a[tid + 256];
  76. s_b[tid] = b_sum = b_sum + s_b[tid + 256];
  77. }
  78. cg::sync(cta);
  79. if ((blockSize >= 256) && (tid < 128)) {
  80. s_a[tid] = a_sum = a_sum + s_a[tid + 128];
  81. s_b[tid] = b_sum = b_sum + s_b[tid + 128];
  82. }
  83. cg::sync(cta);
  84. if ((blockSize >= 128) && (tid < 64)) {
  85. s_a[tid] = a_sum = a_sum + s_a[tid + 64];
  86. s_b[tid] = b_sum = b_sum + s_b[tid + 64];
  87. }
  88. cg::sync(cta);
  89. #if (__CUDA_ARCH__ >= 300) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 502)
  90. if (tid < 32) {
  91. cg::coalesced_group active = cg::coalesced_threads();
  92. // Fetch final intermediate sum from 2nd warp
  93. if (blockSize >= 64) {
  94. a_sum = a_sum + s_a[tid + 32];
  95. b_sum = b_sum + s_b[tid + 32];
  96. }
  97. // Reduce final warp using shuffle
  98. for (int offset = warpSize / 2; offset > 0; offset /= 2) {
  99. a_sum += active.shfl_down(a_sum, offset);
  100. b_sum += active.shfl_down(b_sum, offset);
  101. }
  102. }
  103. #else
  104. if ((blockSize >= 64) && (tid < 32)) {
  105. s_a[tid] = a_sum = a_sum + s_a[tid + 32];
  106. s_b[tid] = b_sum = b_sum + s_b[tid + 32];
  107. }
  108. cg::sync(cta);
  109. if ((blockSize >= 32) && (tid < 16)) {
  110. s_a[tid] = a_sum = a_sum + s_a[tid + 16];
  111. s_b[tid] = b_sum = b_sum + s_b[tid + 16];
  112. }
  113. cg::sync(cta);
  114. if ((blockSize >= 16) && (tid < 8)) {
  115. s_a[tid] = a_sum = a_sum + s_a[tid + 8];
  116. s_b[tid] = b_sum = b_sum + s_b[tid + 8];
  117. }
  118. cg::sync(cta);
  119. if ((blockSize >= 8) && (tid < 4)) {
  120. s_a[tid] = a_sum = a_sum + s_a[tid + 4];
  121. s_b[tid] = b_sum = b_sum + s_b[tid + 4];
  122. }
  123. cg::sync(cta);
  124. if ((blockSize >= 4) && (tid < 2)) {
  125. s_a[tid] = a_sum = a_sum + s_a[tid + 2];
  126. s_b[tid] = b_sum = b_sum + s_b[tid + 2];
  127. }
  128. cg::sync(cta);
  129. if ((blockSize >= 2) && (tid < 1)) {
  130. s_a[tid] = a_sum = a_sum + s_a[tid + 1];
  131. s_b[tid] = b_sum = b_sum + s_b[tid + 1];
  132. }
  133. cg::sync(cta);
  134. #endif
  135. // write result for this block to global mem
  136. if (tid == 0) {
  137. g_a[blockIdx.x] = (T)a_sum;
  138. g_b[blockIdx.x] = (T)b_sum;
  139. }
  140. }
  141. template <typename T, int blockSize>
  142. __device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
  143. {
  144. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  145. T* s_a = SharedMemory<T>();
  146. T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
  147. s_a[threadIdInBlock] = a;
  148. s_b[threadIdInBlock] = b;
  149. reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
  150. }
  151. template <typename T, typename GRAD_T, int blockSize>
  152. __global__ void lamb_cuda_kernel_part1(
  153. T* __restrict__ p,
  154. GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
  155. T* __restrict__ m,
  156. T* __restrict__ v,
  157. const GRAD_T* __restrict__ g,
  158. const float b1,
  159. const float b2,
  160. const float eps,
  161. const float grad_scale,
  162. const float step_size,
  163. const size_t tsize,
  164. adamMode_t mode,
  165. const float decay,
  166. T* __restrict__ w_l2_i,
  167. T* __restrict__ u_l2_i)
  168. {
  169. // Assuming 2D grids and 2D blocks
  170. const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
  171. const int threadsPerBlock = blockDim.x * blockDim.y;
  172. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  173. const int i = (blockId * threadsPerBlock + threadIdInBlock);
  174. const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
  175. T reg_w = 0;
  176. T reg_u = 0;
  177. for (int j = i; j < tsize; j += totThreads) {
  178. T scaled_grad = g[j] / grad_scale;
  179. T pj = p[j];
  180. m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
  181. v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
  182. float denom;
  183. if (mode == ADAM_MODE_0)
  184. denom = sqrtf(v[j] + eps);
  185. else // Mode 1
  186. denom = sqrtf(v[j]) + eps;
  187. T update = (m[j] / denom) + (decay * p[j]);
  188. reg_u += update * update;
  189. reg_w += pj * pj;
  190. }
  191. reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
  192. }
  193. template <typename T, typename GRAD_T, int blockSize>
  194. __global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
  195. {
  196. T* s_a = SharedMemory<T>();
  197. T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
  198. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  199. s_a[threadIdInBlock] = g_a[threadIdInBlock];
  200. s_b[threadIdInBlock] = g_b[threadIdInBlock];
  201. if (threadIdInBlock >= tsize) {
  202. s_a[threadIdInBlock] = 0.0;
  203. s_b[threadIdInBlock] = 0.0;
  204. }
  205. reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
  206. }
  207. template <typename T, typename GRAD_T>
  208. __global__ void lamb_cuda_kernel_part3(
  209. T* __restrict__ p,
  210. GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
  211. T* __restrict__ m,
  212. T* __restrict__ v,
  213. const GRAD_T* __restrict__ g,
  214. const float b1,
  215. const float b2,
  216. const float max_coeff,
  217. const float min_coeff,
  218. const float eps,
  219. const float grad_scale,
  220. const float step_size,
  221. const size_t tsize,
  222. adamMode_t mode,
  223. const float decay,
  224. T* __restrict__ w_l2_i,
  225. T* __restrict__ u_l2_i,
  226. T* __restrict__ lamb_coeff_val)
  227. {
  228. // Assuming 2D grids and 2D blocks
  229. const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
  230. const int threadsPerBlock = blockDim.x * blockDim.y;
  231. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  232. const int i = (blockId * threadsPerBlock + threadIdInBlock);
  233. const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
  234. T reg_w = sqrtf(w_l2_i[0]);
  235. T reg_u = sqrtf(u_l2_i[0]);
  236. float lamb_coeff = 1.0;
  237. if (reg_w != 0 && reg_u != 0) {
  238. lamb_coeff = reg_w / reg_u;
  239. if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
  240. if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
  241. }
  242. if (blockId == 0 && threadIdInBlock == 0) {
  243. lamb_coeff_val[0] = lamb_coeff;
  244. // printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
  245. }
  246. for (int j = i; j < tsize; j += totThreads) {
  247. T pj = (float)p[j];
  248. T mj = m[j];
  249. T vj = v[j];
  250. float denom;
  251. if (mode == ADAM_MODE_0)
  252. denom = sqrtf(vj + eps);
  253. else // Mode 1
  254. denom = sqrtf(vj) + eps;
  255. T update = (mj / denom) + (decay * pj);
  256. pj = pj - (step_size * lamb_coeff * update);
  257. p[j] = pj;
  258. if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
  259. }
  260. }
  261. void fused_lamb_cuda(at::Tensor& p,
  262. at::Tensor& p_copy,
  263. at::Tensor& m,
  264. at::Tensor& v,
  265. at::Tensor& g,
  266. float lr,
  267. float beta1,
  268. float beta2,
  269. float max_coeff,
  270. float min_coeff,
  271. float eps,
  272. float grad_scale,
  273. int step,
  274. int mode,
  275. int bias_correction,
  276. float decay,
  277. at::Tensor& w_l2_i,
  278. at::Tensor& u_l2_i,
  279. at::Tensor& lamb_coeff)
  280. {
  281. // using namespace at;
  282. // Get tensor size
  283. int tsize = p.numel();
  284. // Determine #threads and #blocks
  285. const int threadsPerBlock = 512;
  286. int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
  287. if (num_blocks > 512) num_blocks = 512;
  288. int smemsize = 0;
  289. if (p.type().scalarType() == at::ScalarType::Double)
  290. smemsize = 2 * threadsPerBlock * sizeof(double);
  291. else
  292. smemsize = 2 * threadsPerBlock * sizeof(float);
  293. const dim3 blocks(num_blocks);
  294. const dim3 threads(threadsPerBlock);
  295. AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
  296. "parameter tensor is too large to be indexed with int32");
  297. // Constants
  298. float step_size = 0;
  299. if (bias_correction == 1) {
  300. const float bias_correction1 = 1 - std::pow(beta1, step);
  301. const float bias_correction2 = 1 - std::pow(beta2, step);
  302. step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
  303. } else {
  304. step_size = lr;
  305. }
  306. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  307. if (g.type().scalarType() == at::ScalarType::Half) {
  308. // all other values should be fp32 for half gradients
  309. AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
  310. "expected parameter to be of float type");
  311. // dispatch is done on the gradient type
  312. using namespace at; // prevents "toString is undefined" errors
  313. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  314. g.scalar_type(), "lamb_cuda_kernel", ([&] {
  315. using accscalar_t = at::acc_type<scalar_t, true>;
  316. lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>
  317. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  318. p.data<accscalar_t>(),
  319. p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
  320. m.data<accscalar_t>(),
  321. v.data<accscalar_t>(),
  322. g.data<scalar_t>(),
  323. beta1,
  324. beta2,
  325. eps,
  326. grad_scale,
  327. step_size,
  328. tsize,
  329. (adamMode_t)mode,
  330. decay,
  331. w_l2_i.data<accscalar_t>(),
  332. u_l2_i.data<accscalar_t>());
  333. lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>
  334. <<<1, threadsPerBlock, smemsize, stream>>>(
  335. num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
  336. lamb_cuda_kernel_part3<accscalar_t, scalar_t>
  337. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  338. p.data<accscalar_t>(),
  339. p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
  340. m.data<accscalar_t>(),
  341. v.data<accscalar_t>(),
  342. g.data<scalar_t>(),
  343. beta1,
  344. beta2,
  345. max_coeff,
  346. min_coeff,
  347. eps,
  348. grad_scale,
  349. step_size,
  350. tsize,
  351. (adamMode_t)mode,
  352. decay,
  353. w_l2_i.data<accscalar_t>(),
  354. u_l2_i.data<accscalar_t>(),
  355. lamb_coeff.data<accscalar_t>());
  356. }));
  357. } else {
  358. using namespace at;
  359. AT_DISPATCH_FLOATING_TYPES(
  360. g.scalar_type(), "lamb_cuda_kernel", ([&] {
  361. lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>
  362. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  363. p.data<scalar_t>(),
  364. NULL, // don't output p_copy for fp32, it's wasted write
  365. m.data<scalar_t>(),
  366. v.data<scalar_t>(),
  367. g.data<scalar_t>(),
  368. beta1,
  369. beta2,
  370. eps,
  371. grad_scale,
  372. step_size,
  373. tsize,
  374. (adamMode_t)mode,
  375. decay,
  376. w_l2_i.data<scalar_t>(),
  377. u_l2_i.data<scalar_t>());
  378. lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>
  379. <<<1, threadsPerBlock, smemsize, stream>>>(
  380. num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
  381. lamb_cuda_kernel_part3<scalar_t, scalar_t>
  382. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  383. p.data<scalar_t>(),
  384. NULL, // don't output p_copy for fp32, it's wasted write
  385. m.data<scalar_t>(),
  386. v.data<scalar_t>(),
  387. g.data<scalar_t>(),
  388. beta1,
  389. beta2,
  390. max_coeff,
  391. min_coeff,
  392. eps,
  393. grad_scale,
  394. step_size,
  395. tsize,
  396. (adamMode_t)mode,
  397. decay,
  398. w_l2_i.data<scalar_t>(),
  399. u_l2_i.data<scalar_t>(),
  400. lamb_coeff.data<scalar_t>());
  401. }));
  402. }
  403. C10_CUDA_CHECK(cudaGetLastError());
  404. }
  405. // template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
  406. // float* g_b, cg::grid_group &cgg);