transformer_inference.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from .builder import CUDAOpBuilder, installed_cuda_version
  2. class InferenceBuilder(CUDAOpBuilder):
  3. BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
  4. NAME = "transformer_inference"
  5. def __init__(self, name=None):
  6. name = self.NAME if name is None else name
  7. super().__init__(name=name)
  8. def absolute_name(self):
  9. return f'deepspeed.ops.transformer.inference.{self.NAME}_op'
  10. def is_compatible(self, verbose=True):
  11. try:
  12. import torch
  13. except ImportError:
  14. self.warning(
  15. "Please install torch if trying to pre-compile inference kernels")
  16. return False
  17. cuda_okay = True
  18. if not self.is_rocm_pytorch() and torch.cuda.is_available():
  19. sys_cuda_major, _ = installed_cuda_version()
  20. torch_cuda_major = int(torch.version.cuda.split('.')[0])
  21. cuda_capability = torch.cuda.get_device_properties(0).major
  22. if cuda_capability >= 8:
  23. if torch_cuda_major < 11 or sys_cuda_major < 11:
  24. self.warning(
  25. "On Ampere and higher architectures please use CUDA 11+")
  26. cuda_okay = False
  27. return super().is_compatible(verbose) and cuda_okay
  28. def sources(self):
  29. return [
  30. 'csrc/transformer/inference/csrc/pt_binding.cpp',
  31. 'csrc/transformer/inference/csrc/gelu.cu',
  32. 'csrc/transformer/inference/csrc/relu.cu',
  33. 'csrc/transformer/inference/csrc/normalize.cu',
  34. 'csrc/transformer/inference/csrc/softmax.cu',
  35. 'csrc/transformer/inference/csrc/dequantize.cu',
  36. 'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu',
  37. 'csrc/transformer/inference/csrc/transform.cu',
  38. ]
  39. def extra_ldflags(self):
  40. if not self.is_rocm_pytorch():
  41. return ['-lcurand']
  42. else:
  43. return []
  44. def include_paths(self):
  45. return ['csrc/transformer/inference/includes']