sparse_attn.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import warnings
  5. from .builder import OpBuilder
  6. try:
  7. from packaging import version as pkg_version
  8. except ImportError:
  9. pkg_version = None
  10. class SparseAttnBuilder(OpBuilder):
  11. BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
  12. NAME = "sparse_attn"
  13. def __init__(self):
  14. super().__init__(name=self.NAME)
  15. def absolute_name(self):
  16. return f'deepspeed.ops.sparse_attention.{self.NAME}_op'
  17. def sources(self):
  18. return ['csrc/sparse_attention/utils.cpp']
  19. def cxx_args(self):
  20. return ['-O2', '-fopenmp']
  21. def is_compatible(self, verbose=True):
  22. # Check to see if llvm and cmake are installed since they are dependencies
  23. #required_commands = ['llvm-config|llvm-config-9', 'cmake']
  24. #command_status = list(map(self.command_exists, required_commands))
  25. #deps_compatible = all(command_status)
  26. if self.is_rocm_pytorch():
  27. self.warning(f'{self.NAME} is not compatible with ROCM')
  28. return False
  29. try:
  30. import torch
  31. except ImportError:
  32. self.warning(f"unable to import torch, please install it first")
  33. return False
  34. # torch-cpu will not have a cuda version
  35. if torch.version.cuda is None:
  36. cuda_compatible = False
  37. self.warning(f"{self.NAME} cuda is not available from torch")
  38. else:
  39. major, minor = torch.version.cuda.split('.')[:2]
  40. cuda_compatible = (int(major) == 10
  41. and int(minor) >= 1) or (int(major) >= 11)
  42. if not cuda_compatible:
  43. self.warning(f"{self.NAME} requires CUDA version 10.1+")
  44. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  45. TORCH_MINOR = int(torch.__version__.split('.')[1])
  46. torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5
  47. if not torch_compatible:
  48. self.warning(
  49. f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
  50. )
  51. try:
  52. import triton
  53. except ImportError:
  54. # auto-install of triton is broken on some systems, reverting to manual install for now
  55. # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710
  56. self.warning(
  57. f"please install triton==1.0.0 if you want to use sparse attention")
  58. return False
  59. if pkg_version:
  60. installed_triton = pkg_version.parse(triton.__version__)
  61. triton_mismatch = installed_triton != pkg_version.parse("1.0.0")
  62. else:
  63. installed_triton = triton.__version__
  64. triton_mismatch = installed_triton != "1.0.0"
  65. if triton_mismatch:
  66. self.warning(
  67. f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible"
  68. )
  69. return False
  70. return super().is_compatible(verbose) and torch_compatible and cuda_compatible