sparse_attn.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import OpBuilder
  5. try:
  6. from packaging import version as pkg_version
  7. except ImportError:
  8. pkg_version = None
  9. class SparseAttnBuilder(OpBuilder):
  10. BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
  11. NAME = "sparse_attn"
  12. def __init__(self):
  13. super().__init__(name=self.NAME)
  14. def absolute_name(self):
  15. return f'deepspeed.ops.sparse_attention.{self.NAME}_op'
  16. def sources(self):
  17. return ['csrc/sparse_attention/utils.cpp']
  18. def cxx_args(self):
  19. return ['-O2', '-fopenmp']
  20. def is_compatible(self, verbose=True):
  21. # Check to see if llvm and cmake are installed since they are dependencies
  22. #required_commands = ['llvm-config|llvm-config-9', 'cmake']
  23. #command_status = list(map(self.command_exists, required_commands))
  24. #deps_compatible = all(command_status)
  25. if self.is_rocm_pytorch():
  26. self.warning(f'{self.NAME} is not compatible with ROCM')
  27. return False
  28. try:
  29. import torch
  30. except ImportError:
  31. self.warning(f"unable to import torch, please install it first")
  32. return False
  33. # torch-cpu will not have a cuda version
  34. if torch.version.cuda is None:
  35. cuda_compatible = False
  36. self.warning(f"{self.NAME} cuda is not available from torch")
  37. else:
  38. major, minor = torch.version.cuda.split('.')[:2]
  39. cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11)
  40. if not cuda_compatible:
  41. self.warning(f"{self.NAME} requires CUDA version 10.1+")
  42. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  43. TORCH_MINOR = int(torch.__version__.split('.')[1])
  44. torch_compatible = (TORCH_MAJOR == 1 and TORCH_MINOR >= 5)
  45. if not torch_compatible:
  46. self.warning(
  47. f'{self.NAME} requires a torch version >= 1.5 and < 2.0 but detected {TORCH_MAJOR}.{TORCH_MINOR}')
  48. try:
  49. import triton
  50. except ImportError:
  51. # auto-install of triton is broken on some systems, reverting to manual install for now
  52. # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710
  53. self.warning(f"please install triton==1.0.0 if you want to use sparse attention")
  54. return False
  55. if pkg_version:
  56. installed_triton = pkg_version.parse(triton.__version__)
  57. triton_mismatch = installed_triton != pkg_version.parse("1.0.0")
  58. else:
  59. installed_triton = triton.__version__
  60. triton_mismatch = installed_triton != "1.0.0"
  61. if triton_mismatch:
  62. self.warning(f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible")
  63. return False
  64. return super().is_compatible(verbose) and torch_compatible and cuda_compatible