builder.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import os
  5. import sys
  6. import time
  7. import importlib
  8. from pathlib import Path
  9. import subprocess
  10. import shlex
  11. import shutil
  12. import tempfile
  13. import distutils.ccompiler
  14. import distutils.log
  15. import distutils.sysconfig
  16. from distutils.errors import CompileError, LinkError
  17. from abc import ABC, abstractmethod
  18. YELLOW = '\033[93m'
  19. END = '\033[0m'
  20. WARNING = f"{YELLOW} [WARNING] {END}"
  21. DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
  22. DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
  23. try:
  24. import torch
  25. except ImportError:
  26. print(
  27. f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops."
  28. )
  29. else:
  30. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  31. TORCH_MINOR = int(torch.__version__.split('.')[1])
  32. def installed_cuda_version():
  33. import torch.utils.cpp_extension
  34. cuda_home = torch.utils.cpp_extension.CUDA_HOME
  35. assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
  36. # Ensure there is not a cuda version mismatch between torch and nvcc compiler
  37. output = subprocess.check_output([cuda_home + "/bin/nvcc",
  38. "-V"],
  39. universal_newlines=True)
  40. output_split = output.split()
  41. release_idx = output_split.index("release")
  42. release = output_split[release_idx + 1].replace(',', '').split(".")
  43. # Ignore patch versions, only look at major + minor
  44. cuda_major, cuda_minor = release[:2]
  45. installed_cuda_version = ".".join(release[:2])
  46. return int(cuda_major), int(cuda_minor)
  47. def get_default_compute_capabilities():
  48. compute_caps = DEFAULT_COMPUTE_CAPABILITIES
  49. import torch.utils.cpp_extension
  50. if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version(
  51. )[0] >= 11:
  52. if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
  53. # Special treatment of CUDA 11.0 because compute_86 is not supported.
  54. compute_caps += ";8.0"
  55. else:
  56. compute_caps += ";8.0;8.6"
  57. return compute_caps
  58. # list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
  59. # to build deepspeed and system-wide installed cuda 11.2
  60. cuda_minor_mismatch_ok = {
  61. 10: [
  62. "10.0",
  63. "10.1",
  64. "10.2",
  65. ],
  66. 11: ["11.0",
  67. "11.1",
  68. "11.2",
  69. "11.3",
  70. "11.4",
  71. "11.5",
  72. "11.6",
  73. "11.7"],
  74. }
  75. def assert_no_cuda_mismatch():
  76. cuda_major, cuda_minor = installed_cuda_version()
  77. sys_cuda_version = f'{cuda_major}.{cuda_minor}'
  78. torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  79. # This is a show-stopping error, should probably not proceed past this
  80. if sys_cuda_version != torch_cuda_version:
  81. if (cuda_major in cuda_minor_mismatch_ok
  82. and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
  83. and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
  84. print(f"Installed CUDA version {sys_cuda_version} does not match the "
  85. f"version torch was compiled with {torch.version.cuda} "
  86. "but since the APIs are compatible, accepting this combination")
  87. return
  88. raise Exception(
  89. f"Installed CUDA version {sys_cuda_version} does not match the "
  90. f"version torch was compiled with {torch.version.cuda}, unable to compile "
  91. "cuda/cpp extensions without a matching cuda version.")
  92. class OpBuilder(ABC):
  93. _rocm_version = None
  94. _is_rocm_pytorch = None
  95. def __init__(self, name):
  96. self.name = name
  97. self.jit_mode = False
  98. self.error_log = None
  99. @abstractmethod
  100. def absolute_name(self):
  101. '''
  102. Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
  103. will be installed as something like: deepspeed/ops/adam/cpu_adam.so
  104. '''
  105. pass
  106. @abstractmethod
  107. def sources(self):
  108. '''
  109. Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  110. '''
  111. pass
  112. def hipify_extension(self):
  113. pass
  114. @staticmethod
  115. def assert_torch_info(torch_info):
  116. install_torch_version = torch_info['version']
  117. install_cuda_version = torch_info['cuda_version']
  118. install_hip_version = torch_info['hip_version']
  119. if not OpBuilder.is_rocm_pytorch():
  120. current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  121. else:
  122. current_hip_version = ".".join(torch.version.hip.split('.')[:2])
  123. current_torch_version = ".".join(torch.__version__.split('.')[:2])
  124. if not OpBuilder.is_rocm_pytorch():
  125. if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
  126. raise RuntimeError(
  127. "PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
  128. "with a different version than what is being used at runtime. Please re-install "
  129. f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
  130. f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
  131. f"torch={current_torch_version}, cuda={current_cuda_version}")
  132. else:
  133. if install_hip_version != current_hip_version or install_torch_version != current_torch_version:
  134. raise RuntimeError(
  135. "PyTorch and HIP version mismatch! DeepSpeed ops were compiled and installed "
  136. "with a different version than what is being used at runtime. Please re-install "
  137. f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
  138. f"torch={install_torch_version}, hip={install_hip_version}, runtime versions:"
  139. f"torch={current_torch_version}, hip={current_hip_version}")
  140. @staticmethod
  141. def is_rocm_pytorch():
  142. if OpBuilder._is_rocm_pytorch is not None:
  143. return OpBuilder._is_rocm_pytorch
  144. _is_rocm_pytorch = False
  145. try:
  146. import torch
  147. except ImportError:
  148. pass
  149. else:
  150. if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
  151. _is_rocm_pytorch = hasattr(torch.version,
  152. 'hip') and torch.version.hip is not None
  153. if _is_rocm_pytorch:
  154. from torch.utils.cpp_extension import ROCM_HOME
  155. _is_rocm_pytorch = ROCM_HOME is not None
  156. OpBuilder._is_rocm_pytorch = _is_rocm_pytorch
  157. return OpBuilder._is_rocm_pytorch
  158. @staticmethod
  159. def installed_rocm_version():
  160. if OpBuilder._rocm_version:
  161. return OpBuilder._rocm_version
  162. ROCM_MAJOR = '0'
  163. ROCM_MINOR = '0'
  164. if OpBuilder.is_rocm_pytorch():
  165. from torch.utils.cpp_extension import ROCM_HOME
  166. rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev")
  167. if rocm_ver_file.is_file():
  168. with open(rocm_ver_file, 'r') as file:
  169. ROCM_VERSION_DEV_RAW = file.read()
  170. elif "rocm" in torch.__version__:
  171. ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1]
  172. else:
  173. assert False, "Could not detect ROCm version"
  174. assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version"
  175. ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0]
  176. ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1]
  177. OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR))
  178. return OpBuilder._rocm_version
  179. def include_paths(self):
  180. '''
  181. Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  182. '''
  183. return []
  184. def nvcc_args(self):
  185. '''
  186. Returns optional list of compiler flags to forward to nvcc when building CUDA sources
  187. '''
  188. return []
  189. def cxx_args(self):
  190. '''
  191. Returns optional list of compiler flags to forward to the build
  192. '''
  193. return []
  194. def is_compatible(self, verbose=True):
  195. '''
  196. Check if all non-python dependencies are satisfied to build this op
  197. '''
  198. return True
  199. def extra_ldflags(self):
  200. return []
  201. def libraries_installed(self, libraries):
  202. valid = False
  203. check_cmd = 'dpkg -l'
  204. for lib in libraries:
  205. result = subprocess.Popen(f'dpkg -l {lib}',
  206. stdout=subprocess.PIPE,
  207. stderr=subprocess.PIPE,
  208. shell=True)
  209. valid = valid or result.wait() == 0
  210. return valid
  211. def has_function(self, funcname, libraries, verbose=False):
  212. '''
  213. Test for existence of a function within a tuple of libraries.
  214. This is used as a smoke test to check whether a certain library is available.
  215. As a test, this creates a simple C program that calls the specified function,
  216. and then distutils is used to compile that program and link it with the specified libraries.
  217. Returns True if both the compile and link are successful, False otherwise.
  218. '''
  219. tempdir = None # we create a temporary directory to hold various files
  220. filestderr = None # handle to open file to which we redirect stderr
  221. oldstderr = None # file descriptor for stderr
  222. try:
  223. # Echo compile and link commands that are used.
  224. if verbose:
  225. distutils.log.set_verbosity(1)
  226. # Create a compiler object.
  227. compiler = distutils.ccompiler.new_compiler(verbose=verbose)
  228. # Configure compiler and linker to build according to Python install.
  229. distutils.sysconfig.customize_compiler(compiler)
  230. # Create a temporary directory to hold test files.
  231. tempdir = tempfile.mkdtemp()
  232. # Define a simple C program that calls the function in question
  233. prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (
  234. funcname,
  235. funcname)
  236. # Write the test program to a file.
  237. filename = os.path.join(tempdir, 'test.c')
  238. with open(filename, 'w') as f:
  239. f.write(prog)
  240. # Redirect stderr file descriptor to a file to silence compile/link warnings.
  241. if not verbose:
  242. filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
  243. oldstderr = os.dup(sys.stderr.fileno())
  244. os.dup2(filestderr.fileno(), sys.stderr.fileno())
  245. # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames()
  246. # Otherwise, a local directory will be used instead of tempdir
  247. drive, driveless_filename = os.path.splitdrive(filename)
  248. root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else ''
  249. output_dir = os.path.join(drive, root_dir)
  250. # Attempt to compile the C program into an object file.
  251. cflags = shlex.split(os.environ.get('CFLAGS', ""))
  252. objs = compiler.compile([filename],
  253. output_dir=output_dir,
  254. extra_preargs=self.strip_empty_entries(cflags))
  255. # Attempt to link the object file into an executable.
  256. # Be sure to tack on any libraries that have been specified.
  257. ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
  258. compiler.link_executable(objs,
  259. os.path.join(tempdir,
  260. 'a.out'),
  261. extra_preargs=self.strip_empty_entries(ldflags),
  262. libraries=libraries)
  263. # Compile and link succeeded
  264. return True
  265. except CompileError:
  266. return False
  267. except LinkError:
  268. return False
  269. except:
  270. return False
  271. finally:
  272. # Restore stderr file descriptor and close the stderr redirect file.
  273. if oldstderr is not None:
  274. os.dup2(oldstderr, sys.stderr.fileno())
  275. if filestderr is not None:
  276. filestderr.close()
  277. # Delete the temporary directory holding the test program and stderr files.
  278. if tempdir is not None:
  279. shutil.rmtree(tempdir)
  280. def strip_empty_entries(self, args):
  281. '''
  282. Drop any empty strings from the list of compile and link flags
  283. '''
  284. return [x for x in args if len(x) > 0]
  285. def cpu_arch(self):
  286. try:
  287. from cpuinfo import get_cpu_info
  288. except ImportError as e:
  289. cpu_info = self._backup_cpuinfo()
  290. if cpu_info is None:
  291. return "-march=native"
  292. try:
  293. cpu_info = get_cpu_info()
  294. except Exception as e:
  295. self.warning(
  296. f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  297. "falling back to `lscpu` to get this information.")
  298. cpu_info = self._backup_cpuinfo()
  299. if cpu_info is None:
  300. return "-march=native"
  301. if cpu_info['arch'].startswith('PPC_'):
  302. # gcc does not provide -march on PowerPC, use -mcpu instead
  303. return '-mcpu=native'
  304. return '-march=native'
  305. def _backup_cpuinfo(self):
  306. # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides
  307. if not self.command_exists('lscpu'):
  308. self.warning(
  309. f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo "
  310. "to detect the CPU architecture. 'lscpu' does not appear to exist on "
  311. "your system, will fall back to use -march=native and non-vectorized execution."
  312. )
  313. return None
  314. result = subprocess.check_output('lscpu', shell=True)
  315. result = result.decode('utf-8').strip().lower()
  316. cpu_info = {}
  317. cpu_info['arch'] = None
  318. cpu_info['flags'] = ""
  319. if 'genuineintel' in result or 'authenticamd' in result:
  320. cpu_info['arch'] = 'X86_64'
  321. if 'avx512' in result:
  322. cpu_info['flags'] += 'avx512,'
  323. if 'avx2' in result:
  324. cpu_info['flags'] += 'avx2'
  325. elif 'ppc64le' in result:
  326. cpu_info['arch'] = "PPC_"
  327. return cpu_info
  328. def simd_width(self):
  329. try:
  330. from cpuinfo import get_cpu_info
  331. except ImportError as e:
  332. cpu_info = self._backup_cpuinfo()
  333. if cpu_info is None:
  334. return '-D__SCALAR__'
  335. try:
  336. cpu_info = get_cpu_info()
  337. except Exception as e:
  338. self.warning(
  339. f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  340. "falling back to `lscpu` to get this information.")
  341. cpu_info = self._backup_cpuinfo()
  342. if cpu_info is None:
  343. return '-D__SCALAR__'
  344. if cpu_info['arch'] == 'X86_64':
  345. if 'avx512' in cpu_info['flags']:
  346. return '-D__AVX512__'
  347. elif 'avx2' in cpu_info['flags']:
  348. return '-D__AVX256__'
  349. return '-D__SCALAR__'
  350. def python_requirements(self):
  351. '''
  352. Override if op wants to define special dependencies, otherwise will
  353. take self.name and load requirements-<op-name>.txt if it exists.
  354. '''
  355. path = f'requirements/requirements-{self.name}.txt'
  356. requirements = []
  357. if os.path.isfile(path):
  358. with open(path, 'r') as fd:
  359. requirements = [r.strip() for r in fd.readlines()]
  360. return requirements
  361. def command_exists(self, cmd):
  362. if '|' in cmd:
  363. cmds = cmd.split("|")
  364. else:
  365. cmds = [cmd]
  366. valid = False
  367. for cmd in cmds:
  368. result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
  369. valid = valid or result.wait() == 0
  370. if not valid and len(cmds) > 1:
  371. print(
  372. f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!"
  373. )
  374. elif not valid and len(cmds) == 1:
  375. print(
  376. f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!"
  377. )
  378. return valid
  379. def warning(self, msg):
  380. self.error_log = f"{msg}"
  381. print(f"{WARNING} {msg}")
  382. def deepspeed_src_path(self, code_path):
  383. if os.path.isabs(code_path):
  384. return code_path
  385. else:
  386. return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
  387. def builder(self):
  388. from torch.utils.cpp_extension import CppExtension
  389. return CppExtension(
  390. name=self.absolute_name(),
  391. sources=self.strip_empty_entries(self.sources()),
  392. include_dirs=self.strip_empty_entries(self.include_paths()),
  393. extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
  394. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  395. def load(self, verbose=True):
  396. from ...git_version_info import installed_ops, torch_info
  397. if installed_ops[self.name]:
  398. # Ensure the op we're about to load was compiled with the same
  399. # torch/cuda versions we are currently using at runtime.
  400. if isinstance(self, CUDAOpBuilder):
  401. self.assert_torch_info(torch_info)
  402. return importlib.import_module(self.absolute_name())
  403. else:
  404. return self.jit_load(verbose)
  405. def jit_load(self, verbose=True):
  406. if not self.is_compatible(verbose):
  407. raise RuntimeError(
  408. f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
  409. )
  410. try:
  411. import ninja # noqa: F401
  412. except ImportError:
  413. raise RuntimeError(
  414. f"Unable to JIT load the {self.name} op due to ninja not being installed."
  415. )
  416. if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
  417. assert_no_cuda_mismatch()
  418. self.jit_mode = True
  419. from torch.utils.cpp_extension import load
  420. # Ensure directory exists to prevent race condition in some cases
  421. ext_path = os.path.join(
  422. os.environ.get('TORCH_EXTENSIONS_DIR',
  423. DEFAULT_TORCH_EXTENSION_PATH),
  424. self.name)
  425. os.makedirs(ext_path, exist_ok=True)
  426. start_build = time.time()
  427. sources = [self.deepspeed_src_path(path) for path in self.sources()]
  428. extra_include_paths = [
  429. self.deepspeed_src_path(path) for path in self.include_paths()
  430. ]
  431. # Torch will try and apply whatever CCs are in the arch list at compile time,
  432. # we have already set the intended targets ourselves we know that will be
  433. # needed at runtime. This prevents CC collisions such as multiple __half
  434. # implementations. Stash arch list to reset after build.
  435. torch_arch_list = None
  436. if "TORCH_CUDA_ARCH_LIST" in os.environ:
  437. torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
  438. os.environ["TORCH_CUDA_ARCH_LIST"] = ""
  439. op_module = load(
  440. name=self.name,
  441. sources=self.strip_empty_entries(sources),
  442. extra_include_paths=self.strip_empty_entries(extra_include_paths),
  443. extra_cflags=self.strip_empty_entries(self.cxx_args()),
  444. extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()),
  445. extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
  446. verbose=verbose)
  447. build_duration = time.time() - start_build
  448. if verbose:
  449. print(f"Time to load {self.name} op: {build_duration} seconds")
  450. # Reset arch list so we are not silently removing it for other possible use cases
  451. if torch_arch_list:
  452. os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
  453. return op_module
  454. class CUDAOpBuilder(OpBuilder):
  455. def compute_capability_args(self, cross_compile_archs=None):
  456. """
  457. Returns nvcc compute capability compile flags.
  458. 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
  459. 2. If neither is set default compute capabilities will be used
  460. 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
  461. Format:
  462. - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
  463. TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
  464. TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
  465. - `cross_compile_archs` uses ; separator.
  466. """
  467. ccs = []
  468. if self.jit_mode:
  469. # Compile for underlying architectures since we know those at runtime
  470. for i in range(torch.cuda.device_count()):
  471. CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
  472. cc = f"{CC_MAJOR}.{CC_MINOR}"
  473. if cc not in ccs:
  474. ccs.append(cc)
  475. ccs = sorted(ccs)
  476. ccs[-1] += '+PTX'
  477. else:
  478. # Cross-compile mode, compile for various architectures
  479. # env override takes priority
  480. cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
  481. if cross_compile_archs_env is not None:
  482. if cross_compile_archs is not None:
  483. print(
  484. f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
  485. )
  486. cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
  487. else:
  488. if cross_compile_archs is None:
  489. cross_compile_archs = get_default_compute_capabilities()
  490. ccs = cross_compile_archs.split(';')
  491. args = []
  492. for cc in ccs:
  493. num = cc[0] + cc[2]
  494. args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
  495. if cc.endswith('+PTX'):
  496. args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
  497. return args
  498. def version_dependent_macros(self):
  499. # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
  500. version_ge_1_1 = []
  501. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  502. version_ge_1_1 = ['-DVERSION_GE_1_1']
  503. version_ge_1_3 = []
  504. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  505. version_ge_1_3 = ['-DVERSION_GE_1_3']
  506. version_ge_1_5 = []
  507. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  508. version_ge_1_5 = ['-DVERSION_GE_1_5']
  509. return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  510. def is_compatible(self, verbose=True):
  511. return super().is_compatible(verbose)
  512. def builder(self):
  513. from torch.utils.cpp_extension import CUDAExtension
  514. if not self.is_rocm_pytorch():
  515. assert_no_cuda_mismatch()
  516. cuda_ext = CUDAExtension(
  517. name=self.absolute_name(),
  518. sources=self.strip_empty_entries(self.sources()),
  519. include_dirs=self.strip_empty_entries(self.include_paths()),
  520. libraries=self.strip_empty_entries(self.libraries_args()),
  521. extra_compile_args={
  522. 'cxx': self.strip_empty_entries(self.cxx_args()),
  523. 'nvcc': self.strip_empty_entries(self.nvcc_args())
  524. })
  525. if self.is_rocm_pytorch():
  526. # hip converts paths to absolute, this converts back to relative
  527. sources = cuda_ext.sources
  528. curr_file = Path(__file__).parent.parent # ds root
  529. for i in range(len(sources)):
  530. src = Path(sources[i])
  531. sources[i] = str(src.relative_to(curr_file))
  532. cuda_ext.sources = sources
  533. return cuda_ext
  534. def hipify_extension(self):
  535. if self.is_rocm_pytorch():
  536. from torch.utils.hipify import hipify_python
  537. hipify_python.hipify(
  538. project_directory=os.getcwd(),
  539. output_directory=os.getcwd(),
  540. header_include_dirs=self.include_paths(),
  541. includes=[os.path.join(os.getcwd(),
  542. '*')],
  543. extra_files=[os.path.abspath(s) for s in self.sources()],
  544. show_detailed=True,
  545. is_pytorch_extension=True,
  546. hipify_extra_files_only=True,
  547. )
  548. def cxx_args(self):
  549. if sys.platform == "win32":
  550. return ['-O2']
  551. else:
  552. return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
  553. def nvcc_args(self):
  554. args = ['-O3']
  555. if self.is_rocm_pytorch():
  556. ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version()
  557. args += [
  558. '-std=c++14',
  559. '-U__HIP_NO_HALF_OPERATORS__',
  560. '-U__HIP_NO_HALF_CONVERSIONS__',
  561. '-U__HIP_NO_HALF2_OPERATORS__',
  562. '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
  563. '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
  564. ]
  565. else:
  566. cuda_major, _ = installed_cuda_version()
  567. args += [
  568. '--use_fast_math',
  569. '-std=c++17'
  570. if sys.platform == "win32" and cuda_major > 10 else '-std=c++14',
  571. '-U__CUDA_NO_HALF_OPERATORS__',
  572. '-U__CUDA_NO_HALF_CONVERSIONS__',
  573. '-U__CUDA_NO_HALF2_OPERATORS__'
  574. ]
  575. args += self.compute_capability_args()
  576. return args
  577. def libraries_args(self):
  578. if sys.platform == "win32":
  579. return ['cublas', 'curand']
  580. else:
  581. return []
  582. class TorchCPUOpBuilder(CUDAOpBuilder):
  583. def extra_ldflags(self):
  584. if not self.is_rocm_pytorch():
  585. return ['-lcurand']
  586. else:
  587. return []
  588. def cxx_args(self):
  589. import torch
  590. if not self.is_rocm_pytorch():
  591. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
  592. else:
  593. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
  594. CPU_ARCH = self.cpu_arch()
  595. SIMD_WIDTH = self.simd_width()
  596. args = super().cxx_args()
  597. args += [
  598. f'-L{CUDA_LIB64}',
  599. '-lcudart',
  600. '-lcublas',
  601. '-g',
  602. CPU_ARCH,
  603. '-fopenmp',
  604. SIMD_WIDTH,
  605. ]
  606. return args