fp_quantizer.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. try:
  5. from packaging import version as pkg_version
  6. except ImportError:
  7. pkg_version = None
  8. from .builder import CUDAOpBuilder, installed_cuda_version
  9. class FPQuantizerBuilder(CUDAOpBuilder):
  10. BUILD_VAR = "DS_BUILD_FP_QUANTIZER"
  11. NAME = "fp_quantizer"
  12. def __init__(self, name=None):
  13. name = self.NAME if name is None else name
  14. super().__init__(name=name)
  15. def absolute_name(self):
  16. return f'deepspeed.ops.fp_quantizer.{self.NAME}_op'
  17. def is_compatible(self, verbose=False):
  18. try:
  19. import torch
  20. except ImportError:
  21. if verbose:
  22. self.warning("Please install torch if trying to pre-compile inference kernels")
  23. return False
  24. cuda_okay = True
  25. if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
  26. sys_cuda_major, _ = installed_cuda_version()
  27. torch_cuda_major = int(torch.version.cuda.split('.')[0])
  28. cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
  29. if cuda_capability < 8:
  30. if verbose:
  31. self.warning("NVIDIA Inference is only supported on Ampere and newer architectures")
  32. cuda_okay = False
  33. if cuda_capability >= 8:
  34. if torch_cuda_major < 11 or sys_cuda_major < 11:
  35. if verbose:
  36. self.warning("On Ampere and higher architectures please use CUDA 11+")
  37. cuda_okay = False
  38. try:
  39. import triton
  40. except ImportError:
  41. if verbose:
  42. self.warning(
  43. f"please install triton==2.3.0, 2.3.1 or 3.0.0 if you want to use the FP Quantizer Kernels")
  44. return False
  45. # triton 2.3.{0,1} and 3.0.0 are ok.
  46. allowed_versions = ("2.3", "3.0")
  47. if pkg_version:
  48. allowed = (pkg_version.parse(v) for v in allowed_versions)
  49. installed_triton = pkg_version.parse(triton.__version__)
  50. triton_mismatch = all(installed_triton.major != a.major or installed_triton.minor != a.minor
  51. for a in allowed)
  52. else:
  53. installed_triton = triton.__version__
  54. major, minor, _ = installed_triton.split(".")
  55. allowed = (v.split(".") for v in allowed_versions)
  56. triton_mismatch = all(major != v[0] or minor != v[1] for v in allowed)
  57. if triton_mismatch:
  58. if verbose:
  59. self.warning(
  60. f"FP Quantizer is using an untested triton version ({installed_triton}), only 2.3.{0,1} and 3.0.0 are known to be compatible with these kernels"
  61. )
  62. return False
  63. return super().is_compatible(verbose) and cuda_okay
  64. def filter_ccs(self, ccs):
  65. ccs_retained = []
  66. ccs_pruned = []
  67. for cc in ccs:
  68. if int(cc[0]) >= 8:
  69. ccs_retained.append(cc)
  70. else:
  71. ccs_pruned.append(cc)
  72. if len(ccs_pruned) > 0:
  73. self.warning(f"Filtered compute capabilities {ccs_pruned}")
  74. return ccs_retained
  75. def sources(self):
  76. return [
  77. "csrc/fp_quantizer/fp_quantize.cu",
  78. "csrc/fp_quantizer/fp_quantize.cpp",
  79. ]
  80. def extra_ldflags(self):
  81. return ['-lcurand']
  82. def include_paths(self):
  83. return ['csrc/fp_quantizer/includes', 'csrc/includes']