context.h 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include <cuda_runtime_api.h>
  7. #include <cassert>
  8. #include <iostream>
  9. #include <vector>
  10. #include "cublas_v2.h"
  11. #include "cuda.h"
  12. #include "curand.h"
  13. #include "gemm_test.h"
  14. #define WARP_SIZE 32
  15. #define CUDA_CHECK(callstr) \
  16. { \
  17. cudaError_t error_code = callstr; \
  18. if (error_code != cudaSuccess) { \
  19. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  20. assert(0); \
  21. } \
  22. }
  23. #define CUDA_1D_KERNEL_LOOP(i, n) \
  24. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
  25. #define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
  26. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
  27. for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
  28. #define DS_CUDA_NUM_THREADS 512
  29. #define DS_MAXIMUM_NUM_BLOCKS 262144
  30. inline int DS_GET_BLOCKS(const int N)
  31. {
  32. return (std::max)(
  33. (std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
  34. // Use at least 1 block, since CUDA does not allow empty block
  35. 1);
  36. }
  37. class TrainingContext {
  38. public:
  39. TrainingContext() : _workspace(nullptr), _seed(42), _curr_offset(0)
  40. {
  41. curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
  42. curandSetPseudoRandomGeneratorSeed(_gen, 123);
  43. cublasStatus_t stat = cublasCreate(&_cublasHandle);
  44. if (stat != CUBLAS_STATUS_SUCCESS) {
  45. // It would be nice to use cublasGetStatusName and
  46. // cublasGetStatusString, but they were only added in CUDA 11.4.2.
  47. auto message = std::string("Failed to create cublas handle: cublasStatus_t was ") +
  48. std::to_string(stat);
  49. std::cerr << message << std::endl;
  50. throw std::runtime_error(message);
  51. }
  52. }
  53. virtual ~TrainingContext()
  54. {
  55. cublasDestroy(_cublasHandle);
  56. cudaFree(_workspace);
  57. }
  58. static TrainingContext& Instance()
  59. {
  60. static TrainingContext _ctx;
  61. return _ctx;
  62. }
  63. void SetWorkSpace(void* workspace)
  64. {
  65. if (!workspace) { throw std::runtime_error("Workspace is null."); }
  66. _workspace = workspace;
  67. }
  68. void* GetWorkSpace() { return _workspace; }
  69. curandGenerator_t& GetRandGenerator() { return _gen; }
  70. cudaStream_t GetCurrentStream()
  71. {
  72. // get current pytorch stream.
  73. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  74. return stream;
  75. }
  76. cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); }
  77. cublasHandle_t GetCublasHandle() { return _cublasHandle; }
  78. std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
  79. {
  80. uint64_t offset = _curr_offset;
  81. _curr_offset += offset_inc;
  82. return std::pair<uint64_t, uint64_t>(_seed, offset);
  83. }
  84. void SetSeed(uint64_t new_seed) { _seed = new_seed; }
  85. void TestGemmFP16(bool test_gemm, int batch_size, int seq_len, int head_num, int size_per_head)
  86. {
  87. // avoid rerun.
  88. if (_gemm_algos.size() > 0) return;
  89. if (test_gemm) {
  90. cublasHandle_t handle = GetCublasHandle();
  91. std::unique_ptr<GemmTest<__half>> test_qkv_fw(
  92. new GemmTest<__half>(batch_size * seq_len, // M
  93. head_num * size_per_head, // N
  94. head_num * size_per_head, // K
  95. CUBLAS_OP_T,
  96. CUBLAS_OP_N,
  97. handle));
  98. std::unique_ptr<GemmTest<__half>> test_inter(
  99. new GemmTest<__half>(batch_size * seq_len, // M
  100. 4 * head_num * size_per_head, // N
  101. head_num * size_per_head, // K
  102. CUBLAS_OP_T,
  103. CUBLAS_OP_N,
  104. handle));
  105. std::unique_ptr<GemmTest<__half>> test_output(
  106. new GemmTest<__half>(batch_size * seq_len, // M
  107. head_num * size_per_head, // N
  108. 4 * head_num * size_per_head, // K
  109. CUBLAS_OP_T,
  110. CUBLAS_OP_N,
  111. handle));
  112. std::unique_ptr<StridedGemmTest<__half>> test_attn_scores(
  113. new StridedGemmTest<__half>(batch_size * head_num, // batch
  114. seq_len, // M
  115. seq_len, // N
  116. size_per_head, // K
  117. CUBLAS_OP_T,
  118. CUBLAS_OP_N,
  119. handle));
  120. std::unique_ptr<StridedGemmTest<__half>> test_attn_context(
  121. new StridedGemmTest<__half>(batch_size * head_num, // batch
  122. size_per_head, // M
  123. seq_len, // N
  124. seq_len, // K
  125. CUBLAS_OP_N,
  126. CUBLAS_OP_N,
  127. handle));
  128. _gemm_algos.push_back(test_qkv_fw->TestAlgo(100));
  129. _gemm_algos.push_back(test_inter->TestAlgo(100));
  130. _gemm_algos.push_back(test_output->TestAlgo(100));
  131. _gemm_algos.push_back(test_attn_scores->TestAlgo(100));
  132. _gemm_algos.push_back(test_attn_context->TestAlgo(100));
  133. } else {
  134. // Use default algo.
  135. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  136. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  137. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  138. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  139. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  140. }
  141. }
  142. const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
  143. private:
  144. curandGenerator_t _gen;
  145. cublasHandle_t _cublasHandle;
  146. void* _workspace;
  147. uint64_t _seed;
  148. uint64_t _curr_offset;
  149. std::vector<std::array<int, 3>> _gemm_algos;
  150. };