evoformer_attn.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import CUDAOpBuilder, installed_cuda_version
  5. import os
  6. class EvoformerAttnBuilder(CUDAOpBuilder):
  7. BUILD_VAR = "DS_BUILD_EVOFORMER_ATTN"
  8. NAME = "evoformer_attn"
  9. def __init__(self, name=None):
  10. name = self.NAME if name is None else name
  11. super().__init__(name=name)
  12. self.cutlass_path = os.environ.get('CUTLASS_PATH')
  13. def absolute_name(self):
  14. return f'deepspeed.ops.{self.NAME}_op'
  15. def extra_ldflags(self):
  16. if not self.is_rocm_pytorch():
  17. return ['-lcurand']
  18. else:
  19. return []
  20. def sources(self):
  21. src_dir = 'csrc/deepspeed4science/evoformer_attn'
  22. return [f'{src_dir}/attention.cpp', f'{src_dir}/attention_back.cu', f'{src_dir}/attention_cu.cu']
  23. def nvcc_args(self):
  24. args = super().nvcc_args()
  25. try:
  26. import torch
  27. except ImportError:
  28. self.warning("Please install torch if trying to pre-compile kernels")
  29. return args
  30. major = torch.cuda.get_device_properties(0).major #ignore-cuda
  31. minor = torch.cuda.get_device_properties(0).minor #ignore-cuda
  32. args.append(f"-DGPU_ARCH={major}{minor}")
  33. return args
  34. def is_compatible(self, verbose=True):
  35. try:
  36. import torch
  37. except ImportError:
  38. self.warning("Please install torch if trying to pre-compile kernels")
  39. return False
  40. if self.cutlass_path is None:
  41. self.warning("Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH")
  42. return False
  43. with open(f'{self.cutlass_path}/CHANGELOG.md', 'r') as f:
  44. if '3.1.0' not in f.read():
  45. self.warning("Please use CUTLASS version >= 3.1.0")
  46. return False
  47. cuda_okay = True
  48. if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
  49. sys_cuda_major, _ = installed_cuda_version()
  50. torch_cuda_major = int(torch.version.cuda.split('.')[0])
  51. cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
  52. if cuda_capability < 7:
  53. self.warning("Please use a GPU with compute capability >= 7.0")
  54. cuda_okay = False
  55. if torch_cuda_major < 11 or sys_cuda_major < 11:
  56. self.warning("Please use CUDA 11+")
  57. cuda_okay = False
  58. return super().is_compatible(verbose) and cuda_okay
  59. def include_paths(self):
  60. includes = [f'{self.cutlass_path}/include', f'{self.cutlass_path}/tools/util/include']
  61. return includes