context.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <cuda_runtime_api.h>
  4. #include <cassert>
  5. #include <iostream>
  6. #include <vector>
  7. #include "cublas_v2.h"
  8. #include "cuda.h"
  9. #include "curand.h"
  10. #define WARP_SIZE 32
  11. #define CUDA_CHECK(callstr) \
  12. { \
  13. cudaError_t error_code = callstr; \
  14. if (error_code != cudaSuccess) { \
  15. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  16. assert(0); \
  17. } \
  18. }
  19. #define CUDA_1D_KERNEL_LOOP(i, n) \
  20. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
  21. #define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
  22. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
  23. for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
  24. #define DS_CUDA_NUM_THREADS 512
  25. #define DS_MAXIMUM_NUM_BLOCKS 262144
  26. inline int DS_GET_BLOCKS(const int N)
  27. {
  28. return std::max(
  29. std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
  30. // Use at least 1 block, since CUDA does not allow empty block
  31. 1);
  32. }
  33. class Context {
  34. public:
  35. Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
  36. {
  37. curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
  38. curandSetPseudoRandomGeneratorSeed(_gen, 123);
  39. if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
  40. auto message = std::string("Fail to create cublas handle.");
  41. std::cerr << message << std::endl;
  42. throw std::runtime_error(message);
  43. }
  44. cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
  45. }
  46. virtual ~Context()
  47. {
  48. cublasDestroy(_cublasHandle);
  49. cudaFree(_workspace);
  50. }
  51. static Context& Instance()
  52. {
  53. static Context _ctx;
  54. return _ctx;
  55. }
  56. void GenWorkSpace(size_t size)
  57. {
  58. if (!_workspace) {
  59. assert(_workspace == nullptr);
  60. cudaMalloc(&_workspace, size);
  61. } else if (_workSpaceSize < size) {
  62. cudaFree(_workspace);
  63. cudaMalloc(&_workspace, size);
  64. }
  65. _workSpaceSize = size;
  66. }
  67. void* GetWorkSpace() { return _workspace; }
  68. curandGenerator_t& GetRandGenerator() { return _gen; }
  69. cudaStream_t GetCurrentStream()
  70. {
  71. // get current pytorch stream.
  72. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  73. return stream;
  74. }
  75. cublasHandle_t GetCublasHandle() { return _cublasHandle; }
  76. std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
  77. {
  78. uint64_t offset = _curr_offset;
  79. _curr_offset += offset_inc;
  80. return std::pair<uint64_t, uint64_t>(_seed, offset);
  81. }
  82. void SetSeed(uint64_t new_seed) { _seed = new_seed; }
  83. const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
  84. private:
  85. curandGenerator_t _gen;
  86. cublasHandle_t _cublasHandle;
  87. void* _workspace;
  88. uint64_t _seed;
  89. uint64_t _curr_offset;
  90. size_t _workSpaceSize;
  91. std::vector<std::array<int, 3>> _gemm_algos;
  92. };