inference_context.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <c10/cuda/CUDAStream.h>
  6. #include <cuda_runtime_api.h>
  7. #include <cassert>
  8. #include <iostream>
  9. #include <vector>
  10. #include "cublas_v2.h"
  11. #include "cuda.h"
  12. #define MEGABYTE (1024 * 1024)
  13. #define GIGABYTE (1024 * 1024 * 1024)
  14. // TODO: refactor out
  15. #define WARP_SIZE 32
  16. #define CUDA_CHECK(callstr) \
  17. { \
  18. cudaError_t error_code = callstr; \
  19. if (error_code != cudaSuccess) { \
  20. std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
  21. assert(0); \
  22. } \
  23. }
  24. #define CUDA_1D_KERNEL_LOOP(i, n) \
  25. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
  26. #define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
  27. for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
  28. for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
  29. #define DS_CUDA_NUM_THREADS 512
  30. #define DS_MAXIMUM_NUM_BLOCKS 262144
  31. inline int DS_GET_BLOCKS(const int N)
  32. {
  33. return std::max(
  34. std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
  35. // Use at least 1 block, since CUDA does not allow empty block
  36. 1);
  37. }
  38. class InferenceContext {
  39. public:
  40. InferenceContext()
  41. : _workspace(nullptr),
  42. _seed(42),
  43. _curr_offset(0),
  44. _stream(0),
  45. _free_memory_size(0),
  46. _num_tokens(1),
  47. _attention_unfused_workspace_offset(0),
  48. _workSpaceSize(0)
  49. {
  50. _workSpaceSize = 0;
  51. _workspace = 0;
  52. if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
  53. auto message = std::string("Fail to create cublas handle.");
  54. std::cerr << message << std::endl;
  55. throw std::runtime_error(message);
  56. }
  57. #ifndef __HIP_PLATFORM_HCC__
  58. cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
  59. #endif
  60. cudaEventCreate(&_comp1_event);
  61. cudaEventCreate(&_comp2_event);
  62. cudaEventCreate(&_comp_event);
  63. cudaEventCreate(&_comm_event);
  64. }
  65. virtual ~InferenceContext()
  66. {
  67. cublasDestroy(_cublasHandle);
  68. cudaFree(_workspace);
  69. cudaEventDestroy(_comp1_event);
  70. cudaEventDestroy(_comp2_event);
  71. cudaEventDestroy(_comp_event);
  72. cudaEventDestroy(_comm_event);
  73. }
  74. static InferenceContext& Instance()
  75. {
  76. static InferenceContext _ctx;
  77. return _ctx;
  78. }
  79. void GenWorkSpace(const unsigned& num_layers,
  80. const unsigned& num_heads,
  81. const size_t& batch_size,
  82. const size_t& prompt_len,
  83. const size_t& hidden_dim,
  84. const unsigned& mp_size,
  85. const bool& external_cache,
  86. const size_t& elem_size,
  87. const unsigned& rank,
  88. unsigned max_out_tokens,
  89. unsigned min_out_tokens)
  90. {
  91. size_t total_size;
  92. if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }
  93. // Flash attention requires padded heads and we'll conservatively allocate
  94. // for that here. Flash attention is only enabled for head size <= 128 right now
  95. const int head_size = hidden_dim / num_heads;
  96. const int padded_head_size = head_size <= 32 ? 32 : (head_size <= 64 ? 64 : 128);
  97. const int effective_head_size = (head_size > 128) ? head_size : padded_head_size;
  98. size_t activation_size = 10 * (num_heads * effective_head_size) * batch_size;
  99. // Other sequence length dimension is added when the final workSpaceSize is calculated
  100. size_t temp_size = batch_size * (num_heads / mp_size) * max_out_tokens;
  101. size_t cache_size =
  102. num_layers * batch_size * ((num_heads * effective_head_size) / mp_size) * 2;
  103. size_t minimal_requirements =
  104. temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE;
  105. if (_free_memory_size < minimal_requirements) {
  106. printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
  107. minimal_requirements,
  108. _free_memory_size,
  109. total_size);
  110. throw std::runtime_error("Workspace can't be allocated, no enough memory.");
  111. }
  112. _max_seq_len = ((_free_memory_size - minimal_requirements) / elem_size) /
  113. (activation_size + temp_size + cache_size);
  114. _max_seq_len = std::min((size_t)max_out_tokens, _max_seq_len);
  115. size_t workSpaceSize = ((external_cache ? (activation_size + temp_size)
  116. : (activation_size + temp_size + cache_size))) *
  117. _max_seq_len * elem_size;
  118. temp_size *= _max_seq_len * elem_size;
  119. if (_max_seq_len < min_out_tokens) {
  120. printf(
  121. "Allocatable workspace available (%ld tokens) is less than minimum requested "
  122. "workspace (%d tokens)\n",
  123. _max_seq_len,
  124. min_out_tokens);
  125. throw std::runtime_error("Workspace can't be allocated, not enough memory");
  126. }
  127. if (!_workspace) {
  128. assert(_workspace == nullptr);
  129. cudaMalloc(&_workspace, workSpaceSize);
  130. } else if (_workSpaceSize < workSpaceSize) {
  131. cudaFree(_workspace);
  132. cudaMalloc(&_workspace, workSpaceSize);
  133. }
  134. if (rank == 0 && (!_workspace || _workSpaceSize < workSpaceSize))
  135. printf(
  136. "------------------------------------------------------\n"
  137. "Free memory : %f (GigaBytes) \n"
  138. "Total memory: %f (GigaBytes) \n"
  139. "Requested memory: %f (GigaBytes) \n"
  140. "Setting maximum total tokens (input + output) to %lu \n"
  141. "WorkSpace: %p \n"
  142. "------------------------------------------------------\n",
  143. (float)_free_memory_size / GIGABYTE,
  144. (float)total_size / GIGABYTE,
  145. (float)workSpaceSize / GIGABYTE,
  146. _max_seq_len,
  147. _workspace);
  148. if (!_workspace) {
  149. printf("Requested:\t%lu\nFree:\t%lu\nTotal:\t%lu\n",
  150. workSpaceSize,
  151. _free_memory_size,
  152. total_size);
  153. throw std::runtime_error("Workspace is null.");
  154. }
  155. _workSpaceSize = workSpaceSize;
  156. _attention_unfused_workspace_offset = workSpaceSize - temp_size;
  157. }
  158. inline size_t GetMaxTokenLength() const { return _max_seq_len; }
  159. cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
  160. size_t get_workspace_size() const { return _workSpaceSize; }
  161. void* GetWorkSpace() { return _workspace; }
  162. void* GetAttentionUnfusedWorkspace()
  163. {
  164. return (char*)_workspace + _attention_unfused_workspace_offset;
  165. }
  166. inline unsigned new_token(unsigned layer_id)
  167. {
  168. if (layer_id == 0) _token_length++;
  169. return _token_length;
  170. }
  171. inline void reset_tokens(unsigned initial_tokens = 1)
  172. {
  173. _num_tokens = initial_tokens;
  174. } //_token_length = 0; }
  175. inline unsigned current_tokens() const { return _num_tokens; }
  176. inline void advance_tokens() { _num_tokens++; }
  177. cudaStream_t GetCommStream(bool async_op = false)
  178. {
  179. if (!_comm_stream)
  180. _comm_stream = async_op ? at::cuda::getStreamFromPool(true)
  181. : at::cuda::getCurrentCUDAStream();
  182. return _comm_stream;
  183. }
  184. cudaStream_t GetCurrentStream(bool other_stream = false)
  185. {
  186. // get current pytorch stream.
  187. if (other_stream) {
  188. if (!_stream) _stream = at::cuda::getStreamFromPool(true);
  189. return _stream;
  190. }
  191. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  192. return stream;
  193. }
  194. void release_workspace()
  195. {
  196. cudaFree(_workspace);
  197. _workspace = nullptr;
  198. }
  199. bool retake_workspace()
  200. {
  201. if (_workspace != nullptr || _workSpaceSize == 0) return true;
  202. cudaMalloc(&_workspace, _workSpaceSize);
  203. return _workspace != nullptr;
  204. }
  205. cublasHandle_t GetCublasHandle() { return _cublasHandle; }
  206. std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
  207. {
  208. uint64_t offset = _curr_offset;
  209. _curr_offset += offset_inc;
  210. return std::pair<uint64_t, uint64_t>(_seed, offset);
  211. }
  212. void SetSeed(uint64_t new_seed) { _seed = new_seed; }
  213. const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
  214. inline void SynchComp()
  215. {
  216. cudaEventRecord(_comp_event, _comp_stream);
  217. cudaStreamWaitEvent(_comm_stream, _comp_event, 0);
  218. }
  219. inline void SynchComm()
  220. {
  221. cudaEventRecord(_comm_event, _comm_stream);
  222. cudaStreamWaitEvent(_comp_stream, _comm_event, 0);
  223. }
  224. private:
  225. cublasHandle_t _cublasHandle;
  226. cudaEvent_t _comp_event;
  227. cudaEvent_t _comm_event;
  228. void* _workspace;
  229. // offset from _workspace for attention unfused memory
  230. size_t _attention_unfused_workspace_offset;
  231. uint64_t _seed;
  232. uint64_t _curr_offset;
  233. size_t _workSpaceSize;
  234. size_t _free_memory_size;
  235. size_t _max_seq_len;
  236. cudaEvent_t _comp1_event;
  237. cudaEvent_t _comp2_event;
  238. cudaStream_t _stream;
  239. unsigned _token_length;
  240. unsigned _num_tokens;
  241. std::vector<std::array<int, 3>> _gemm_algos;
  242. cudaStream_t _comp_stream;
  243. cudaStream_t _comm_stream;
  244. std::unordered_map<int, int> _world_sizes;
  245. };