builder.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import time
  6. import importlib
  7. try:
  8. # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed
  9. # if successful this also means we're doing a local install and not JIT compile path
  10. from op_builder import __deepspeed__ # noqa: F401 # type: ignore
  11. from op_builder.builder import OpBuilder, TORCH_MAJOR, TORCH_MINOR
  12. except ImportError:
  13. from deepspeed.ops.op_builder.builder import OpBuilder, TORCH_MAJOR, TORCH_MINOR
  14. class SYCLOpBuilder(OpBuilder):
  15. def builder(self):
  16. try:
  17. from intel_extension_for_pytorch.xpu.cpp_extension import DPCPPExtension
  18. except ImportError:
  19. from intel_extension_for_pytorch.xpu.utils import DPCPPExtension
  20. print("dpcpp sources = {}".format(self.sources()))
  21. dpcpp_ext = DPCPPExtension(name=self.absolute_name(),
  22. sources=self.strip_empty_entries(self.sources()),
  23. include_dirs=self.strip_empty_entries(self.include_paths()),
  24. extra_compile_args={
  25. 'cxx': self.strip_empty_entries(self.cxx_args()),
  26. },
  27. extra_link_args=self.strip_empty_entries(self.fixed_aotflags()))
  28. return dpcpp_ext
  29. def version_dependent_macros(self):
  30. # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
  31. version_ge_1_1 = []
  32. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  33. version_ge_1_1 = ['-DVERSION_GE_1_1']
  34. version_ge_1_3 = []
  35. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  36. version_ge_1_3 = ['-DVERSION_GE_1_3']
  37. version_ge_1_5 = []
  38. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  39. version_ge_1_5 = ['-DVERSION_GE_1_5']
  40. return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  41. def cxx_args(self):
  42. cxx_flags = [
  43. '-fsycl', '-fsycl-targets=spir64_gen', '-g', '-gdwarf-4', '-O3', '-std=c++17', '-fPIC', '-DMKL_ILP64',
  44. '-fno-strict-aliasing'
  45. ]
  46. if os.environ.get('USE_MKL_GEMM'):
  47. cxx_flags.append('-DUSE_MKL_GEMM')
  48. return cxx_flags
  49. def extra_ldflags(self):
  50. return [
  51. '-fPIC', '-fsycl', '-fsycl-targets=spir64_gen', '-fsycl-max-parallel-link-jobs=8',
  52. '-Xs "-options -cl-poison-unsupported-fp64-kernels,cl-intel-enable-auto-large-GRF-mode"',
  53. '-Xs "-device pvc"', '-Wl,-export-dynamic'
  54. ]
  55. def fixed_aotflags(self):
  56. return [
  57. '-fsycl', '-fsycl-targets=spir64_gen', '-fsycl-max-parallel-link-jobs=8', '-Xs',
  58. "-options -cl-poison-unsupported-fp64-kernels,cl-intel-enable-auto-large-GRF-mode", '-Xs', "-device pvc"
  59. ]
  60. def load(self, verbose=True):
  61. from deepspeed.git_version_info import installed_ops, torch_info # noqa: F401
  62. if installed_ops.get(self.name, False):
  63. return importlib.import_module(self.absolute_name())
  64. else:
  65. return self.jit_load(verbose)
  66. def jit_load(self, verbose=True):
  67. if not self.is_compatible(verbose):
  68. raise RuntimeError(
  69. f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
  70. )
  71. try:
  72. import ninja # noqa: F401
  73. except ImportError:
  74. raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
  75. self.jit_mode = True
  76. from intel_extension_for_pytorch.xpu.cpp_extension import load
  77. start_build = time.time()
  78. # Recognize relative paths as absolute paths for jit load
  79. sources = [self.deepspeed_src_path(path) for path in self.sources()]
  80. extra_include_paths = [self.deepspeed_src_path(path) for path in self.include_paths()]
  81. # Torch will try and apply whatever CCs are in the arch list at compile time,
  82. # we have already set the intended targets ourselves we know that will be
  83. # needed at runtime. This prevents CC collisions such as multiple __half
  84. # implementations. Stash arch list to reset after build.
  85. '''
  86. torch_arch_list = None
  87. if "TORCH_CUDA_ARCH_LIST" in os.environ:
  88. torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
  89. os.environ["TORCH_CUDA_ARCH_LIST"] = ""
  90. '''
  91. op_module = load(
  92. name=self.name,
  93. sources=self.strip_empty_entries(sources),
  94. extra_include_paths=self.strip_empty_entries(extra_include_paths),
  95. extra_cflags=self.strip_empty_entries(self.cxx_args()),
  96. # extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()),
  97. extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
  98. verbose=verbose)
  99. build_duration = time.time() - start_build
  100. if verbose:
  101. print(f"Time to load {self.name} op: {build_duration} seconds")
  102. '''
  103. # Reset arch list so we are not silently removing it for other possible use cases
  104. if torch_arch_list:
  105. os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
  106. '''
  107. return op_module