123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- #pragma once
- #include <ATen/cuda/CUDAContext.h>
- #include <cuda_runtime_api.h>
- #include <cassert>
- #include <iostream>
- #include <vector>
- #include "cublas_v2.h"
- #include "cuda.h"
- #include "curand.h"
- #define WARP_SIZE 32
- #define CUDA_CHECK(callstr) \
- { \
- cudaError_t error_code = callstr; \
- if (error_code != cudaSuccess) { \
- std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
- assert(0); \
- } \
- }
- #define CUDA_1D_KERNEL_LOOP(i, n) \
- for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
- #define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
- for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
- for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
- #define DS_CUDA_NUM_THREADS 512
- #define DS_MAXIMUM_NUM_BLOCKS 262144
- inline int DS_GET_BLOCKS(const int N)
- {
- return std::max(
- std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
- // Use at least 1 block, since CUDA does not allow empty block
- 1);
- }
- class Context {
- public:
- Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
- {
- curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
- curandSetPseudoRandomGeneratorSeed(_gen, 123);
- if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
- auto message = std::string("Fail to create cublas handle.");
- std::cerr << message << std::endl;
- throw std::runtime_error(message);
- }
- cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
- }
- virtual ~Context()
- {
- cublasDestroy(_cublasHandle);
- cudaFree(_workspace);
- }
- static Context& Instance()
- {
- static Context _ctx;
- return _ctx;
- }
- void GenWorkSpace(size_t size)
- {
- if (!_workspace) {
- assert(_workspace == nullptr);
- cudaMalloc(&_workspace, size);
- } else if (_workSpaceSize < size) {
- cudaFree(_workspace);
- cudaMalloc(&_workspace, size);
- }
- _workSpaceSize = size;
- }
- void* GetWorkSpace() { return _workspace; }
- curandGenerator_t& GetRandGenerator() { return _gen; }
- cudaStream_t GetCurrentStream()
- {
- // get current pytorch stream.
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- return stream;
- }
- cublasHandle_t GetCublasHandle() { return _cublasHandle; }
- std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
- {
- uint64_t offset = _curr_offset;
- _curr_offset += offset_inc;
- return std::pair<uint64_t, uint64_t>(_seed, offset);
- }
- void SetSeed(uint64_t new_seed) { _seed = new_seed; }
- const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
- private:
- curandGenerator_t _gen;
- cublasHandle_t _cublasHandle;
- void* _workspace;
- uint64_t _seed;
- uint64_t _curr_offset;
- size_t _workSpaceSize;
- std::vector<std::array<int, 3>> _gemm_algos;
- };
|