builder.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import os
  5. import sys
  6. import time
  7. import json
  8. import importlib
  9. from pathlib import Path
  10. import subprocess
  11. import shlex
  12. import shutil
  13. import tempfile
  14. import distutils.ccompiler
  15. import distutils.log
  16. import distutils.sysconfig
  17. from distutils.errors import CompileError, LinkError
  18. from abc import ABC, abstractmethod
  19. YELLOW = '\033[93m'
  20. END = '\033[0m'
  21. WARNING = f"{YELLOW} [WARNING] {END}"
  22. DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
  23. DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
  24. try:
  25. import torch
  26. except ImportError:
  27. print(
  28. f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops."
  29. )
  30. else:
  31. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  32. TORCH_MINOR = int(torch.__version__.split('.')[1])
  33. def installed_cuda_version():
  34. import torch.utils.cpp_extension
  35. cuda_home = torch.utils.cpp_extension.CUDA_HOME
  36. assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
  37. # Ensure there is not a cuda version mismatch between torch and nvcc compiler
  38. output = subprocess.check_output([cuda_home + "/bin/nvcc",
  39. "-V"],
  40. universal_newlines=True)
  41. output_split = output.split()
  42. release_idx = output_split.index("release")
  43. release = output_split[release_idx + 1].replace(',', '').split(".")
  44. # Ignore patch versions, only look at major + minor
  45. cuda_major, cuda_minor = release[:2]
  46. installed_cuda_version = ".".join(release[:2])
  47. return int(cuda_major), int(cuda_minor)
  48. def get_default_compute_capabilities():
  49. compute_caps = DEFAULT_COMPUTE_CAPABILITIES
  50. import torch.utils.cpp_extension
  51. if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version(
  52. )[0] >= 11:
  53. if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
  54. # Special treatment of CUDA 11.0 because compute_86 is not supported.
  55. compute_caps += ";8.0"
  56. else:
  57. compute_caps += ";8.0;8.6"
  58. return compute_caps
  59. # list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
  60. # to build deepspeed and system-wide installed cuda 11.2
  61. cuda_minor_mismatch_ok = {
  62. 10: [
  63. "10.0",
  64. "10.1",
  65. "10.2",
  66. ],
  67. 11: [
  68. "11.0",
  69. "11.1",
  70. "11.2",
  71. "11.3",
  72. "11.4",
  73. "11.5",
  74. "11.6",
  75. ],
  76. }
  77. def assert_no_cuda_mismatch():
  78. cuda_major, cuda_minor = installed_cuda_version()
  79. sys_cuda_version = f'{cuda_major}.{cuda_minor}'
  80. torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  81. # This is a show-stopping error, should probably not proceed past this
  82. if sys_cuda_version != torch_cuda_version:
  83. if (cuda_major in cuda_minor_mismatch_ok
  84. and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
  85. and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
  86. print(f"Installed CUDA version {sys_cuda_version} does not match the "
  87. f"version torch was compiled with {torch.version.cuda} "
  88. "but since the APIs are compatible, accepting this combination")
  89. return
  90. raise Exception(
  91. f"Installed CUDA version {sys_cuda_version} does not match the "
  92. f"version torch was compiled with {torch.version.cuda}, unable to compile "
  93. "cuda/cpp extensions without a matching cuda version.")
  94. class OpBuilder(ABC):
  95. _rocm_version = None
  96. _is_rocm_pytorch = None
  97. def __init__(self, name):
  98. self.name = name
  99. self.jit_mode = False
  100. @abstractmethod
  101. def absolute_name(self):
  102. '''
  103. Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
  104. will be installed as something like: deepspeed/ops/adam/cpu_adam.so
  105. '''
  106. pass
  107. @abstractmethod
  108. def sources(self):
  109. '''
  110. Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  111. '''
  112. pass
  113. def hipify_extension(self):
  114. pass
  115. @staticmethod
  116. def assert_torch_info(torch_info):
  117. install_torch_version = torch_info['version']
  118. install_cuda_version = torch_info['cuda_version']
  119. install_hip_version = torch_info['hip_version']
  120. if not OpBuilder.is_rocm_pytorch():
  121. current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  122. else:
  123. current_hip_version = ".".join(torch.version.hip.split('.')[:2])
  124. current_torch_version = ".".join(torch.__version__.split('.')[:2])
  125. if not OpBuilder.is_rocm_pytorch():
  126. if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
  127. raise RuntimeError(
  128. "PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
  129. "with a different version than what is being used at runtime. Please re-install "
  130. f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
  131. f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
  132. f"torch={current_torch_version}, cuda={current_cuda_version}")
  133. else:
  134. if install_hip_version != current_hip_version or install_torch_version != current_torch_version:
  135. raise RuntimeError(
  136. "PyTorch and HIP version mismatch! DeepSpeed ops were compiled and installed "
  137. "with a different version than what is being used at runtime. Please re-install "
  138. f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
  139. f"torch={install_torch_version}, hip={install_hip_version}, runtime versions:"
  140. f"torch={current_torch_version}, hip={current_hip_version}")
  141. @staticmethod
  142. def is_rocm_pytorch():
  143. if OpBuilder._is_rocm_pytorch is not None:
  144. return OpBuilder._is_rocm_pytorch
  145. _is_rocm_pytorch = False
  146. try:
  147. import torch
  148. except ImportError:
  149. pass
  150. else:
  151. if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
  152. _is_rocm_pytorch = hasattr(torch.version,
  153. 'hip') and torch.version.hip is not None
  154. if _is_rocm_pytorch:
  155. from torch.utils.cpp_extension import ROCM_HOME
  156. _is_rocm_pytorch = ROCM_HOME is not None
  157. OpBuilder._is_rocm_pytorch = _is_rocm_pytorch
  158. return OpBuilder._is_rocm_pytorch
  159. @staticmethod
  160. def installed_rocm_version():
  161. if OpBuilder._rocm_version:
  162. return OpBuilder._rocm_version
  163. ROCM_MAJOR = '0'
  164. ROCM_MINOR = '0'
  165. if OpBuilder.is_rocm_pytorch():
  166. from torch.utils.cpp_extension import ROCM_HOME
  167. with open('/opt/rocm/.info/version-dev', 'r') as file:
  168. ROCM_VERSION_DEV_RAW = file.read()
  169. ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0]
  170. ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1]
  171. OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR))
  172. return OpBuilder._rocm_version
  173. def include_paths(self):
  174. '''
  175. Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  176. '''
  177. return []
  178. def nvcc_args(self):
  179. '''
  180. Returns optional list of compiler flags to forward to nvcc when building CUDA sources
  181. '''
  182. return []
  183. def cxx_args(self):
  184. '''
  185. Returns optional list of compiler flags to forward to the build
  186. '''
  187. return []
  188. def is_compatible(self, verbose=True):
  189. '''
  190. Check if all non-python dependencies are satisfied to build this op
  191. '''
  192. return True
  193. def extra_ldflags(self):
  194. return []
  195. def libraries_installed(self, libraries):
  196. valid = False
  197. check_cmd = 'dpkg -l'
  198. for lib in libraries:
  199. result = subprocess.Popen(f'dpkg -l {lib}',
  200. stdout=subprocess.PIPE,
  201. stderr=subprocess.PIPE,
  202. shell=True)
  203. valid = valid or result.wait() == 0
  204. return valid
  205. def has_function(self, funcname, libraries, verbose=False):
  206. '''
  207. Test for existence of a function within a tuple of libraries.
  208. This is used as a smoke test to check whether a certain library is available.
  209. As a test, this creates a simple C program that calls the specified function,
  210. and then distutils is used to compile that program and link it with the specified libraries.
  211. Returns True if both the compile and link are successful, False otherwise.
  212. '''
  213. tempdir = None # we create a temporary directory to hold various files
  214. filestderr = None # handle to open file to which we redirect stderr
  215. oldstderr = None # file descriptor for stderr
  216. try:
  217. # Echo compile and link commands that are used.
  218. if verbose:
  219. distutils.log.set_verbosity(1)
  220. # Create a compiler object.
  221. compiler = distutils.ccompiler.new_compiler(verbose=verbose)
  222. # Configure compiler and linker to build according to Python install.
  223. distutils.sysconfig.customize_compiler(compiler)
  224. # Create a temporary directory to hold test files.
  225. tempdir = tempfile.mkdtemp()
  226. # Define a simple C program that calls the function in question
  227. prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (
  228. funcname,
  229. funcname)
  230. # Write the test program to a file.
  231. filename = os.path.join(tempdir, 'test.c')
  232. with open(filename, 'w') as f:
  233. f.write(prog)
  234. # Redirect stderr file descriptor to a file to silence compile/link warnings.
  235. if not verbose:
  236. filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
  237. oldstderr = os.dup(sys.stderr.fileno())
  238. os.dup2(filestderr.fileno(), sys.stderr.fileno())
  239. # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames()
  240. # Otherwise, a local directory will be used instead of tempdir
  241. drive, driveless_filename = os.path.splitdrive(filename)
  242. root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else ''
  243. output_dir = os.path.join(drive, root_dir)
  244. # Attempt to compile the C program into an object file.
  245. cflags = shlex.split(os.environ.get('CFLAGS', ""))
  246. objs = compiler.compile([filename],
  247. output_dir=output_dir,
  248. extra_preargs=self.strip_empty_entries(cflags))
  249. # Attempt to link the object file into an executable.
  250. # Be sure to tack on any libraries that have been specified.
  251. ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
  252. compiler.link_executable(objs,
  253. os.path.join(tempdir,
  254. 'a.out'),
  255. extra_preargs=self.strip_empty_entries(ldflags),
  256. libraries=libraries)
  257. # Compile and link succeeded
  258. return True
  259. except CompileError:
  260. return False
  261. except LinkError:
  262. return False
  263. except:
  264. return False
  265. finally:
  266. # Restore stderr file descriptor and close the stderr redirect file.
  267. if oldstderr is not None:
  268. os.dup2(oldstderr, sys.stderr.fileno())
  269. if filestderr is not None:
  270. filestderr.close()
  271. # Delete the temporary directory holding the test program and stderr files.
  272. if tempdir is not None:
  273. shutil.rmtree(tempdir)
  274. def strip_empty_entries(self, args):
  275. '''
  276. Drop any empty strings from the list of compile and link flags
  277. '''
  278. return [x for x in args if len(x) > 0]
  279. def cpu_arch(self):
  280. try:
  281. from cpuinfo import get_cpu_info
  282. except ImportError as e:
  283. cpu_info = self._backup_cpuinfo()
  284. if cpu_info is None:
  285. return "-march=native"
  286. try:
  287. cpu_info = get_cpu_info()
  288. except Exception as e:
  289. self.warning(
  290. f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  291. "falling back to `lscpu` to get this information.")
  292. cpu_info = self._backup_cpuinfo()
  293. if cpu_info is None:
  294. return "-march=native"
  295. if cpu_info['arch'].startswith('PPC_'):
  296. # gcc does not provide -march on PowerPC, use -mcpu instead
  297. return '-mcpu=native'
  298. return '-march=native'
  299. def _backup_cpuinfo(self):
  300. # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides
  301. if not self.command_exists('lscpu'):
  302. self.warning(
  303. f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo "
  304. "to detect the CPU architecture. 'lscpu' does not appear to exist on "
  305. "your system, will fall back to use -march=native and non-vectorized execution."
  306. )
  307. return None
  308. result = subprocess.check_output('lscpu', shell=True)
  309. result = result.decode('utf-8').strip().lower()
  310. cpu_info = {}
  311. cpu_info['arch'] = None
  312. cpu_info['flags'] = ""
  313. if 'genuineintel' in result or 'authenticamd' in result:
  314. cpu_info['arch'] = 'X86_64'
  315. if 'avx512' in result:
  316. cpu_info['flags'] += 'avx512,'
  317. if 'avx2' in result:
  318. cpu_info['flags'] += 'avx2'
  319. elif 'ppc64le' in result:
  320. cpu_info['arch'] = "PPC_"
  321. return cpu_info
  322. def simd_width(self):
  323. try:
  324. from cpuinfo import get_cpu_info
  325. except ImportError as e:
  326. cpu_info = self._backup_cpuinfo()
  327. if cpu_info is None:
  328. return '-D__SCALAR__'
  329. try:
  330. cpu_info = get_cpu_info()
  331. except Exception as e:
  332. self.warning(
  333. f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  334. "falling back to `lscpu` to get this information.")
  335. cpu_info = self._backup_cpuinfo()
  336. if cpu_info is None:
  337. return '-D__SCALAR__'
  338. if cpu_info['arch'] == 'X86_64':
  339. if 'avx512' in cpu_info['flags']:
  340. return '-D__AVX512__'
  341. elif 'avx2' in cpu_info['flags']:
  342. return '-D__AVX256__'
  343. return '-D__SCALAR__'
  344. def python_requirements(self):
  345. '''
  346. Override if op wants to define special dependencies, otherwise will
  347. take self.name and load requirements-<op-name>.txt if it exists.
  348. '''
  349. path = f'requirements/requirements-{self.name}.txt'
  350. requirements = []
  351. if os.path.isfile(path):
  352. with open(path, 'r') as fd:
  353. requirements = [r.strip() for r in fd.readlines()]
  354. return requirements
  355. def command_exists(self, cmd):
  356. if '|' in cmd:
  357. cmds = cmd.split("|")
  358. else:
  359. cmds = [cmd]
  360. valid = False
  361. for cmd in cmds:
  362. result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
  363. valid = valid or result.wait() == 0
  364. if not valid and len(cmds) > 1:
  365. print(
  366. f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!"
  367. )
  368. elif not valid and len(cmds) == 1:
  369. print(
  370. f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!"
  371. )
  372. return valid
  373. def warning(self, msg):
  374. print(f"{WARNING} {msg}")
  375. def deepspeed_src_path(self, code_path):
  376. if os.path.isabs(code_path):
  377. return code_path
  378. else:
  379. return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
  380. def builder(self):
  381. from torch.utils.cpp_extension import CppExtension
  382. return CppExtension(
  383. name=self.absolute_name(),
  384. sources=self.strip_empty_entries(self.sources()),
  385. include_dirs=self.strip_empty_entries(self.include_paths()),
  386. extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
  387. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  388. def load(self, verbose=True):
  389. from ...git_version_info import installed_ops, torch_info
  390. if installed_ops[self.name]:
  391. # Ensure the op we're about to load was compiled with the same
  392. # torch/cuda versions we are currently using at runtime.
  393. if isinstance(self, CUDAOpBuilder):
  394. self.assert_torch_info(torch_info)
  395. return importlib.import_module(self.absolute_name())
  396. else:
  397. return self.jit_load(verbose)
  398. def jit_load(self, verbose=True):
  399. if not self.is_compatible(verbose):
  400. raise RuntimeError(
  401. f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue."
  402. )
  403. try:
  404. import ninja
  405. except ImportError:
  406. raise RuntimeError(
  407. f"Unable to JIT load the {self.name} op due to ninja not being installed."
  408. )
  409. if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
  410. assert_no_cuda_mismatch()
  411. self.jit_mode = True
  412. from torch.utils.cpp_extension import load
  413. # Ensure directory exists to prevent race condition in some cases
  414. ext_path = os.path.join(
  415. os.environ.get('TORCH_EXTENSIONS_DIR',
  416. DEFAULT_TORCH_EXTENSION_PATH),
  417. self.name)
  418. os.makedirs(ext_path, exist_ok=True)
  419. start_build = time.time()
  420. sources = [self.deepspeed_src_path(path) for path in self.sources()]
  421. extra_include_paths = [
  422. self.deepspeed_src_path(path) for path in self.include_paths()
  423. ]
  424. # Torch will try and apply whatever CCs are in the arch list at compile time,
  425. # we have already set the intended targets ourselves we know that will be
  426. # needed at runtime. This prevents CC collisions such as multiple __half
  427. # implementations. Stash arch list to reset after build.
  428. torch_arch_list = None
  429. if "TORCH_CUDA_ARCH_LIST" in os.environ:
  430. torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
  431. os.environ["TORCH_CUDA_ARCH_LIST"] = ""
  432. op_module = load(
  433. name=self.name,
  434. sources=self.strip_empty_entries(sources),
  435. extra_include_paths=self.strip_empty_entries(extra_include_paths),
  436. extra_cflags=self.strip_empty_entries(self.cxx_args()),
  437. extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()),
  438. extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
  439. verbose=verbose)
  440. build_duration = time.time() - start_build
  441. if verbose:
  442. print(f"Time to load {self.name} op: {build_duration} seconds")
  443. # Reset arch list so we are not silently removing it for other possible use cases
  444. if torch_arch_list:
  445. os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
  446. return op_module
  447. class CUDAOpBuilder(OpBuilder):
  448. def compute_capability_args(self, cross_compile_archs=None):
  449. """
  450. Returns nvcc compute capability compile flags.
  451. 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
  452. 2. If neither is set default compute capabilities will be used
  453. 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
  454. Format:
  455. - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
  456. TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
  457. TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
  458. - `cross_compile_archs` uses ; separator.
  459. """
  460. ccs = []
  461. if self.jit_mode:
  462. # Compile for underlying architectures since we know those at runtime
  463. for i in range(torch.cuda.device_count()):
  464. CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
  465. cc = f"{CC_MAJOR}.{CC_MINOR}"
  466. if cc not in ccs:
  467. ccs.append(cc)
  468. ccs = sorted(ccs)
  469. ccs[-1] += '+PTX'
  470. else:
  471. # Cross-compile mode, compile for various architectures
  472. # env override takes priority
  473. cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
  474. if cross_compile_archs_env is not None:
  475. if cross_compile_archs is not None:
  476. print(
  477. f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
  478. )
  479. cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
  480. else:
  481. if cross_compile_archs is None:
  482. cross_compile_archs = get_default_compute_capabilities()
  483. ccs = cross_compile_archs.split(';')
  484. args = []
  485. for cc in ccs:
  486. num = cc[0] + cc[2]
  487. args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
  488. if cc.endswith('+PTX'):
  489. args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
  490. return args
  491. def version_dependent_macros(self):
  492. # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
  493. version_ge_1_1 = []
  494. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  495. version_ge_1_1 = ['-DVERSION_GE_1_1']
  496. version_ge_1_3 = []
  497. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  498. version_ge_1_3 = ['-DVERSION_GE_1_3']
  499. version_ge_1_5 = []
  500. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  501. version_ge_1_5 = ['-DVERSION_GE_1_5']
  502. return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  503. def is_compatible(self, verbose=True):
  504. return super().is_compatible(verbose)
  505. def builder(self):
  506. from torch.utils.cpp_extension import CUDAExtension
  507. if not self.is_rocm_pytorch():
  508. assert_no_cuda_mismatch()
  509. cuda_ext = CUDAExtension(
  510. name=self.absolute_name(),
  511. sources=self.strip_empty_entries(self.sources()),
  512. include_dirs=self.strip_empty_entries(self.include_paths()),
  513. libraries=self.strip_empty_entries(self.libraries_args()),
  514. extra_compile_args={
  515. 'cxx': self.strip_empty_entries(self.cxx_args()),
  516. 'nvcc': self.strip_empty_entries(self.nvcc_args())
  517. })
  518. if self.is_rocm_pytorch():
  519. # hip converts paths to absolute, this converts back to relative
  520. sources = cuda_ext.sources
  521. curr_file = Path(__file__).parent.parent # ds root
  522. for i in range(len(sources)):
  523. src = Path(sources[i])
  524. sources[i] = str(src.relative_to(curr_file))
  525. cuda_ext.sources = sources
  526. return cuda_ext
  527. def hipify_extension(self):
  528. if self.is_rocm_pytorch():
  529. from torch.utils.hipify import hipify_python
  530. hipify_python.hipify(
  531. project_directory=os.getcwd(),
  532. output_directory=os.getcwd(),
  533. header_include_dirs=self.include_paths(),
  534. includes=[os.path.join(os.getcwd(),
  535. '*')],
  536. extra_files=[os.path.abspath(s) for s in self.sources()],
  537. show_detailed=True,
  538. is_pytorch_extension=True,
  539. hipify_extra_files_only=True,
  540. )
  541. def cxx_args(self):
  542. if sys.platform == "win32":
  543. return ['-O2']
  544. else:
  545. return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
  546. def nvcc_args(self):
  547. args = ['-O3']
  548. if self.is_rocm_pytorch():
  549. ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version()
  550. args += [
  551. '-std=c++14',
  552. '-U__HIP_NO_HALF_OPERATORS__',
  553. '-U__HIP_NO_HALF_CONVERSIONS__',
  554. '-U__HIP_NO_HALF2_OPERATORS__',
  555. '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
  556. '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
  557. ]
  558. else:
  559. cuda_major, _ = installed_cuda_version()
  560. args += [
  561. '--use_fast_math',
  562. '-std=c++17'
  563. if sys.platform == "win32" and cuda_major > 10 else '-std=c++14',
  564. '-U__CUDA_NO_HALF_OPERATORS__',
  565. '-U__CUDA_NO_HALF_CONVERSIONS__',
  566. '-U__CUDA_NO_HALF2_OPERATORS__'
  567. ]
  568. args += self.compute_capability_args()
  569. return args
  570. def libraries_args(self):
  571. if sys.platform == "win32":
  572. return ['cublas', 'curand']
  573. else:
  574. return []
  575. class TorchCPUOpBuilder(CUDAOpBuilder):
  576. def extra_ldflags(self):
  577. if not self.is_rocm_pytorch():
  578. return ['-lcurand']
  579. else:
  580. return []
  581. def cxx_args(self):
  582. import torch
  583. if not self.is_rocm_pytorch():
  584. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
  585. else:
  586. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
  587. CPU_ARCH = self.cpu_arch()
  588. SIMD_WIDTH = self.simd_width()
  589. args = super().cxx_args()
  590. args += [
  591. f'-L{CUDA_LIB64}',
  592. '-lcudart',
  593. '-lcublas',
  594. '-g',
  595. CPU_ARCH,
  596. '-fopenmp',
  597. SIMD_WIDTH,
  598. ]
  599. return args