dropout.h 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <stdio.h>
  8. template <typename T>
  9. class Dropout {
  10. public:
  11. struct Config {
  12. float ratio;
  13. uint32_t dim;
  14. bool training;
  15. Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
  16. float RATIO() const { return training ? ratio : 0.0; }
  17. inline void SetDim(uint32_t d) { dim = d; }
  18. };
  19. Dropout(const Config& config) : _config(config), _mask(nullptr) {}
  20. virtual ~Dropout() {}
  21. void Forward(int bsz, T* out, const T* vals, cudaStream_t stream, bool bwd = false)
  22. {
  23. launch_dropout<T>(
  24. out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd);
  25. }
  26. void ForwardWithBias(int bsz, T* vals, const T* bias, cudaStream_t stream)
  27. {
  28. launch_dropout<T>(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream);
  29. }
  30. void ForwardWithBias(int bsz,
  31. T* out,
  32. const T* vals,
  33. const T* residual,
  34. const T* bias,
  35. cudaStream_t stream)
  36. {
  37. launch_dropout<T>(
  38. out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream);
  39. }
  40. void Backward(int bsz, T* d_vals, cudaStream_t stream)
  41. {
  42. launch_dropout_grad<T>(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream);
  43. }
  44. void Backward(int bsz, T* d_vals_out, const T* d_vals, cudaStream_t stream)
  45. {
  46. launch_dropout_grad<T>(
  47. d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream);
  48. }
  49. bool HasDropout() const { return _config.RATIO() > 0.0; }
  50. void SetTrainingMode(bool training) { _config.training = training; }
  51. void SetMask(uint8_t* mask)
  52. {
  53. if (!mask) { throw std::runtime_error("Dropout mask is null."); }
  54. _mask = mask;
  55. }
  56. Config GetConfig() const { return _config; }
  57. inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
  58. private:
  59. uint8_t* _mask;
  60. Config _config;
  61. };