builder.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import sys
  6. import time
  7. import importlib
  8. from pathlib import Path
  9. import subprocess
  10. import shlex
  11. import shutil
  12. import tempfile
  13. import distutils.ccompiler
  14. import distutils.log
  15. import distutils.sysconfig
  16. from distutils.errors import CompileError, LinkError
  17. from abc import ABC, abstractmethod
  18. from typing import List
  19. YELLOW = '\033[93m'
  20. END = '\033[0m'
  21. WARNING = f"{YELLOW} [WARNING] {END}"
  22. DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
  23. DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
  24. try:
  25. import torch
  26. except ImportError:
  27. print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.")
  28. else:
  29. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  30. TORCH_MINOR = int(torch.__version__.split('.')[1])
  31. class MissingCUDAException(Exception):
  32. pass
  33. class CUDAMismatchException(Exception):
  34. pass
  35. def installed_cuda_version(name=""):
  36. import torch.utils.cpp_extension
  37. cuda_home = torch.utils.cpp_extension.CUDA_HOME
  38. if cuda_home is None:
  39. raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)")
  40. # Ensure there is not a cuda version mismatch between torch and nvcc compiler
  41. output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
  42. output_split = output.split()
  43. release_idx = output_split.index("release")
  44. release = output_split[release_idx + 1].replace(',', '').split(".")
  45. # Ignore patch versions, only look at major + minor
  46. cuda_major, cuda_minor = release[:2]
  47. return int(cuda_major), int(cuda_minor)
  48. def get_default_compute_capabilities():
  49. compute_caps = DEFAULT_COMPUTE_CAPABILITIES
  50. import torch.utils.cpp_extension
  51. if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11:
  52. if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
  53. # Special treatment of CUDA 11.0 because compute_86 is not supported.
  54. compute_caps += ";8.0"
  55. else:
  56. compute_caps += ";8.0;8.6"
  57. return compute_caps
  58. # list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
  59. # to build deepspeed and system-wide installed cuda 11.2
  60. cuda_minor_mismatch_ok = {
  61. 10: ["10.0", "10.1", "10.2"],
  62. 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"],
  63. 12: ["12.0", "12.1", "12.2", "12.3"],
  64. }
  65. def assert_no_cuda_mismatch(name=""):
  66. cuda_major, cuda_minor = installed_cuda_version(name)
  67. sys_cuda_version = f'{cuda_major}.{cuda_minor}'
  68. torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  69. # This is a show-stopping error, should probably not proceed past this
  70. if sys_cuda_version != torch_cuda_version:
  71. if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
  72. and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
  73. print(f"Installed CUDA version {sys_cuda_version} does not match the "
  74. f"version torch was compiled with {torch.version.cuda} "
  75. "but since the APIs are compatible, accepting this combination")
  76. return True
  77. elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1":
  78. print(
  79. f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
  80. f"version torch was compiled with {torch.version.cuda}."
  81. "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior."
  82. )
  83. return True
  84. raise CUDAMismatchException(
  85. f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
  86. f"version torch was compiled with {torch.version.cuda}, unable to compile "
  87. "cuda/cpp extensions without a matching cuda version.")
  88. return True
  89. class OpBuilder(ABC):
  90. _rocm_version = None
  91. _is_rocm_pytorch = None
  92. _is_sycl_enabled = None
  93. _loaded_ops = {}
  94. def __init__(self, name):
  95. self.name = name
  96. self.jit_mode = False
  97. self.build_for_cpu = False
  98. self.enable_bf16 = False
  99. self.error_log = None
  100. @abstractmethod
  101. def absolute_name(self):
  102. '''
  103. Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
  104. will be installed as something like: deepspeed/ops/adam/cpu_adam.so
  105. '''
  106. pass
  107. @abstractmethod
  108. def sources(self):
  109. '''
  110. Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  111. '''
  112. pass
  113. def hipify_extension(self):
  114. pass
  115. def sycl_extension(self):
  116. pass
  117. @staticmethod
  118. def validate_torch_version(torch_info):
  119. install_torch_version = torch_info['version']
  120. current_torch_version = ".".join(torch.__version__.split('.')[:2])
  121. if install_torch_version != current_torch_version:
  122. raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed "
  123. "with a different version than what is being used at runtime. "
  124. f"Please re-install DeepSpeed or switch torch versions. "
  125. f"Install torch version={install_torch_version}, "
  126. f"Runtime torch version={current_torch_version}")
  127. @staticmethod
  128. def validate_torch_op_version(torch_info):
  129. if not OpBuilder.is_rocm_pytorch():
  130. current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  131. install_cuda_version = torch_info['cuda_version']
  132. if install_cuda_version != current_cuda_version:
  133. raise RuntimeError("CUDA version mismatch! DeepSpeed ops were compiled and installed "
  134. "with a different version than what is being used at runtime. "
  135. f"Please re-install DeepSpeed or switch torch versions. "
  136. f"Install CUDA version={install_cuda_version}, "
  137. f"Runtime CUDA version={current_cuda_version}")
  138. else:
  139. current_hip_version = ".".join(torch.version.hip.split('.')[:2])
  140. install_hip_version = torch_info['hip_version']
  141. if install_hip_version != current_hip_version:
  142. raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed "
  143. "with a different version than what is being used at runtime. "
  144. f"Please re-install DeepSpeed or switch torch versions. "
  145. f"Install HIP version={install_hip_version}, "
  146. f"Runtime HIP version={current_hip_version}")
  147. @staticmethod
  148. def is_rocm_pytorch():
  149. if OpBuilder._is_rocm_pytorch is not None:
  150. return OpBuilder._is_rocm_pytorch
  151. _is_rocm_pytorch = False
  152. try:
  153. import torch
  154. except ImportError:
  155. pass
  156. else:
  157. if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
  158. _is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None
  159. if _is_rocm_pytorch:
  160. from torch.utils.cpp_extension import ROCM_HOME
  161. _is_rocm_pytorch = ROCM_HOME is not None
  162. OpBuilder._is_rocm_pytorch = _is_rocm_pytorch
  163. return OpBuilder._is_rocm_pytorch
  164. @staticmethod
  165. def is_sycl_enabled():
  166. if OpBuilder._is_sycl_enabled is not None:
  167. return OpBuilder._is_sycl_enabled
  168. _is_sycl_enabled = False
  169. try:
  170. result = subprocess.run(["c2s", "--version"], capture_output=True)
  171. except:
  172. pass
  173. else:
  174. _is_sycl_enabled = True
  175. OpBuilder._is_sycl_enabled = _is_sycl_enabled
  176. return OpBuilder._is_sycl_enabled
  177. @staticmethod
  178. def installed_rocm_version():
  179. if OpBuilder._rocm_version:
  180. return OpBuilder._rocm_version
  181. ROCM_MAJOR = '0'
  182. ROCM_MINOR = '0'
  183. if OpBuilder.is_rocm_pytorch():
  184. from torch.utils.cpp_extension import ROCM_HOME
  185. rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev")
  186. if rocm_ver_file.is_file():
  187. with open(rocm_ver_file, 'r') as file:
  188. ROCM_VERSION_DEV_RAW = file.read()
  189. elif "rocm" in torch.__version__:
  190. ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1]
  191. else:
  192. assert False, "Could not detect ROCm version"
  193. assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version"
  194. ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0]
  195. ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1]
  196. OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR))
  197. return OpBuilder._rocm_version
  198. def include_paths(self):
  199. '''
  200. Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  201. '''
  202. return []
  203. def nvcc_args(self):
  204. '''
  205. Returns optional list of compiler flags to forward to nvcc when building CUDA sources
  206. '''
  207. return []
  208. def cxx_args(self):
  209. '''
  210. Returns optional list of compiler flags to forward to the build
  211. '''
  212. return []
  213. def is_compatible(self, verbose=True):
  214. '''
  215. Check if all non-python dependencies are satisfied to build this op
  216. '''
  217. return True
  218. def extra_ldflags(self):
  219. return []
  220. def has_function(self, funcname, libraries, verbose=False):
  221. '''
  222. Test for existence of a function within a tuple of libraries.
  223. This is used as a smoke test to check whether a certain library is available.
  224. As a test, this creates a simple C program that calls the specified function,
  225. and then distutils is used to compile that program and link it with the specified libraries.
  226. Returns True if both the compile and link are successful, False otherwise.
  227. '''
  228. tempdir = None # we create a temporary directory to hold various files
  229. filestderr = None # handle to open file to which we redirect stderr
  230. oldstderr = None # file descriptor for stderr
  231. try:
  232. # Echo compile and link commands that are used.
  233. if verbose:
  234. distutils.log.set_verbosity(1)
  235. # Create a compiler object.
  236. compiler = distutils.ccompiler.new_compiler(verbose=verbose)
  237. # Configure compiler and linker to build according to Python install.
  238. distutils.sysconfig.customize_compiler(compiler)
  239. # Create a temporary directory to hold test files.
  240. tempdir = tempfile.mkdtemp()
  241. # Define a simple C program that calls the function in question
  242. prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname)
  243. # Write the test program to a file.
  244. filename = os.path.join(tempdir, 'test.c')
  245. with open(filename, 'w') as f:
  246. f.write(prog)
  247. # Redirect stderr file descriptor to a file to silence compile/link warnings.
  248. if not verbose:
  249. filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
  250. oldstderr = os.dup(sys.stderr.fileno())
  251. os.dup2(filestderr.fileno(), sys.stderr.fileno())
  252. # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames()
  253. # Otherwise, a local directory will be used instead of tempdir
  254. drive, driveless_filename = os.path.splitdrive(filename)
  255. root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else ''
  256. output_dir = os.path.join(drive, root_dir)
  257. # Attempt to compile the C program into an object file.
  258. cflags = shlex.split(os.environ.get('CFLAGS', ""))
  259. objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags))
  260. # Attempt to link the object file into an executable.
  261. # Be sure to tack on any libraries that have been specified.
  262. ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
  263. compiler.link_executable(objs,
  264. os.path.join(tempdir, 'a.out'),
  265. extra_preargs=self.strip_empty_entries(ldflags),
  266. libraries=libraries)
  267. # Compile and link succeeded
  268. return True
  269. except CompileError:
  270. return False
  271. except LinkError:
  272. return False
  273. except:
  274. return False
  275. finally:
  276. # Restore stderr file descriptor and close the stderr redirect file.
  277. if oldstderr is not None:
  278. os.dup2(oldstderr, sys.stderr.fileno())
  279. if filestderr is not None:
  280. filestderr.close()
  281. # Delete the temporary directory holding the test program and stderr files.
  282. if tempdir is not None:
  283. shutil.rmtree(tempdir)
  284. def strip_empty_entries(self, args):
  285. '''
  286. Drop any empty strings from the list of compile and link flags
  287. '''
  288. return [x for x in args if len(x) > 0]
  289. def cpu_arch(self):
  290. try:
  291. from cpuinfo import get_cpu_info
  292. except ImportError as e:
  293. cpu_info = self._backup_cpuinfo()
  294. if cpu_info is None:
  295. return "-march=native"
  296. try:
  297. cpu_info = get_cpu_info()
  298. except Exception as e:
  299. self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  300. "falling back to `lscpu` to get this information.")
  301. cpu_info = self._backup_cpuinfo()
  302. if cpu_info is None:
  303. return "-march=native"
  304. if cpu_info['arch'].startswith('PPC_'):
  305. # gcc does not provide -march on PowerPC, use -mcpu instead
  306. return '-mcpu=native'
  307. return '-march=native'
  308. def is_cuda_enable(self):
  309. try:
  310. assert_no_cuda_mismatch(self.name)
  311. return '-D__ENABLE_CUDA__'
  312. except MissingCUDAException:
  313. print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
  314. "only cpu ops can be compiled!")
  315. return '-D__DISABLE_CUDA__'
  316. return '-D__DISABLE_CUDA__'
  317. def _backup_cpuinfo(self):
  318. # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides
  319. if not self.command_exists('lscpu'):
  320. self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo "
  321. "to detect the CPU architecture. 'lscpu' does not appear to exist on "
  322. "your system, will fall back to use -march=native and non-vectorized execution.")
  323. return None
  324. result = subprocess.check_output('lscpu', shell=True)
  325. result = result.decode('utf-8').strip().lower()
  326. cpu_info = {}
  327. cpu_info['arch'] = None
  328. cpu_info['flags'] = ""
  329. if 'genuineintel' in result or 'authenticamd' in result:
  330. cpu_info['arch'] = 'X86_64'
  331. if 'avx512' in result:
  332. cpu_info['flags'] += 'avx512,'
  333. elif 'avx512f' in result:
  334. cpu_info['flags'] += 'avx512f,'
  335. if 'avx2' in result:
  336. cpu_info['flags'] += 'avx2'
  337. elif 'ppc64le' in result:
  338. cpu_info['arch'] = "PPC_"
  339. return cpu_info
  340. def simd_width(self):
  341. try:
  342. from cpuinfo import get_cpu_info
  343. except ImportError as e:
  344. cpu_info = self._backup_cpuinfo()
  345. if cpu_info is None:
  346. return '-D__SCALAR__'
  347. try:
  348. cpu_info = get_cpu_info()
  349. except Exception as e:
  350. self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  351. "falling back to `lscpu` to get this information.")
  352. cpu_info = self._backup_cpuinfo()
  353. if cpu_info is None:
  354. return '-D__SCALAR__'
  355. if cpu_info['arch'] == 'X86_64':
  356. if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']:
  357. return '-D__AVX512__'
  358. elif 'avx2' in cpu_info['flags']:
  359. return '-D__AVX256__'
  360. return '-D__SCALAR__'
  361. def command_exists(self, cmd):
  362. if '|' in cmd:
  363. cmds = cmd.split("|")
  364. else:
  365. cmds = [cmd]
  366. valid = False
  367. for cmd in cmds:
  368. result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
  369. valid = valid or result.wait() == 0
  370. if not valid and len(cmds) > 1:
  371. print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!")
  372. elif not valid and len(cmds) == 1:
  373. print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!")
  374. return valid
  375. def warning(self, msg):
  376. self.error_log = f"{msg}"
  377. print(f"{WARNING} {msg}")
  378. def deepspeed_src_path(self, code_path):
  379. if os.path.isabs(code_path):
  380. return code_path
  381. else:
  382. return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
  383. def builder(self):
  384. from torch.utils.cpp_extension import CppExtension
  385. return CppExtension(name=self.absolute_name(),
  386. sources=self.strip_empty_entries(self.sources()),
  387. include_dirs=self.strip_empty_entries(self.include_paths()),
  388. extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
  389. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  390. def load(self, verbose=True):
  391. if self.name in __class__._loaded_ops:
  392. return __class__._loaded_ops[self.name]
  393. from deepspeed.git_version_info import installed_ops, torch_info
  394. if installed_ops.get(self.name, False):
  395. # Ensure the op we're about to load was compiled with the same
  396. # torch/cuda versions we are currently using at runtime.
  397. self.validate_torch_version(torch_info)
  398. if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder):
  399. self.validate_torch_op_version(torch_info)
  400. op_module = importlib.import_module(self.absolute_name())
  401. __class__._loaded_ops[self.name] = op_module
  402. return op_module
  403. else:
  404. return self.jit_load(verbose)
  405. def jit_load(self, verbose=True):
  406. if not self.is_compatible(verbose):
  407. raise RuntimeError(
  408. f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
  409. )
  410. try:
  411. import ninja # noqa: F401 # type: ignore
  412. except ImportError:
  413. raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
  414. if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
  415. self.build_for_cpu = not torch.cuda.is_available()
  416. self.jit_mode = True
  417. from torch.utils.cpp_extension import load
  418. start_build = time.time()
  419. sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()]
  420. extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()]
  421. # Torch will try and apply whatever CCs are in the arch list at compile time,
  422. # we have already set the intended targets ourselves we know that will be
  423. # needed at runtime. This prevents CC collisions such as multiple __half
  424. # implementations. Stash arch list to reset after build.
  425. torch_arch_list = None
  426. if "TORCH_CUDA_ARCH_LIST" in os.environ:
  427. torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
  428. os.environ["TORCH_CUDA_ARCH_LIST"] = ""
  429. nvcc_args = self.strip_empty_entries(self.nvcc_args())
  430. cxx_args = self.strip_empty_entries(self.cxx_args())
  431. if isinstance(self, CUDAOpBuilder):
  432. if not self.build_for_cpu and self.enable_bf16:
  433. cxx_args.append("-DBF16_AVAILABLE")
  434. nvcc_args.append("-DBF16_AVAILABLE")
  435. nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
  436. nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
  437. if self.is_rocm_pytorch():
  438. cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
  439. op_module = load(name=self.name,
  440. sources=self.strip_empty_entries(sources),
  441. extra_include_paths=self.strip_empty_entries(extra_include_paths),
  442. extra_cflags=cxx_args,
  443. extra_cuda_cflags=nvcc_args,
  444. extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
  445. verbose=verbose)
  446. build_duration = time.time() - start_build
  447. if verbose:
  448. print(f"Time to load {self.name} op: {build_duration} seconds")
  449. # Reset arch list so we are not silently removing it for other possible use cases
  450. if torch_arch_list:
  451. os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
  452. __class__._loaded_ops[self.name] = op_module
  453. return op_module
  454. class CUDAOpBuilder(OpBuilder):
  455. def compute_capability_args(self, cross_compile_archs=None):
  456. """
  457. Returns nvcc compute capability compile flags.
  458. 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
  459. 2. If neither is set default compute capabilities will be used
  460. 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
  461. Format:
  462. - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
  463. TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
  464. TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
  465. - `cross_compile_archs` uses ; separator.
  466. """
  467. ccs = []
  468. if self.jit_mode:
  469. # Compile for underlying architectures since we know those at runtime
  470. for i in range(torch.cuda.device_count()):
  471. CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
  472. cc = f"{CC_MAJOR}.{CC_MINOR}"
  473. if cc not in ccs:
  474. ccs.append(cc)
  475. ccs = sorted(ccs)
  476. ccs[-1] += '+PTX'
  477. else:
  478. # Cross-compile mode, compile for various architectures
  479. # env override takes priority
  480. cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
  481. if cross_compile_archs_env is not None:
  482. if cross_compile_archs is not None:
  483. print(
  484. f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
  485. )
  486. cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
  487. else:
  488. if cross_compile_archs is None:
  489. cross_compile_archs = get_default_compute_capabilities()
  490. ccs = cross_compile_archs.split(';')
  491. ccs = self.filter_ccs(ccs)
  492. if len(ccs) == 0:
  493. raise RuntimeError(
  494. f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")
  495. args = []
  496. self.enable_bf16 = True
  497. for cc in ccs:
  498. num = cc[0] + cc[2]
  499. args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
  500. if cc.endswith('+PTX'):
  501. args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
  502. if int(cc[0]) <= 7:
  503. self.enable_bf16 = False
  504. return args
  505. def filter_ccs(self, ccs: List[str]):
  506. """
  507. Prune any compute capabilities that are not compatible with the builder. Should log
  508. which CCs have been pruned.
  509. """
  510. return ccs
  511. def version_dependent_macros(self):
  512. # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
  513. version_ge_1_1 = []
  514. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  515. version_ge_1_1 = ['-DVERSION_GE_1_1']
  516. version_ge_1_3 = []
  517. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  518. version_ge_1_3 = ['-DVERSION_GE_1_3']
  519. version_ge_1_5 = []
  520. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  521. version_ge_1_5 = ['-DVERSION_GE_1_5']
  522. return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  523. def is_compatible(self, verbose=True):
  524. return super().is_compatible(verbose)
  525. def builder(self):
  526. try:
  527. if not self.is_rocm_pytorch():
  528. assert_no_cuda_mismatch(self.name)
  529. self.build_for_cpu = False
  530. except MissingCUDAException:
  531. self.build_for_cpu = True
  532. if self.build_for_cpu:
  533. from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
  534. else:
  535. from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder
  536. compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \
  537. {'cxx': self.strip_empty_entries(self.cxx_args()), \
  538. 'nvcc': self.strip_empty_entries(self.nvcc_args())}
  539. if not self.build_for_cpu and self.enable_bf16:
  540. compile_args['cxx'].append("-DBF16_AVAILABLE")
  541. if self.is_rocm_pytorch():
  542. compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1")
  543. cuda_ext = ExtensionBuilder(name=self.absolute_name(),
  544. sources=self.strip_empty_entries(self.sources()),
  545. include_dirs=self.strip_empty_entries(self.include_paths()),
  546. libraries=self.strip_empty_entries(self.libraries_args()),
  547. extra_compile_args=compile_args,
  548. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  549. if self.is_rocm_pytorch():
  550. # hip converts paths to absolute, this converts back to relative
  551. sources = cuda_ext.sources
  552. curr_file = Path(__file__).parent.parent # ds root
  553. for i in range(len(sources)):
  554. src = Path(sources[i])
  555. if src.is_absolute():
  556. sources[i] = str(src.relative_to(curr_file))
  557. else:
  558. sources[i] = str(src)
  559. cuda_ext.sources = sources
  560. return cuda_ext
  561. def hipify_extension(self):
  562. if self.is_rocm_pytorch():
  563. from torch.utils.hipify import hipify_python
  564. hipify_python.hipify(
  565. project_directory=os.getcwd(),
  566. output_directory=os.getcwd(),
  567. header_include_dirs=self.include_paths(),
  568. includes=[os.path.join(os.getcwd(), '*')],
  569. extra_files=[os.path.abspath(s) for s in self.sources()],
  570. show_detailed=True,
  571. is_pytorch_extension=True,
  572. hipify_extra_files_only=True,
  573. )
  574. def cxx_args(self):
  575. if sys.platform == "win32":
  576. return ['-O2']
  577. else:
  578. return ['-O3', '-std=c++17', '-g', '-Wno-reorder']
  579. def nvcc_args(self):
  580. if self.build_for_cpu:
  581. return []
  582. args = ['-O3']
  583. if self.is_rocm_pytorch():
  584. ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version()
  585. args += [
  586. '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__',
  587. '-U__HIP_NO_HALF2_OPERATORS__',
  588. '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
  589. '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
  590. ]
  591. else:
  592. cuda_major, _ = installed_cuda_version()
  593. args += [
  594. '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math',
  595. '-std=c++17' if cuda_major > 10 else '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__',
  596. '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__'
  597. ]
  598. if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1':
  599. args.append('--ptxas-options=-v')
  600. args += self.compute_capability_args()
  601. return args
  602. def libraries_args(self):
  603. if self.build_for_cpu:
  604. return []
  605. if sys.platform == "win32":
  606. return ['cublas', 'curand']
  607. else:
  608. return []
  609. class TorchCPUOpBuilder(CUDAOpBuilder):
  610. def extra_ldflags(self):
  611. if self.build_for_cpu:
  612. return ['-fopenmp']
  613. if not self.is_rocm_pytorch():
  614. return ['-lcurand']
  615. return []
  616. def cxx_args(self):
  617. import torch
  618. args = []
  619. if not self.build_for_cpu:
  620. if not self.is_rocm_pytorch():
  621. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
  622. if not os.path.exists(CUDA_LIB64):
  623. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
  624. else:
  625. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
  626. args += super().cxx_args()
  627. args += [
  628. f'-L{CUDA_LIB64}',
  629. '-lcudart',
  630. '-lcublas',
  631. '-g',
  632. ]
  633. CPU_ARCH = self.cpu_arch()
  634. SIMD_WIDTH = self.simd_width()
  635. CUDA_ENABLE = self.is_cuda_enable()
  636. args += [
  637. CPU_ARCH,
  638. '-fopenmp',
  639. SIMD_WIDTH,
  640. CUDA_ENABLE,
  641. ]
  642. return args