dropout.h 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. template <typename T>
  6. class Dropout {
  7. public:
  8. struct Config {
  9. float ratio;
  10. uint32_t dim;
  11. bool training;
  12. Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {}
  13. float RATIO() const { return training ? ratio : 0.0; }
  14. inline void SetDim(uint32_t d) { dim = d; }
  15. };
  16. Dropout(const Config& config) : _config(config), _mask(nullptr) {}
  17. virtual ~Dropout() {}
  18. void Forward(int bsz, T* out, const T* vals, cudaStream_t stream, bool bwd = false)
  19. {
  20. launch_dropout<T>(
  21. out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd);
  22. }
  23. void ForwardWithBias(int bsz, T* vals, const T* bias, cudaStream_t stream)
  24. {
  25. launch_dropout<T>(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream);
  26. }
  27. void ForwardWithBias(int bsz,
  28. T* out,
  29. const T* vals,
  30. const T* residual,
  31. const T* bias,
  32. cudaStream_t stream)
  33. {
  34. launch_dropout<T>(
  35. out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream);
  36. }
  37. void Backward(int bsz, T* d_vals, cudaStream_t stream)
  38. {
  39. launch_dropout_grad<T>(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream);
  40. }
  41. void Backward(int bsz, T* d_vals_out, const T* d_vals, cudaStream_t stream)
  42. {
  43. launch_dropout_grad<T>(
  44. d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream);
  45. }
  46. bool HasDropout() const { return _config.RATIO() > 0.0; }
  47. void SetTrainingMode(bool training) { _config.training = training; }
  48. void SetMask(uint8_t* mask)
  49. {
  50. if (!mask) { throw std::runtime_error("Dropout mask is null."); }
  51. _mask = mask;
  52. }
  53. Config GetConfig() const { return _config; }
  54. inline void SetDimension(uint32_t dim) { _config.SetDim(dim); }
  55. private:
  56. uint8_t* _mask;
  57. Config _config;
  58. };