builder.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import sys
  7. import time
  8. import importlib
  9. from pathlib import Path
  10. import subprocess
  11. import shlex
  12. import shutil
  13. import tempfile
  14. import distutils.ccompiler
  15. import distutils.log
  16. import distutils.sysconfig
  17. from distutils.errors import CompileError, LinkError
  18. from abc import ABC, abstractmethod
  19. from typing import List
  20. YELLOW = '\033[93m'
  21. END = '\033[0m'
  22. WARNING = f"{YELLOW} [WARNING] {END}"
  23. DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
  24. DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
  25. try:
  26. import torch
  27. except ImportError:
  28. print(f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops.")
  29. else:
  30. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  31. TORCH_MINOR = int(torch.__version__.split('.')[1])
  32. class MissingCUDAException(Exception):
  33. pass
  34. class CUDAMismatchException(Exception):
  35. pass
  36. def installed_cuda_version(name=""):
  37. import torch.utils.cpp_extension
  38. cuda_home = torch.utils.cpp_extension.CUDA_HOME
  39. if cuda_home is None:
  40. raise MissingCUDAException("CUDA_HOME does not exist, unable to compile CUDA op(s)")
  41. # Ensure there is not a cuda version mismatch between torch and nvcc compiler
  42. output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], universal_newlines=True)
  43. output_split = output.split()
  44. release_idx = output_split.index("release")
  45. release = output_split[release_idx + 1].replace(',', '').split(".")
  46. # Ignore patch versions, only look at major + minor
  47. cuda_major, cuda_minor = release[:2]
  48. return int(cuda_major), int(cuda_minor)
  49. def get_default_compute_capabilities():
  50. compute_caps = DEFAULT_COMPUTE_CAPABILITIES
  51. import torch.utils.cpp_extension
  52. if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11:
  53. if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
  54. # Special treatment of CUDA 11.0 because compute_86 is not supported.
  55. compute_caps += ";8.0"
  56. else:
  57. compute_caps += ";8.0;8.6"
  58. return compute_caps
  59. # list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
  60. # to build deepspeed and system-wide installed cuda 11.2
  61. cuda_minor_mismatch_ok = {
  62. 10: ["10.0", "10.1", "10.2"],
  63. 11: ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"],
  64. 12: ["12.0", "12.1", "12.2", "12.3", "12.4", "12.5", "12.6"],
  65. }
  66. def assert_no_cuda_mismatch(name=""):
  67. cuda_major, cuda_minor = installed_cuda_version(name)
  68. sys_cuda_version = f'{cuda_major}.{cuda_minor}'
  69. torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  70. # This is a show-stopping error, should probably not proceed past this
  71. if sys_cuda_version != torch_cuda_version:
  72. if (cuda_major in cuda_minor_mismatch_ok and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
  73. and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
  74. print(f"Installed CUDA version {sys_cuda_version} does not match the "
  75. f"version torch was compiled with {torch.version.cuda} "
  76. "but since the APIs are compatible, accepting this combination")
  77. return True
  78. elif os.getenv("DS_SKIP_CUDA_CHECK", "0") == "1":
  79. print(
  80. f"{WARNING} DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
  81. f"version torch was compiled with {torch.version.cuda}."
  82. "Detected `DS_SKIP_CUDA_CHECK=1`: Allowing this combination of CUDA, but it may result in unexpected behavior."
  83. )
  84. return True
  85. raise CUDAMismatchException(
  86. f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the "
  87. f"version torch was compiled with {torch.version.cuda}, unable to compile "
  88. "cuda/cpp extensions without a matching cuda version.")
  89. return True
  90. class OpBuilder(ABC):
  91. _rocm_version = None
  92. _rocm_gpu_arch = None
  93. _rocm_wavefront_size = None
  94. _is_rocm_pytorch = None
  95. _is_sycl_enabled = None
  96. _loaded_ops = {}
  97. def __init__(self, name):
  98. self.name = name
  99. self.jit_mode = False
  100. self.build_for_cpu = False
  101. self.enable_bf16 = False
  102. self.error_log = None
  103. @abstractmethod
  104. def absolute_name(self):
  105. '''
  106. Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
  107. will be installed as something like: deepspeed/ops/adam/cpu_adam.so
  108. '''
  109. pass
  110. @abstractmethod
  111. def sources(self):
  112. '''
  113. Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  114. '''
  115. pass
  116. def hipify_extension(self):
  117. pass
  118. def sycl_extension(self):
  119. pass
  120. @staticmethod
  121. def validate_torch_version(torch_info):
  122. install_torch_version = torch_info['version']
  123. current_torch_version = ".".join(torch.__version__.split('.')[:2])
  124. if install_torch_version != current_torch_version:
  125. raise RuntimeError("PyTorch version mismatch! DeepSpeed ops were compiled and installed "
  126. "with a different version than what is being used at runtime. "
  127. f"Please re-install DeepSpeed or switch torch versions. "
  128. f"Install torch version={install_torch_version}, "
  129. f"Runtime torch version={current_torch_version}")
  130. @staticmethod
  131. def validate_torch_op_version(torch_info):
  132. if not OpBuilder.is_rocm_pytorch():
  133. current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  134. install_cuda_version = torch_info['cuda_version']
  135. if install_cuda_version != current_cuda_version:
  136. raise RuntimeError("CUDA version mismatch! DeepSpeed ops were compiled and installed "
  137. "with a different version than what is being used at runtime. "
  138. f"Please re-install DeepSpeed or switch torch versions. "
  139. f"Install CUDA version={install_cuda_version}, "
  140. f"Runtime CUDA version={current_cuda_version}")
  141. else:
  142. current_hip_version = ".".join(torch.version.hip.split('.')[:2])
  143. install_hip_version = torch_info['hip_version']
  144. if install_hip_version != current_hip_version:
  145. raise RuntimeError("HIP version mismatch! DeepSpeed ops were compiled and installed "
  146. "with a different version than what is being used at runtime. "
  147. f"Please re-install DeepSpeed or switch torch versions. "
  148. f"Install HIP version={install_hip_version}, "
  149. f"Runtime HIP version={current_hip_version}")
  150. @staticmethod
  151. def is_rocm_pytorch():
  152. if OpBuilder._is_rocm_pytorch is not None:
  153. return OpBuilder._is_rocm_pytorch
  154. _is_rocm_pytorch = False
  155. try:
  156. import torch
  157. except ImportError:
  158. pass
  159. else:
  160. if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5):
  161. _is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None
  162. if _is_rocm_pytorch:
  163. from torch.utils.cpp_extension import ROCM_HOME
  164. _is_rocm_pytorch = ROCM_HOME is not None
  165. OpBuilder._is_rocm_pytorch = _is_rocm_pytorch
  166. return OpBuilder._is_rocm_pytorch
  167. @staticmethod
  168. def is_sycl_enabled():
  169. if OpBuilder._is_sycl_enabled is not None:
  170. return OpBuilder._is_sycl_enabled
  171. _is_sycl_enabled = False
  172. try:
  173. result = subprocess.run(["c2s", "--version"], capture_output=True)
  174. except:
  175. pass
  176. else:
  177. _is_sycl_enabled = True
  178. OpBuilder._is_sycl_enabled = _is_sycl_enabled
  179. return OpBuilder._is_sycl_enabled
  180. @staticmethod
  181. def installed_rocm_version():
  182. if OpBuilder._rocm_version:
  183. return OpBuilder._rocm_version
  184. ROCM_MAJOR = '0'
  185. ROCM_MINOR = '0'
  186. ROCM_VERSION_DEV_RAW = ""
  187. if OpBuilder.is_rocm_pytorch():
  188. from torch.utils.cpp_extension import ROCM_HOME
  189. rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version")
  190. if rocm_ver_file.is_file():
  191. with open(rocm_ver_file, 'r') as file:
  192. ROCM_VERSION_DEV_RAW = file.read()
  193. elif "rocm" in torch.__version__:
  194. ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1]
  195. if ROCM_VERSION_DEV_RAW != "":
  196. ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0]
  197. ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1]
  198. else:
  199. # Look in /usr/include/rocm-version.h
  200. rocm_ver_file = Path("/usr/include/rocm_version.h")
  201. if rocm_ver_file.is_file():
  202. with open(rocm_ver_file, 'r') as file:
  203. for ln in file.readlines():
  204. if "#define ROCM_VERSION_MAJOR" in ln:
  205. ROCM_MAJOR = re.findall(r'\S+', ln)[2]
  206. elif "#define ROCM_VERSION_MINOR" in ln:
  207. ROCM_MINOR = re.findall(r'\S+', ln)[2]
  208. if ROCM_MAJOR == '0':
  209. assert False, "Could not detect ROCm version"
  210. OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR))
  211. return OpBuilder._rocm_version
  212. @staticmethod
  213. def get_rocm_gpu_arch():
  214. if OpBuilder._rocm_gpu_arch:
  215. return OpBuilder._rocm_gpu_arch
  216. rocm_info = Path("/opt/rocm/bin/rocminfo")
  217. if (not rocm_info.is_file()):
  218. rocm_info = Path("rocminfo")
  219. rocm_gpu_arch_cmd = str(rocm_info) + " | grep -o -m 1 'gfx.*'"
  220. try:
  221. result = subprocess.check_output(rocm_gpu_arch_cmd, shell=True)
  222. rocm_gpu_arch = result.decode('utf-8').strip()
  223. except subprocess.CalledProcessError:
  224. rocm_gpu_arch = ""
  225. OpBuilder._rocm_gpu_arch = rocm_gpu_arch
  226. return OpBuilder._rocm_gpu_arch
  227. @staticmethod
  228. def get_rocm_wavefront_size():
  229. if OpBuilder._rocm_wavefront_size:
  230. return OpBuilder._rocm_wavefront_size
  231. rocm_info = Path("/opt/rocm/bin/rocminfo")
  232. if (not rocm_info.is_file()):
  233. rocm_info = Path("rocminfo")
  234. rocm_wavefront_size_cmd = str(
  235. rocm_info) + " | grep -Eo -m1 'Wavefront Size:[[:space:]]+[0-9]+' | grep -Eo '[0-9]+'"
  236. try:
  237. result = subprocess.check_output(rocm_wavefront_size_cmd, shell=True)
  238. rocm_wavefront_size = result.decode('utf-8').strip()
  239. except subprocess.CalledProcessError:
  240. rocm_wavefront_size = "32"
  241. OpBuilder._rocm_wavefront_size = rocm_wavefront_size
  242. return OpBuilder._rocm_wavefront_size
  243. def include_paths(self):
  244. '''
  245. Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  246. '''
  247. return []
  248. def nvcc_args(self):
  249. '''
  250. Returns optional list of compiler flags to forward to nvcc when building CUDA sources
  251. '''
  252. return []
  253. def cxx_args(self):
  254. '''
  255. Returns optional list of compiler flags to forward to the build
  256. '''
  257. return []
  258. def is_compatible(self, verbose=False):
  259. '''
  260. Check if all non-python dependencies are satisfied to build this op
  261. '''
  262. return True
  263. def extra_ldflags(self):
  264. return []
  265. def has_function(self, funcname, libraries, library_dirs=None, verbose=False):
  266. '''
  267. Test for existence of a function within a tuple of libraries.
  268. This is used as a smoke test to check whether a certain library is available.
  269. As a test, this creates a simple C program that calls the specified function,
  270. and then distutils is used to compile that program and link it with the specified libraries.
  271. Returns True if both the compile and link are successful, False otherwise.
  272. '''
  273. tempdir = None # we create a temporary directory to hold various files
  274. filestderr = None # handle to open file to which we redirect stderr
  275. oldstderr = None # file descriptor for stderr
  276. try:
  277. # Echo compile and link commands that are used.
  278. if verbose:
  279. distutils.log.set_verbosity(1)
  280. # Create a compiler object.
  281. compiler = distutils.ccompiler.new_compiler(verbose=verbose)
  282. # Configure compiler and linker to build according to Python install.
  283. distutils.sysconfig.customize_compiler(compiler)
  284. # Create a temporary directory to hold test files.
  285. tempdir = tempfile.mkdtemp()
  286. # Define a simple C program that calls the function in question
  287. prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (funcname, funcname)
  288. # Write the test program to a file.
  289. filename = os.path.join(tempdir, 'test.c')
  290. with open(filename, 'w') as f:
  291. f.write(prog)
  292. # Redirect stderr file descriptor to a file to silence compile/link warnings.
  293. if not verbose:
  294. filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
  295. oldstderr = os.dup(sys.stderr.fileno())
  296. os.dup2(filestderr.fileno(), sys.stderr.fileno())
  297. # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames()
  298. # Otherwise, a local directory will be used instead of tempdir
  299. drive, driveless_filename = os.path.splitdrive(filename)
  300. root_dir = driveless_filename[0] if os.path.isabs(driveless_filename) else ''
  301. output_dir = os.path.join(drive, root_dir)
  302. # Attempt to compile the C program into an object file.
  303. cflags = shlex.split(os.environ.get('CFLAGS', ""))
  304. objs = compiler.compile([filename], output_dir=output_dir, extra_preargs=self.strip_empty_entries(cflags))
  305. # Attempt to link the object file into an executable.
  306. # Be sure to tack on any libraries that have been specified.
  307. ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
  308. compiler.link_executable(objs,
  309. os.path.join(tempdir, 'a.out'),
  310. extra_preargs=self.strip_empty_entries(ldflags),
  311. libraries=libraries,
  312. library_dirs=library_dirs)
  313. # Compile and link succeeded
  314. return True
  315. except CompileError:
  316. return False
  317. except LinkError:
  318. return False
  319. except:
  320. return False
  321. finally:
  322. # Restore stderr file descriptor and close the stderr redirect file.
  323. if oldstderr is not None:
  324. os.dup2(oldstderr, sys.stderr.fileno())
  325. if filestderr is not None:
  326. filestderr.close()
  327. # Delete the temporary directory holding the test program and stderr files.
  328. if tempdir is not None:
  329. shutil.rmtree(tempdir)
  330. def strip_empty_entries(self, args):
  331. '''
  332. Drop any empty strings from the list of compile and link flags
  333. '''
  334. return [x for x in args if len(x) > 0]
  335. def cpu_arch(self):
  336. try:
  337. from cpuinfo import get_cpu_info
  338. except ImportError as e:
  339. cpu_info = self._backup_cpuinfo()
  340. if cpu_info is None:
  341. return "-march=native"
  342. try:
  343. cpu_info = get_cpu_info()
  344. except Exception as e:
  345. self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  346. "falling back to `lscpu` to get this information.")
  347. cpu_info = self._backup_cpuinfo()
  348. if cpu_info is None:
  349. return "-march=native"
  350. if cpu_info['arch'].startswith('PPC_'):
  351. # gcc does not provide -march on PowerPC, use -mcpu instead
  352. return '-mcpu=native'
  353. return '-march=native'
  354. def is_cuda_enable(self):
  355. try:
  356. assert_no_cuda_mismatch(self.name)
  357. return '-D__ENABLE_CUDA__'
  358. except MissingCUDAException:
  359. print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
  360. "only cpu ops can be compiled!")
  361. return '-D__DISABLE_CUDA__'
  362. return '-D__DISABLE_CUDA__'
  363. def _backup_cpuinfo(self):
  364. # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides
  365. if not self.command_exists('lscpu'):
  366. self.warning(f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo "
  367. "to detect the CPU architecture. 'lscpu' does not appear to exist on "
  368. "your system, will fall back to use -march=native and non-vectorized execution.")
  369. return None
  370. result = subprocess.check_output(['lscpu'])
  371. result = result.decode('utf-8').strip().lower()
  372. cpu_info = {}
  373. cpu_info['arch'] = None
  374. cpu_info['flags'] = ""
  375. if 'genuineintel' in result or 'authenticamd' in result:
  376. cpu_info['arch'] = 'X86_64'
  377. if 'avx512' in result:
  378. cpu_info['flags'] += 'avx512,'
  379. elif 'avx512f' in result:
  380. cpu_info['flags'] += 'avx512f,'
  381. if 'avx2' in result:
  382. cpu_info['flags'] += 'avx2'
  383. elif 'ppc64le' in result:
  384. cpu_info['arch'] = "PPC_"
  385. return cpu_info
  386. def simd_width(self):
  387. try:
  388. from cpuinfo import get_cpu_info
  389. except ImportError as e:
  390. cpu_info = self._backup_cpuinfo()
  391. if cpu_info is None:
  392. return '-D__SCALAR__'
  393. try:
  394. cpu_info = get_cpu_info()
  395. except Exception as e:
  396. self.warning(f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), "
  397. "falling back to `lscpu` to get this information.")
  398. cpu_info = self._backup_cpuinfo()
  399. if cpu_info is None:
  400. return '-D__SCALAR__'
  401. if cpu_info['arch'] == 'X86_64':
  402. if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']:
  403. return '-D__AVX512__'
  404. elif 'avx2' in cpu_info['flags']:
  405. return '-D__AVX256__'
  406. return '-D__SCALAR__'
  407. def command_exists(self, cmd):
  408. if '|' in cmd:
  409. cmds = cmd.split("|")
  410. else:
  411. cmds = [cmd]
  412. valid = False
  413. for cmd in cmds:
  414. safe_cmd = ["bash", "-c", f"type {cmd}"]
  415. result = subprocess.Popen(safe_cmd, stdout=subprocess.PIPE)
  416. valid = valid or result.wait() == 0
  417. if not valid and len(cmds) > 1:
  418. print(f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!")
  419. elif not valid and len(cmds) == 1:
  420. print(f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!")
  421. return valid
  422. def warning(self, msg):
  423. self.error_log = f"{msg}"
  424. print(f"{WARNING} {msg}")
  425. def deepspeed_src_path(self, code_path):
  426. if os.path.isabs(code_path):
  427. return code_path
  428. else:
  429. return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
  430. def builder(self):
  431. from torch.utils.cpp_extension import CppExtension
  432. include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())]
  433. return CppExtension(name=self.absolute_name(),
  434. sources=self.strip_empty_entries(self.sources()),
  435. include_dirs=include_dirs,
  436. extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
  437. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  438. def load(self, verbose=True):
  439. if self.name in __class__._loaded_ops:
  440. return __class__._loaded_ops[self.name]
  441. from deepspeed.git_version_info import installed_ops, torch_info, accelerator_name
  442. from deepspeed.accelerator import get_accelerator
  443. if installed_ops.get(self.name, False) and accelerator_name == get_accelerator()._name:
  444. # Ensure the op we're about to load was compiled with the same
  445. # torch/cuda versions we are currently using at runtime.
  446. self.validate_torch_version(torch_info)
  447. if torch.cuda.is_available() and isinstance(self, CUDAOpBuilder):
  448. self.validate_torch_op_version(torch_info)
  449. op_module = importlib.import_module(self.absolute_name())
  450. __class__._loaded_ops[self.name] = op_module
  451. return op_module
  452. else:
  453. return self.jit_load(verbose)
  454. def jit_load(self, verbose=True):
  455. if not self.is_compatible(verbose):
  456. raise RuntimeError(
  457. f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue. {self.error_log}"
  458. )
  459. try:
  460. import ninja # noqa: F401 # type: ignore
  461. except ImportError:
  462. raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
  463. if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
  464. self.build_for_cpu = not torch.cuda.is_available()
  465. self.jit_mode = True
  466. from torch.utils.cpp_extension import load
  467. start_build = time.time()
  468. sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()]
  469. extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()]
  470. # Torch will try and apply whatever CCs are in the arch list at compile time,
  471. # we have already set the intended targets ourselves we know that will be
  472. # needed at runtime. This prevents CC collisions such as multiple __half
  473. # implementations. Stash arch list to reset after build.
  474. torch_arch_list = None
  475. if "TORCH_CUDA_ARCH_LIST" in os.environ:
  476. torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
  477. os.environ["TORCH_CUDA_ARCH_LIST"] = ""
  478. nvcc_args = self.strip_empty_entries(self.nvcc_args())
  479. cxx_args = self.strip_empty_entries(self.cxx_args())
  480. if isinstance(self, CUDAOpBuilder):
  481. if not self.build_for_cpu and self.enable_bf16:
  482. cxx_args.append("-DBF16_AVAILABLE")
  483. nvcc_args.append("-DBF16_AVAILABLE")
  484. nvcc_args.append("-U__CUDA_NO_BFLOAT16_OPERATORS__")
  485. nvcc_args.append("-U__CUDA_NO_BFLOAT162_OPERATORS__")
  486. nvcc_args.append("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
  487. if self.is_rocm_pytorch():
  488. cxx_args.append("-D__HIP_PLATFORM_AMD__=1")
  489. os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch()
  490. cxx_args.append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())
  491. op_module = load(name=self.name,
  492. sources=self.strip_empty_entries(sources),
  493. extra_include_paths=self.strip_empty_entries(extra_include_paths),
  494. extra_cflags=cxx_args,
  495. extra_cuda_cflags=nvcc_args,
  496. extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
  497. verbose=verbose)
  498. build_duration = time.time() - start_build
  499. if verbose:
  500. print(f"Time to load {self.name} op: {build_duration} seconds")
  501. # Reset arch list so we are not silently removing it for other possible use cases
  502. if torch_arch_list:
  503. os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
  504. __class__._loaded_ops[self.name] = op_module
  505. return op_module
  506. class CUDAOpBuilder(OpBuilder):
  507. def compute_capability_args(self, cross_compile_archs=None):
  508. """
  509. Returns nvcc compute capability compile flags.
  510. 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
  511. 2. If neither is set default compute capabilities will be used
  512. 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
  513. Format:
  514. - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
  515. TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
  516. TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
  517. - `cross_compile_archs` uses ; separator.
  518. """
  519. ccs = []
  520. if self.jit_mode:
  521. # Compile for underlying architectures since we know those at runtime
  522. for i in range(torch.cuda.device_count()):
  523. CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
  524. cc = f"{CC_MAJOR}.{CC_MINOR}"
  525. if cc not in ccs:
  526. ccs.append(cc)
  527. ccs = sorted(ccs)
  528. ccs[-1] += '+PTX'
  529. else:
  530. # Cross-compile mode, compile for various architectures
  531. # env override takes priority
  532. cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
  533. if cross_compile_archs_env is not None:
  534. if cross_compile_archs is not None:
  535. print(
  536. f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
  537. )
  538. cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
  539. else:
  540. if cross_compile_archs is None:
  541. cross_compile_archs = get_default_compute_capabilities()
  542. ccs = cross_compile_archs.split(';')
  543. ccs = self.filter_ccs(ccs)
  544. if len(ccs) == 0:
  545. raise RuntimeError(
  546. f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")
  547. args = []
  548. self.enable_bf16 = True
  549. for cc in ccs:
  550. num = cc[0] + cc[2]
  551. args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
  552. if cc.endswith('+PTX'):
  553. args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
  554. if int(cc[0]) <= 7:
  555. self.enable_bf16 = False
  556. return args
  557. def filter_ccs(self, ccs: List[str]):
  558. """
  559. Prune any compute capabilities that are not compatible with the builder. Should log
  560. which CCs have been pruned.
  561. """
  562. return ccs
  563. def version_dependent_macros(self):
  564. # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
  565. version_ge_1_1 = []
  566. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  567. version_ge_1_1 = ['-DVERSION_GE_1_1']
  568. version_ge_1_3 = []
  569. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  570. version_ge_1_3 = ['-DVERSION_GE_1_3']
  571. version_ge_1_5 = []
  572. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  573. version_ge_1_5 = ['-DVERSION_GE_1_5']
  574. return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  575. def is_compatible(self, verbose=False):
  576. return super().is_compatible(verbose)
  577. def builder(self):
  578. try:
  579. if not self.is_rocm_pytorch():
  580. assert_no_cuda_mismatch(self.name)
  581. self.build_for_cpu = False
  582. except MissingCUDAException:
  583. self.build_for_cpu = True
  584. if self.build_for_cpu:
  585. from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
  586. else:
  587. from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder
  588. include_dirs = [os.path.abspath(x) for x in self.strip_empty_entries(self.include_paths())]
  589. compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \
  590. {'cxx': self.strip_empty_entries(self.cxx_args()), \
  591. 'nvcc': self.strip_empty_entries(self.nvcc_args())}
  592. if not self.build_for_cpu and self.enable_bf16:
  593. compile_args['cxx'].append("-DBF16_AVAILABLE")
  594. compile_args['nvcc'].append("-DBF16_AVAILABLE")
  595. if self.is_rocm_pytorch():
  596. compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1")
  597. #cxx compiler args are required to compile cpp files
  598. compile_args['cxx'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())
  599. #nvcc compiler args are required to compile hip files
  600. compile_args['nvcc'].append('-DROCM_WAVEFRONT_SIZE=%s' % self.get_rocm_wavefront_size())
  601. if self.get_rocm_gpu_arch():
  602. os.environ["PYTORCH_ROCM_ARCH"] = self.get_rocm_gpu_arch()
  603. cuda_ext = ExtensionBuilder(name=self.absolute_name(),
  604. sources=self.strip_empty_entries(self.sources()),
  605. include_dirs=include_dirs,
  606. libraries=self.strip_empty_entries(self.libraries_args()),
  607. extra_compile_args=compile_args,
  608. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  609. if self.is_rocm_pytorch():
  610. # hip converts paths to absolute, this converts back to relative
  611. sources = cuda_ext.sources
  612. curr_file = Path(__file__).parent.parent # ds root
  613. for i in range(len(sources)):
  614. src = Path(sources[i])
  615. if src.is_absolute():
  616. sources[i] = str(src.relative_to(curr_file))
  617. else:
  618. sources[i] = str(src)
  619. cuda_ext.sources = sources
  620. return cuda_ext
  621. def hipify_extension(self):
  622. if self.is_rocm_pytorch():
  623. from torch.utils.hipify import hipify_python
  624. hipify_python.hipify(
  625. project_directory=os.getcwd(),
  626. output_directory=os.getcwd(),
  627. header_include_dirs=self.include_paths(),
  628. includes=[os.path.join(os.getcwd(), '*')],
  629. extra_files=[os.path.abspath(s) for s in self.sources()],
  630. show_detailed=True,
  631. is_pytorch_extension=True,
  632. hipify_extra_files_only=True,
  633. )
  634. def cxx_args(self):
  635. if sys.platform == "win32":
  636. return ['-O2']
  637. else:
  638. return ['-O3', '-std=c++17', '-g', '-Wno-reorder']
  639. def nvcc_args(self):
  640. if self.build_for_cpu:
  641. return []
  642. args = ['-O3']
  643. if self.is_rocm_pytorch():
  644. ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version()
  645. args += [
  646. '-std=c++17', '-U__HIP_NO_HALF_OPERATORS__', '-U__HIP_NO_HALF_CONVERSIONS__',
  647. '-U__HIP_NO_HALF2_OPERATORS__',
  648. '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR,
  649. '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR
  650. ]
  651. else:
  652. try:
  653. nvcc_threads = int(os.getenv("DS_NVCC_THREADS", ""))
  654. if nvcc_threads <= 0:
  655. raise ValueError("")
  656. except ValueError:
  657. nvcc_threads = min(os.cpu_count(), 8)
  658. cuda_major, cuda_minor = installed_cuda_version()
  659. if cuda_major > 10:
  660. if cuda_major == 12 and cuda_minor >= 5:
  661. std_lib = '-std=c++20'
  662. else:
  663. std_lib = '-std=c++17'
  664. else:
  665. std_lib = '-std=c++14'
  666. args += [
  667. '-allow-unsupported-compiler' if sys.platform == "win32" else '', '--use_fast_math', std_lib,
  668. '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
  669. f'--threads={nvcc_threads}'
  670. ]
  671. if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1':
  672. args.append('--ptxas-options=-v')
  673. args += self.compute_capability_args()
  674. return args
  675. def libraries_args(self):
  676. if self.build_for_cpu:
  677. return []
  678. if sys.platform == "win32":
  679. return ['cublas', 'curand']
  680. else:
  681. return []
  682. class TorchCPUOpBuilder(CUDAOpBuilder):
  683. def get_cuda_lib64_path(self):
  684. import torch
  685. if not self.is_rocm_pytorch():
  686. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib64")
  687. if not os.path.exists(CUDA_LIB64):
  688. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, "lib")
  689. else:
  690. CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, "lib")
  691. return CUDA_LIB64
  692. def extra_ldflags(self):
  693. if self.build_for_cpu:
  694. return ['-fopenmp']
  695. if not self.is_rocm_pytorch():
  696. ld_flags = ['-lcurand']
  697. if not self.build_for_cpu:
  698. ld_flags.append(f'-L{self.get_cuda_lib64_path()}')
  699. return ld_flags
  700. return []
  701. def cxx_args(self):
  702. args = []
  703. if not self.build_for_cpu:
  704. CUDA_LIB64 = self.get_cuda_lib64_path()
  705. args += super().cxx_args()
  706. args += [
  707. f'-L{CUDA_LIB64}',
  708. '-lcudart',
  709. '-lcublas',
  710. '-g',
  711. ]
  712. CPU_ARCH = self.cpu_arch()
  713. SIMD_WIDTH = self.simd_width()
  714. CUDA_ENABLE = self.is_cuda_enable()
  715. args += [
  716. CPU_ARCH,
  717. '-fopenmp',
  718. SIMD_WIDTH,
  719. CUDA_ENABLE,
  720. ]
  721. return args