fused_lamb_cuda_kernel.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. /* Copyright 2019 The Microsoft DeepSpeed Team */
  2. #include <cuda.h>
  3. #include <cuda_runtime.h>
  4. #include <stdio.h>
  5. #include <cmath>
  6. #include "ATen/ATen.h"
  7. #include "ATen/TensorUtils.h"
  8. #include "ATen/cuda/CUDAContext.h"
  9. #include "ATen/cuda/detail/IndexUtils.cuh"
  10. //#include "ATen/Type.h"
  11. #include <THC/THCGeneral.h>
  12. #include "ATen/AccumulateType.h"
  13. #include <iostream>
  14. //#include <helper_functions.h>
  15. #include <cooperative_groups.h>
  16. #include <cuda_runtime_api.h>
  17. #include <stdio.h>
  18. namespace cg = cooperative_groups;
  19. // Utility class used to avoid linker errors with extern
  20. // unsized shared memory arrays with templated type
  21. namespace {
  22. // This is the un-specialized struct. Note that we prevent instantiation of this
  23. // struct by putting an undefined symbol in the function body so it won't compile.
  24. template <typename T>
  25. struct SharedMemory {
  26. // Ensure that we won't compile any un-specialized types
  27. __device__ inline operator T*()
  28. {
  29. #ifndef _WIN32
  30. extern __device__ void error(void);
  31. error();
  32. #endif
  33. return NULL;
  34. }
  35. };
  36. template <>
  37. struct SharedMemory<float> {
  38. __device__ inline operator float*()
  39. {
  40. extern __shared__ float s_float[];
  41. return s_float;
  42. }
  43. };
  44. template <>
  45. struct SharedMemory<double> {
  46. __device__ inline operator double*()
  47. {
  48. extern __shared__ double s_double[];
  49. return s_double;
  50. }
  51. };
  52. } // namespace
  53. #include "type_shim.h"
  54. typedef enum {
  55. ADAM_MODE_0 = 0, // eps under square root
  56. ADAM_MODE_1 = 1 // eps outside square root
  57. } adamMode_t;
  58. // s_a and s_b are in shared memory
  59. // g_a and g_b are in shared memory
  60. template <typename T, int blockSize>
  61. __device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
  62. {
  63. // Handle to thread block group
  64. cg::thread_block cta = cg::this_thread_block();
  65. // perform block reduction in shared memory,
  66. unsigned int tid = cta.thread_rank();
  67. T a_sum = s_a[tid];
  68. T b_sum = s_b[tid];
  69. cg::sync(cta);
  70. // do reduction in shared mem
  71. if ((blockSize >= 512) && (tid < 256)) {
  72. s_a[tid] = a_sum = a_sum + s_a[tid + 256];
  73. s_b[tid] = b_sum = b_sum + s_b[tid + 256];
  74. }
  75. cg::sync(cta);
  76. if ((blockSize >= 256) && (tid < 128)) {
  77. s_a[tid] = a_sum = a_sum + s_a[tid + 128];
  78. s_b[tid] = b_sum = b_sum + s_b[tid + 128];
  79. }
  80. cg::sync(cta);
  81. if ((blockSize >= 128) && (tid < 64)) {
  82. s_a[tid] = a_sum = a_sum + s_a[tid + 64];
  83. s_b[tid] = b_sum = b_sum + s_b[tid + 64];
  84. }
  85. cg::sync(cta);
  86. #if (__CUDA_ARCH__ >= 300)
  87. if (tid < 32) {
  88. cg::coalesced_group active = cg::coalesced_threads();
  89. // Fetch final intermediate sum from 2nd warp
  90. if (blockSize >= 64) {
  91. a_sum = a_sum + s_a[tid + 32];
  92. b_sum = b_sum + s_b[tid + 32];
  93. }
  94. // Reduce final warp using shuffle
  95. for (int offset = warpSize / 2; offset > 0; offset /= 2) {
  96. a_sum += active.shfl_down(a_sum, offset);
  97. b_sum += active.shfl_down(b_sum, offset);
  98. }
  99. }
  100. #else
  101. if ((blockSize >= 64) && (tid < 32)) {
  102. s_a[tid] = a_sum = a_sum + s_a[tid + 32];
  103. s_b[tid] = b_sum = b_sum + s_b[tid + 32];
  104. }
  105. cg::sync(cta);
  106. if ((blockSize >= 32) && (tid < 16)) {
  107. s_a[tid] = a_sum = a_sum + s_a[tid + 16];
  108. s_b[tid] = b_sum = b_sum + s_b[tid + 16];
  109. }
  110. cg::sync(cta);
  111. if ((blockSize >= 16) && (tid < 8)) {
  112. s_a[tid] = a_sum = a_sum + s_a[tid + 8];
  113. s_b[tid] = b_sum = b_sum + s_b[tid + 8];
  114. }
  115. cg::sync(cta);
  116. if ((blockSize >= 8) && (tid < 4)) {
  117. s_a[tid] = a_sum = a_sum + s_a[tid + 4];
  118. s_b[tid] = b_sum = b_sum + s_b[tid + 4];
  119. }
  120. cg::sync(cta);
  121. if ((blockSize >= 4) && (tid < 2)) {
  122. s_a[tid] = a_sum = a_sum + s_a[tid + 2];
  123. s_b[tid] = b_sum = b_sum + s_b[tid + 2];
  124. }
  125. cg::sync(cta);
  126. if ((blockSize >= 2) && (tid < 1)) {
  127. s_a[tid] = a_sum = a_sum + s_a[tid + 1];
  128. s_b[tid] = b_sum = b_sum + s_b[tid + 1];
  129. }
  130. cg::sync(cta);
  131. #endif
  132. // write result for this block to global mem
  133. if (tid == 0) {
  134. g_a[blockIdx.x] = (T)a_sum;
  135. g_b[blockIdx.x] = (T)b_sum;
  136. }
  137. }
  138. template <typename T, int blockSize>
  139. __device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
  140. {
  141. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  142. T* s_a = SharedMemory<T>();
  143. T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
  144. s_a[threadIdInBlock] = a;
  145. s_b[threadIdInBlock] = b;
  146. reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
  147. }
  148. template <typename T, typename GRAD_T, int blockSize>
  149. __global__ void lamb_cuda_kernel_part1(
  150. T* __restrict__ p,
  151. GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
  152. T* __restrict__ m,
  153. T* __restrict__ v,
  154. const GRAD_T* __restrict__ g,
  155. const float b1,
  156. const float b2,
  157. const float eps,
  158. const float grad_scale,
  159. const float step_size,
  160. const size_t tsize,
  161. adamMode_t mode,
  162. const float decay,
  163. T* __restrict__ w_l2_i,
  164. T* __restrict__ u_l2_i)
  165. {
  166. // Assuming 2D grids and 2D blocks
  167. const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
  168. const int threadsPerBlock = blockDim.x * blockDim.y;
  169. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  170. const int i = (blockId * threadsPerBlock + threadIdInBlock);
  171. const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
  172. T reg_w = 0;
  173. T reg_u = 0;
  174. for (int j = i; j < tsize; j += totThreads) {
  175. T scaled_grad = g[j] / grad_scale;
  176. T pj = p[j];
  177. m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
  178. v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
  179. float denom;
  180. if (mode == ADAM_MODE_0)
  181. denom = sqrtf(v[j] + eps);
  182. else // Mode 1
  183. denom = sqrtf(v[j]) + eps;
  184. T update = (m[j] / denom) + (decay * p[j]);
  185. reg_u += update * update;
  186. reg_w += pj * pj;
  187. }
  188. reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
  189. }
  190. template <typename T, typename GRAD_T, int blockSize>
  191. __global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
  192. {
  193. T* s_a = SharedMemory<T>();
  194. T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
  195. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  196. s_a[threadIdInBlock] = g_a[threadIdInBlock];
  197. s_b[threadIdInBlock] = g_b[threadIdInBlock];
  198. if (threadIdInBlock >= tsize) {
  199. s_a[threadIdInBlock] = 0.0;
  200. s_b[threadIdInBlock] = 0.0;
  201. }
  202. reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
  203. }
  204. template <typename T, typename GRAD_T>
  205. __global__ void lamb_cuda_kernel_part3(
  206. T* __restrict__ p,
  207. GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
  208. T* __restrict__ m,
  209. T* __restrict__ v,
  210. const GRAD_T* __restrict__ g,
  211. const float b1,
  212. const float b2,
  213. const float max_coeff,
  214. const float min_coeff,
  215. const float eps,
  216. const float grad_scale,
  217. const float step_size,
  218. const size_t tsize,
  219. adamMode_t mode,
  220. const float decay,
  221. T* __restrict__ w_l2_i,
  222. T* __restrict__ u_l2_i,
  223. T* __restrict__ lamb_coeff_val)
  224. {
  225. // Assuming 2D grids and 2D blocks
  226. const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
  227. const int threadsPerBlock = blockDim.x * blockDim.y;
  228. const int threadIdInBlock = cg::this_thread_block().thread_rank();
  229. const int i = (blockId * threadsPerBlock + threadIdInBlock);
  230. const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
  231. T reg_w = sqrtf(w_l2_i[0]);
  232. T reg_u = sqrtf(u_l2_i[0]);
  233. float lamb_coeff = 1.0;
  234. if (reg_w != 0 && reg_u != 0) {
  235. lamb_coeff = reg_w / reg_u;
  236. if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
  237. if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
  238. }
  239. if (blockId == 0 && threadIdInBlock == 0) {
  240. lamb_coeff_val[0] = lamb_coeff;
  241. // printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
  242. }
  243. for (int j = i; j < tsize; j += totThreads) {
  244. T pj = (float)p[j];
  245. T mj = m[j];
  246. T vj = v[j];
  247. float denom;
  248. if (mode == ADAM_MODE_0)
  249. denom = sqrtf(vj + eps);
  250. else // Mode 1
  251. denom = sqrtf(vj) + eps;
  252. T update = (mj / denom) + (decay * pj);
  253. pj = pj - (step_size * lamb_coeff * update);
  254. p[j] = pj;
  255. if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
  256. }
  257. }
  258. void fused_lamb_cuda(at::Tensor& p,
  259. at::Tensor& p_copy,
  260. at::Tensor& m,
  261. at::Tensor& v,
  262. at::Tensor& g,
  263. float lr,
  264. float beta1,
  265. float beta2,
  266. float max_coeff,
  267. float min_coeff,
  268. float eps,
  269. float grad_scale,
  270. int step,
  271. int mode,
  272. int bias_correction,
  273. float decay,
  274. at::Tensor& w_l2_i,
  275. at::Tensor& u_l2_i,
  276. at::Tensor& lamb_coeff)
  277. {
  278. // using namespace at;
  279. // Get tensor size
  280. int tsize = p.numel();
  281. // Determine #threads and #blocks
  282. const int threadsPerBlock = 512;
  283. int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
  284. if (num_blocks > 512) num_blocks = 512;
  285. int smemsize = 0;
  286. if (p.type().scalarType() == at::ScalarType::Double)
  287. smemsize = 2 * threadsPerBlock * sizeof(double);
  288. else
  289. smemsize = 2 * threadsPerBlock * sizeof(float);
  290. const dim3 blocks(num_blocks);
  291. const dim3 threads(threadsPerBlock);
  292. AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
  293. "parameter tensor is too large to be indexed with int32");
  294. // Constants
  295. float step_size = 0;
  296. if (bias_correction == 1) {
  297. const float bias_correction1 = 1 - std::pow(beta1, step);
  298. const float bias_correction2 = 1 - std::pow(beta2, step);
  299. step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
  300. } else {
  301. step_size = lr;
  302. }
  303. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  304. if (g.type().scalarType() == at::ScalarType::Half) {
  305. // all other values should be fp32 for half gradients
  306. AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
  307. "expected parameter to be of float type");
  308. // dispatch is done on the gradient type
  309. using namespace at; // prevents "toString is undefined" errors
  310. AT_DISPATCH_FLOATING_TYPES_AND_HALF(
  311. g.scalar_type(), "lamb_cuda_kernel", ([&] {
  312. using accscalar_t = at::acc_type<scalar_t, true>;
  313. lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>
  314. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  315. p.data<accscalar_t>(),
  316. p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
  317. m.data<accscalar_t>(),
  318. v.data<accscalar_t>(),
  319. g.data<scalar_t>(),
  320. beta1,
  321. beta2,
  322. eps,
  323. grad_scale,
  324. step_size,
  325. tsize,
  326. (adamMode_t)mode,
  327. decay,
  328. w_l2_i.data<accscalar_t>(),
  329. u_l2_i.data<accscalar_t>());
  330. lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>
  331. <<<1, threadsPerBlock, smemsize, stream>>>(
  332. num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
  333. lamb_cuda_kernel_part3<accscalar_t, scalar_t>
  334. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  335. p.data<accscalar_t>(),
  336. p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
  337. m.data<accscalar_t>(),
  338. v.data<accscalar_t>(),
  339. g.data<scalar_t>(),
  340. beta1,
  341. beta2,
  342. max_coeff,
  343. min_coeff,
  344. eps,
  345. grad_scale,
  346. step_size,
  347. tsize,
  348. (adamMode_t)mode,
  349. decay,
  350. w_l2_i.data<accscalar_t>(),
  351. u_l2_i.data<accscalar_t>(),
  352. lamb_coeff.data<accscalar_t>());
  353. }));
  354. } else {
  355. using namespace at;
  356. AT_DISPATCH_FLOATING_TYPES(
  357. g.scalar_type(), "lamb_cuda_kernel", ([&] {
  358. lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>
  359. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  360. p.data<scalar_t>(),
  361. NULL, // don't output p_copy for fp32, it's wasted write
  362. m.data<scalar_t>(),
  363. v.data<scalar_t>(),
  364. g.data<scalar_t>(),
  365. beta1,
  366. beta2,
  367. eps,
  368. grad_scale,
  369. step_size,
  370. tsize,
  371. (adamMode_t)mode,
  372. decay,
  373. w_l2_i.data<scalar_t>(),
  374. u_l2_i.data<scalar_t>());
  375. lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>
  376. <<<1, threadsPerBlock, smemsize, stream>>>(
  377. num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
  378. lamb_cuda_kernel_part3<scalar_t, scalar_t>
  379. <<<blocks, threadsPerBlock, smemsize, stream>>>(
  380. p.data<scalar_t>(),
  381. NULL, // don't output p_copy for fp32, it's wasted write
  382. m.data<scalar_t>(),
  383. v.data<scalar_t>(),
  384. g.data<scalar_t>(),
  385. beta1,
  386. beta2,
  387. max_coeff,
  388. min_coeff,
  389. eps,
  390. grad_scale,
  391. step_size,
  392. tsize,
  393. (adamMode_t)mode,
  394. decay,
  395. w_l2_i.data<scalar_t>(),
  396. u_l2_i.data<scalar_t>(),
  397. lamb_coeff.data<scalar_t>());
  398. }));
  399. }
  400. THCudaCheck(cudaGetLastError());
  401. }
  402. // template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
  403. // float* g_b, cg::grid_group &cgg);