context.h 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include <cuda_runtime_api.h>
  7. #include <cassert>
  8. #include <iostream>
  9. #include <vector>
  10. #include "cublas_v2.h"
  11. #include "cuda.h"
  12. #include "curand.h"
  13. #include <cuda.h>
  14. #include <cuda_runtime_api.h>
  15. #include <stdlib.h>
  16. #include <sys/time.h>
  17. #include <map>
  18. #include <memory>
  19. #include <stack>
  20. #include <string>
  21. #define WARP_SIZE 32
  22. class FPContext {
  23. public:
  24. FPContext() : _seed(42)
  25. {
  26. curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
  27. curandSetPseudoRandomGeneratorSeed(_gen, 123);
  28. }
  29. virtual ~FPContext() {}
  30. static FPContext& Instance()
  31. {
  32. static FPContext _ctx;
  33. return _ctx;
  34. }
  35. curandGenerator_t& GetRandGenerator() { return _gen; }
  36. cudaStream_t GetCurrentStream()
  37. {
  38. // get current pytorch stream.
  39. cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  40. return stream;
  41. }
  42. std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
  43. {
  44. uint64_t offset = _curr_offset;
  45. _curr_offset += offset_inc;
  46. return std::pair<uint64_t, uint64_t>(_seed, offset);
  47. }
  48. void SetSeed(uint64_t new_seed) { _seed = new_seed; }
  49. private:
  50. curandGenerator_t _gen;
  51. cublasHandle_t _cublasHandle;
  52. uint64_t _seed;
  53. uint64_t _curr_offset;
  54. };