transformer.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import CUDAOpBuilder
  5. class TransformerBuilder(CUDAOpBuilder):
  6. BUILD_VAR = "DS_BUILD_TRANSFORMER"
  7. NAME = "transformer"
  8. def __init__(self, name=None):
  9. name = self.NAME if name is None else name
  10. super().__init__(name=name)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.transformer.{self.NAME}_op'
  13. def extra_ldflags(self):
  14. if not self.is_rocm_pytorch():
  15. return ['-lcurand']
  16. else:
  17. return []
  18. def sources(self):
  19. return [
  20. 'csrc/transformer/ds_transformer_cuda.cpp', 'csrc/transformer/cublas_wrappers.cu',
  21. 'csrc/transformer/transform_kernels.cu', 'csrc/transformer/gelu_kernels.cu',
  22. 'csrc/transformer/dropout_kernels.cu', 'csrc/transformer/normalize_kernels.cu',
  23. 'csrc/transformer/softmax_kernels.cu', 'csrc/transformer/general_kernels.cu'
  24. ]
  25. def include_paths(self):
  26. includes = ['csrc/includes']
  27. return includes