context.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <cuda_runtime_api.h>
  4. #include <cassert>
  5. #include <iostream>
  6. #include <vector>
  7. #include "cublas_v2.h"
  8. #include "cuda.h"
  9. #include "curand.h"
  10. #include "gemm_test.h"
  11. #define WARP_SIZE 32
  12. #define CUDA_CHECK(callstr) \
  13. { \
  14. cudaError_t error_code = callstr; \
  15. if (error_code != cudaSuccess) { \
  16. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  17. assert(0); \
  18. } \
  19. }
  20. #define CUDA_1D_KERNEL_LOOP(i, n) \
  21. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
  22. #define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
  23. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
  24. for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
  25. #define DS_CUDA_NUM_THREADS 512
  26. #define DS_MAXIMUM_NUM_BLOCKS 262144
  27. inline int DS_GET_BLOCKS(const int N)
  28. {
  29. return (std::max)(
  30. (std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
  31. // Use at least 1 block, since CUDA does not allow empty block
  32. 1);
  33. }
  34. class Context {
  35. public:
  36. Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
  37. {
  38. curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
  39. curandSetPseudoRandomGeneratorSeed(_gen, 123);
  40. if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
  41. auto message = std::string("Fail to create cublas handle.");
  42. std::cerr << message << std::endl;
  43. throw std::runtime_error(message);
  44. }
  45. }
  46. virtual ~Context()
  47. {
  48. cublasDestroy(_cublasHandle);
  49. cudaFree(_workspace);
  50. }
  51. static Context& Instance()
  52. {
  53. static Context _ctx;
  54. return _ctx;
  55. }
  56. void SetWorkSpace(void* workspace)
  57. {
  58. if (!workspace) { throw std::runtime_error("Workspace is null."); }
  59. _workspace = workspace;
  60. }
  61. void* GetWorkSpace() { return _workspace; }
  62. curandGenerator_t& GetRandGenerator() { return _gen; }
  63. cudaStream_t GetCurrentStream()
  64. {
  65. // get current pytorch stream.
  66. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  67. return stream;
  68. }
  69. cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); }
  70. cublasHandle_t GetCublasHandle() { return _cublasHandle; }
  71. std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
  72. {
  73. uint64_t offset = _curr_offset;
  74. _curr_offset += offset_inc;
  75. return std::pair<uint64_t, uint64_t>(_seed, offset);
  76. }
  77. void SetSeed(uint64_t new_seed) { _seed = new_seed; }
  78. void TestGemmFP16(bool test_gemm, int batch_size, int seq_len, int head_num, int size_per_head)
  79. {
  80. // avoid rerun.
  81. if (_gemm_algos.size() > 0) return;
  82. if (test_gemm) {
  83. cublasHandle_t handle = GetCublasHandle();
  84. std::unique_ptr<GemmTest<__half>> test_qkv_fw(
  85. new GemmTest<__half>(batch_size * seq_len, // M
  86. head_num * size_per_head, // N
  87. head_num * size_per_head, // K
  88. CUBLAS_OP_T,
  89. CUBLAS_OP_N,
  90. handle));
  91. std::unique_ptr<GemmTest<__half>> test_inter(
  92. new GemmTest<__half>(batch_size * seq_len, // M
  93. 4 * head_num * size_per_head, // N
  94. head_num * size_per_head, // K
  95. CUBLAS_OP_T,
  96. CUBLAS_OP_N,
  97. handle));
  98. std::unique_ptr<GemmTest<__half>> test_output(
  99. new GemmTest<__half>(batch_size * seq_len, // M
  100. head_num * size_per_head, // N
  101. 4 * head_num * size_per_head, // K
  102. CUBLAS_OP_T,
  103. CUBLAS_OP_N,
  104. handle));
  105. std::unique_ptr<StridedGemmTest<__half>> test_attn_scores(
  106. new StridedGemmTest<__half>(batch_size * head_num, // batch
  107. seq_len, // M
  108. seq_len, // N
  109. size_per_head, // K
  110. CUBLAS_OP_T,
  111. CUBLAS_OP_N,
  112. handle));
  113. std::unique_ptr<StridedGemmTest<__half>> test_attn_context(
  114. new StridedGemmTest<__half>(batch_size * head_num, // batch
  115. size_per_head, // M
  116. seq_len, // N
  117. seq_len, // K
  118. CUBLAS_OP_N,
  119. CUBLAS_OP_N,
  120. handle));
  121. _gemm_algos.push_back(test_qkv_fw->TestAlgo(100));
  122. _gemm_algos.push_back(test_inter->TestAlgo(100));
  123. _gemm_algos.push_back(test_output->TestAlgo(100));
  124. _gemm_algos.push_back(test_attn_scores->TestAlgo(100));
  125. _gemm_algos.push_back(test_attn_context->TestAlgo(100));
  126. } else {
  127. // Use default algo.
  128. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  129. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  130. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  131. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  132. _gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
  133. }
  134. }
  135. const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
  136. private:
  137. curandGenerator_t _gen;
  138. cublasHandle_t _cublasHandle;
  139. void* _workspace;
  140. uint64_t _seed;
  141. uint64_t _curr_offset;
  142. std::vector<std::array<int, 3>> _gemm_algos;
  143. };