softmax.h 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #pragma once
  2. #include <cuda.h>
  3. #include <cuda_fp16.h>
  4. #include <stdio.h>
  5. #include "custom_cuda_layers.h"
  6. #include <fstream>
  7. using namespace std;
  8. template <typename T>
  9. class Softmax {
  10. public:
  11. struct Config {
  12. size_t batchSize;
  13. size_t heads;
  14. size_t seq_length;
  15. size_t prob_depth;
  16. float temperature;
  17. bool mem_alloc;
  18. Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false)
  19. : batchSize(batch),
  20. heads(h),
  21. seq_length(seq),
  22. prob_depth(prob_size),
  23. temperature(1.0),
  24. mem_alloc(mem_alloc)
  25. {
  26. }
  27. };
  28. Softmax(Config config) : config_(config) {}
  29. ~Softmax() {}
  30. void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream)
  31. {
  32. launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream);
  33. }
  34. void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream)
  35. {
  36. launch_attn_softmax_backward_v2<T>(
  37. out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
  38. }
  39. inline size_t GetProbDepth() const { return config_.prob_depth; }
  40. inline size_t GetBatchSize() const { return config_.batchSize; }
  41. inline size_t GetNumHeads() const { return config_.heads; }
  42. inline size_t GetSeqLength() const { return config_.seq_length; }
  43. inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
  44. private:
  45. Config config_;
  46. };