softmax.h 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. // Copyright (c) Microsoft Corporation.
  2. // SPDX-License-Identifier: Apache-2.0
  3. // DeepSpeed Team
  4. #pragma once
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <stdio.h>
  8. #include "custom_cuda_layers.h"
  9. #include <fstream>
  10. using namespace std;
  11. template <typename T>
  12. class Softmax {
  13. public:
  14. struct Config {
  15. size_t batchSize;
  16. size_t heads;
  17. size_t seq_length;
  18. size_t prob_depth;
  19. float temperature;
  20. bool mem_alloc;
  21. Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false)
  22. : batchSize(batch),
  23. heads(h),
  24. seq_length(seq),
  25. prob_depth(prob_size),
  26. temperature(1.0),
  27. mem_alloc(mem_alloc)
  28. {
  29. }
  30. };
  31. Softmax(Config config) : config_(config) {}
  32. ~Softmax() {}
  33. void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream)
  34. {
  35. launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream);
  36. }
  37. void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream)
  38. {
  39. launch_attn_softmax_backward_v2<T>(
  40. out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
  41. }
  42. inline size_t GetProbDepth() const { return config_.prob_depth; }
  43. inline size_t GetBatchSize() const { return config_.batchSize; }
  44. inline size_t GetNumHeads() const { return config_.heads; }
  45. inline size_t GetSeqLength() const { return config_.seq_length; }
  46. inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
  47. private:
  48. Config config_;
  49. };