spatial_inference.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import CUDAOpBuilder, installed_cuda_version
  5. class SpatialInferenceBuilder(CUDAOpBuilder):
  6. BUILD_VAR = "DS_BUILD_SPATIAL_INFERENCE"
  7. NAME = "spatial_inference"
  8. def __init__(self, name=None):
  9. name = self.NAME if name is None else name
  10. super().__init__(name=name)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.spatial.{self.NAME}_op'
  13. def is_compatible(self, verbose=True):
  14. try:
  15. import torch
  16. except ImportError:
  17. self.warning("Please install torch if trying to pre-compile inference kernels")
  18. return False
  19. cuda_okay = True
  20. if not self.is_rocm_pytorch() and torch.cuda.is_available():
  21. sys_cuda_major, _ = installed_cuda_version()
  22. torch_cuda_major = int(torch.version.cuda.split('.')[0])
  23. cuda_capability = torch.cuda.get_device_properties(0).major
  24. if cuda_capability >= 8:
  25. if torch_cuda_major < 11 or sys_cuda_major < 11:
  26. self.warning("On Ampere and higher architectures please use CUDA 11+")
  27. cuda_okay = False
  28. return super().is_compatible(verbose) and cuda_okay
  29. def sources(self):
  30. return [
  31. 'csrc/spatial/csrc/opt_bias_add.cu',
  32. 'csrc/spatial/csrc/pt_binding.cpp',
  33. ]
  34. def include_paths(self):
  35. return ['csrc/spatial/includes', 'csrc/includes']