transformer_inference.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from .builder import CUDAOpBuilder, installed_cuda_version
  5. class InferenceBuilder(CUDAOpBuilder):
  6. BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE"
  7. NAME = "transformer_inference"
  8. def __init__(self, name=None):
  9. name = self.NAME if name is None else name
  10. super().__init__(name=name)
  11. def absolute_name(self):
  12. return f'deepspeed.ops.transformer.inference.{self.NAME}_op'
  13. def is_compatible(self, verbose=False):
  14. try:
  15. import torch
  16. except ImportError:
  17. if verbose:
  18. self.warning("Please install torch if trying to pre-compile inference kernels")
  19. return False
  20. cuda_okay = True
  21. if not self.is_rocm_pytorch() and torch.cuda.is_available():
  22. sys_cuda_major, _ = installed_cuda_version()
  23. torch_cuda_major = int(torch.version.cuda.split('.')[0])
  24. cuda_capability = torch.cuda.get_device_properties(0).major
  25. if cuda_capability < 6:
  26. if verbose:
  27. self.warning("NVIDIA Inference is only supported on Pascal and newer architectures")
  28. cuda_okay = False
  29. if cuda_capability >= 8:
  30. if torch_cuda_major < 11 or sys_cuda_major < 11:
  31. if verbose:
  32. self.warning("On Ampere and higher architectures please use CUDA 11+")
  33. cuda_okay = False
  34. return super().is_compatible(verbose) and cuda_okay
  35. def filter_ccs(self, ccs):
  36. ccs_retained = []
  37. ccs_pruned = []
  38. for cc in ccs:
  39. if int(cc[0]) >= 6:
  40. ccs_retained.append(cc)
  41. else:
  42. ccs_pruned.append(cc)
  43. if len(ccs_pruned) > 0:
  44. self.warning(f"Filtered compute capabilities {ccs_pruned}")
  45. return ccs_retained
  46. def sources(self):
  47. return [
  48. 'csrc/transformer/inference/csrc/pt_binding.cpp',
  49. 'csrc/transformer/inference/csrc/gelu.cu',
  50. 'csrc/transformer/inference/csrc/relu.cu',
  51. 'csrc/transformer/inference/csrc/layer_norm.cu',
  52. 'csrc/transformer/inference/csrc/rms_norm.cu',
  53. 'csrc/transformer/inference/csrc/softmax.cu',
  54. 'csrc/transformer/inference/csrc/dequantize.cu',
  55. 'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu',
  56. 'csrc/transformer/inference/csrc/transform.cu',
  57. 'csrc/transformer/inference/csrc/pointwise_ops.cu',
  58. ]
  59. def extra_ldflags(self):
  60. if not self.is_rocm_pytorch():
  61. return ['-lcurand']
  62. else:
  63. return []
  64. def include_paths(self):
  65. return ['csrc/transformer/inference/includes', 'csrc/includes']