builder.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. """
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. """
  4. import os
  5. import sys
  6. import time
  7. import importlib
  8. from pathlib import Path
  9. import subprocess
  10. import shlex
  11. import shutil
  12. import tempfile
  13. import distutils.ccompiler
  14. import distutils.log
  15. import distutils.sysconfig
  16. from distutils.errors import CompileError, LinkError
  17. from abc import ABC, abstractmethod
  18. YELLOW = '\033[93m'
  19. END = '\033[0m'
  20. WARNING = f"{YELLOW} [WARNING] {END}"
  21. DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
  22. DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"
  23. try:
  24. import torch
  25. except ImportError:
  26. print(
  27. f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops."
  28. )
  29. def installed_cuda_version():
  30. import torch.utils.cpp_extension
  31. cuda_home = torch.utils.cpp_extension.CUDA_HOME
  32. assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
  33. # Ensure there is not a cuda version mismatch between torch and nvcc compiler
  34. output = subprocess.check_output([cuda_home + "/bin/nvcc",
  35. "-V"],
  36. universal_newlines=True)
  37. output_split = output.split()
  38. release_idx = output_split.index("release")
  39. release = output_split[release_idx + 1].replace(',', '').split(".")
  40. # Ignore patch versions, only look at major + minor
  41. cuda_major, cuda_minor = release[:2]
  42. installed_cuda_version = ".".join(release[:2])
  43. return int(cuda_major), int(cuda_minor)
  44. def get_default_compute_capatabilities():
  45. compute_caps = DEFAULT_COMPUTE_CAPABILITIES
  46. import torch.utils.cpp_extension
  47. if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version(
  48. )[0] >= 11:
  49. if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
  50. # Special treatment of CUDA 11.0 because compute_86 is not supported.
  51. compute_caps += ";8.0"
  52. else:
  53. compute_caps += ";8.0;8.6"
  54. return compute_caps
  55. # list compatible minor CUDA versions - so that for example pytorch built with cuda-11.0 can be used
  56. # to build deepspeed and system-wide installed cuda 11.2
  57. cuda_minor_mismatch_ok = {
  58. 10: ["10.0",
  59. "10.1",
  60. "10.2"],
  61. 11: ["11.0",
  62. "11.1",
  63. "11.2",
  64. "11.3",
  65. "11.4"],
  66. }
  67. def assert_no_cuda_mismatch():
  68. cuda_major, cuda_minor = installed_cuda_version()
  69. sys_cuda_version = f'{cuda_major}.{cuda_minor}'
  70. torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  71. # This is a show-stopping error, should probably not proceed past this
  72. if sys_cuda_version != torch_cuda_version:
  73. if (cuda_major in cuda_minor_mismatch_ok
  74. and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major]
  75. and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]):
  76. print(f"Installed CUDA version {sys_cuda_version} does not match the "
  77. f"version torch was compiled with {torch.version.cuda} "
  78. "but since the APIs are compatible, accepting this combination")
  79. return
  80. raise Exception(
  81. f"Installed CUDA version {sys_cuda_version} does not match the "
  82. f"version torch was compiled with {torch.version.cuda}, unable to compile "
  83. "cuda/cpp extensions without a matching cuda version.")
  84. def assert_torch_info(torch_info):
  85. install_torch_version = torch_info['version']
  86. install_cuda_version = torch_info['cuda_version']
  87. current_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
  88. current_torch_version = ".".join(torch.__version__.split('.')[:2])
  89. if install_cuda_version != current_cuda_version or install_torch_version != current_torch_version:
  90. raise RuntimeError(
  91. "PyTorch and CUDA version mismatch! DeepSpeed ops were compiled and installed "
  92. "with a different version than what is being used at runtime. Please re-install "
  93. f"DeepSpeed or switch torch versions. DeepSpeed install versions: "
  94. f"torch={install_torch_version}, cuda={install_cuda_version}, runtime versions:"
  95. f"torch={current_torch_version}, cuda={current_cuda_version}")
  96. class OpBuilder(ABC):
  97. def __init__(self, name):
  98. self.name = name
  99. self.jit_mode = False
  100. @abstractmethod
  101. def absolute_name(self):
  102. '''
  103. Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam
  104. will be installed as something like: deepspeed/ops/adam/cpu_adam.so
  105. '''
  106. pass
  107. @abstractmethod
  108. def sources(self):
  109. '''
  110. Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  111. '''
  112. pass
  113. def include_paths(self):
  114. '''
  115. Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed)
  116. '''
  117. return []
  118. def nvcc_args(self):
  119. '''
  120. Returns optional list of compiler flags to forward to nvcc when building CUDA sources
  121. '''
  122. return []
  123. def cxx_args(self):
  124. '''
  125. Returns optional list of compiler flags to forward to the build
  126. '''
  127. return []
  128. def is_compatible(self):
  129. '''
  130. Check if all non-python dependencies are satisfied to build this op
  131. '''
  132. return True
  133. def extra_ldflags(self):
  134. return []
  135. def libraries_installed(self, libraries):
  136. valid = False
  137. check_cmd = 'dpkg -l'
  138. for lib in libraries:
  139. result = subprocess.Popen(f'dpkg -l {lib}',
  140. stdout=subprocess.PIPE,
  141. stderr=subprocess.PIPE,
  142. shell=True)
  143. valid = valid or result.wait() == 0
  144. return valid
  145. def has_function(self, funcname, libraries, verbose=False):
  146. '''
  147. Test for existence of a function within a tuple of libraries.
  148. This is used as a smoke test to check whether a certain library is avaiable.
  149. As a test, this creates a simple C program that calls the specified function,
  150. and then distutils is used to compile that program and link it with the specified libraries.
  151. Returns True if both the compile and link are successful, False otherwise.
  152. '''
  153. tempdir = None # we create a temporary directory to hold various files
  154. filestderr = None # handle to open file to which we redirect stderr
  155. oldstderr = None # file descriptor for stderr
  156. try:
  157. # Echo compile and link commands that are used.
  158. if verbose:
  159. distutils.log.set_verbosity(1)
  160. # Create a compiler object.
  161. compiler = distutils.ccompiler.new_compiler(verbose=verbose)
  162. # Configure compiler and linker to build according to Python install.
  163. distutils.sysconfig.customize_compiler(compiler)
  164. # Create a temporary directory to hold test files.
  165. tempdir = tempfile.mkdtemp()
  166. # Define a simple C program that calls the function in question
  167. prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % (
  168. funcname,
  169. funcname)
  170. # Write the test program to a file.
  171. filename = os.path.join(tempdir, 'test.c')
  172. with open(filename, 'w') as f:
  173. f.write(prog)
  174. # Redirect stderr file descriptor to a file to silence compile/link warnings.
  175. if not verbose:
  176. filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w')
  177. oldstderr = os.dup(sys.stderr.fileno())
  178. os.dup2(filestderr.fileno(), sys.stderr.fileno())
  179. # Attempt to compile the C program into an object file.
  180. cflags = shlex.split(os.environ.get('CFLAGS', ""))
  181. objs = compiler.compile([filename],
  182. extra_preargs=self.strip_empty_entries(cflags))
  183. # Attempt to link the object file into an executable.
  184. # Be sure to tack on any libraries that have been specified.
  185. ldflags = shlex.split(os.environ.get('LDFLAGS', ""))
  186. compiler.link_executable(objs,
  187. os.path.join(tempdir,
  188. 'a.out'),
  189. extra_preargs=self.strip_empty_entries(ldflags),
  190. libraries=libraries)
  191. # Compile and link succeeded
  192. return True
  193. except CompileError:
  194. return False
  195. except LinkError:
  196. return False
  197. except:
  198. return False
  199. finally:
  200. # Restore stderr file descriptor and close the stderr redirect file.
  201. if oldstderr is not None:
  202. os.dup2(oldstderr, sys.stderr.fileno())
  203. if filestderr is not None:
  204. filestderr.close()
  205. # Delete the temporary directory holding the test program and stderr files.
  206. if tempdir is not None:
  207. shutil.rmtree(tempdir)
  208. def strip_empty_entries(self, args):
  209. '''
  210. Drop any empty strings from the list of compile and link flags
  211. '''
  212. return [x for x in args if len(x) > 0]
  213. def cpu_arch(self):
  214. if not self.command_exists('lscpu'):
  215. self.warning(
  216. f"{self.name} attempted to query 'lscpu' to detect the CPU architecture. "
  217. "However, 'lscpu' does not appear to exist on "
  218. "your system, will fall back to use -march=native.")
  219. return '-march=native'
  220. result = subprocess.check_output('lscpu', shell=True)
  221. result = result.decode('utf-8').strip().lower()
  222. if 'ppc64le' in result:
  223. # gcc does not provide -march on PowerPC, use -mcpu instead
  224. return '-mcpu=native'
  225. return '-march=native'
  226. def simd_width(self):
  227. if not self.command_exists('lscpu'):
  228. self.warning(
  229. f"{self.name} attempted to query 'lscpu' to detect the existence "
  230. "of AVX instructions. However, 'lscpu' does not appear to exist on "
  231. "your system, will fall back to non-vectorized execution.")
  232. return '-D__SCALAR__'
  233. try:
  234. result = subprocess.check_output('lscpu', shell=True)
  235. result = result.decode('utf-8').strip().lower()
  236. except Exception as e:
  237. print(
  238. f"{WARNING} {self.name} SIMD_WIDTH cannot be recognized due to {str(e)}!"
  239. )
  240. return '-D__SCALAR__'
  241. if 'genuineintel' in result:
  242. if 'avx512' in result:
  243. return '-D__AVX512__'
  244. elif 'avx2' in result:
  245. return '-D__AVX256__'
  246. return '-D__SCALAR__'
  247. def python_requirements(self):
  248. '''
  249. Override if op wants to define special dependencies, otherwise will
  250. take self.name and load requirements-<op-name>.txt if it exists.
  251. '''
  252. path = f'requirements/requirements-{self.name}.txt'
  253. requirements = []
  254. if os.path.isfile(path):
  255. with open(path, 'r') as fd:
  256. requirements = [r.strip() for r in fd.readlines()]
  257. return requirements
  258. def command_exists(self, cmd):
  259. if '|' in cmd:
  260. cmds = cmd.split("|")
  261. else:
  262. cmds = [cmd]
  263. valid = False
  264. for cmd in cmds:
  265. result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
  266. valid = valid or result.wait() == 0
  267. if not valid and len(cmds) > 1:
  268. print(
  269. f"{WARNING} {self.name} requires one of the following commands '{cmds}', but it does not exist!"
  270. )
  271. elif not valid and len(cmds) == 1:
  272. print(
  273. f"{WARNING} {self.name} requires the '{cmd}' command, but it does not exist!"
  274. )
  275. return valid
  276. def warning(self, msg):
  277. print(f"{WARNING} {msg}")
  278. def deepspeed_src_path(self, code_path):
  279. if os.path.isabs(code_path):
  280. return code_path
  281. else:
  282. return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
  283. def builder(self):
  284. from torch.utils.cpp_extension import CppExtension
  285. return CppExtension(
  286. name=self.absolute_name(),
  287. sources=self.strip_empty_entries(self.sources()),
  288. include_dirs=self.strip_empty_entries(self.include_paths()),
  289. extra_compile_args={'cxx': self.strip_empty_entries(self.cxx_args())},
  290. extra_link_args=self.strip_empty_entries(self.extra_ldflags()))
  291. def load(self, verbose=True):
  292. from ...git_version_info import installed_ops, torch_info
  293. if installed_ops[self.name]:
  294. # Ensure the op we're about to load was compiled with the same
  295. # torch/cuda versions we are currently using at runtime.
  296. if isinstance(self, CUDAOpBuilder):
  297. assert_torch_info(torch_info)
  298. return importlib.import_module(self.absolute_name())
  299. else:
  300. return self.jit_load(verbose)
  301. def jit_load(self, verbose=True):
  302. if not self.is_compatible():
  303. raise RuntimeError(
  304. f"Unable to JIT load the {self.name} op due to it not being compatible due to hardware/software issue."
  305. )
  306. try:
  307. import ninja
  308. except ImportError:
  309. raise RuntimeError(
  310. f"Unable to JIT load the {self.name} op due to ninja not being installed."
  311. )
  312. if isinstance(self, CUDAOpBuilder):
  313. assert_no_cuda_mismatch()
  314. self.jit_mode = True
  315. from torch.utils.cpp_extension import load
  316. # Ensure directory exists to prevent race condition in some cases
  317. ext_path = os.path.join(
  318. os.environ.get('TORCH_EXTENSIONS_DIR',
  319. DEFAULT_TORCH_EXTENSION_PATH),
  320. self.name)
  321. os.makedirs(ext_path, exist_ok=True)
  322. start_build = time.time()
  323. sources = [self.deepspeed_src_path(path) for path in self.sources()]
  324. extra_include_paths = [
  325. self.deepspeed_src_path(path) for path in self.include_paths()
  326. ]
  327. op_module = load(
  328. name=self.name,
  329. sources=self.strip_empty_entries(sources),
  330. extra_include_paths=self.strip_empty_entries(extra_include_paths),
  331. extra_cflags=self.strip_empty_entries(self.cxx_args()),
  332. extra_cuda_cflags=self.strip_empty_entries(self.nvcc_args()),
  333. extra_ldflags=self.strip_empty_entries(self.extra_ldflags()),
  334. verbose=verbose)
  335. build_duration = time.time() - start_build
  336. if verbose:
  337. print(f"Time to load {self.name} op: {build_duration} seconds")
  338. return op_module
  339. class CUDAOpBuilder(OpBuilder):
  340. def compute_capability_args(self, cross_compile_archs=None):
  341. """
  342. Returns nvcc compute capability compile flags.
  343. 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
  344. 2. If neither is set default compute capabilities will be used
  345. 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
  346. Format:
  347. - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
  348. TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6" pip install ...
  349. TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" pip install ...
  350. - `cross_compile_archs` uses ; separator.
  351. """
  352. ccs = []
  353. if self.jit_mode:
  354. # Compile for underlying architectures since we know those at runtime
  355. for i in range(torch.cuda.device_count()):
  356. CC_MAJOR, CC_MINOR = torch.cuda.get_device_capability(i)
  357. cc = f"{CC_MAJOR}.{CC_MINOR}"
  358. if cc not in ccs:
  359. ccs.append(cc)
  360. ccs = sorted(ccs)
  361. ccs[-1] += '+PTX'
  362. else:
  363. # Cross-compile mode, compile for various architectures
  364. # env override takes priority
  365. cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
  366. if cross_compile_archs_env is not None:
  367. if cross_compile_archs is not None:
  368. print(
  369. f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`"
  370. )
  371. cross_compile_archs = cross_compile_archs_env.replace(' ', ';')
  372. else:
  373. if cross_compile_archs is None:
  374. cross_compile_archs = get_default_compute_capatabilities()
  375. ccs = cross_compile_archs.split(';')
  376. args = []
  377. for cc in ccs:
  378. num = cc[0] + cc[2]
  379. args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
  380. if cc.endswith('+PTX'):
  381. args.append(f'-gencode=arch=compute_{num},code=compute_{num}')
  382. return args
  383. def version_dependent_macros(self):
  384. # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456
  385. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  386. TORCH_MINOR = int(torch.__version__.split('.')[1])
  387. version_ge_1_1 = []
  388. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
  389. version_ge_1_1 = ['-DVERSION_GE_1_1']
  390. version_ge_1_3 = []
  391. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
  392. version_ge_1_3 = ['-DVERSION_GE_1_3']
  393. version_ge_1_5 = []
  394. if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
  395. version_ge_1_5 = ['-DVERSION_GE_1_5']
  396. return version_ge_1_1 + version_ge_1_3 + version_ge_1_5
  397. def is_compatible(self):
  398. return super().is_compatible()
  399. def builder(self):
  400. from torch.utils.cpp_extension import CUDAExtension
  401. assert_no_cuda_mismatch()
  402. return CUDAExtension(name=self.absolute_name(),
  403. sources=self.strip_empty_entries(self.sources()),
  404. include_dirs=self.strip_empty_entries(self.include_paths()),
  405. libraries=self.strip_empty_entries(self.libraries_args()),
  406. extra_compile_args={
  407. 'cxx': self.strip_empty_entries(self.cxx_args()),
  408. 'nvcc': self.strip_empty_entries(self.nvcc_args())
  409. })
  410. def cxx_args(self):
  411. if sys.platform == "win32":
  412. return ['-O2']
  413. else:
  414. return ['-O3', '-std=c++14', '-g', '-Wno-reorder']
  415. def nvcc_args(self):
  416. args = [
  417. '-O3',
  418. '--use_fast_math',
  419. '-std=c++17' if sys.platform == "win32" else '-std=c++14',
  420. '-U__CUDA_NO_HALF_OPERATORS__',
  421. '-U__CUDA_NO_HALF_CONVERSIONS__',
  422. '-U__CUDA_NO_HALF2_OPERATORS__'
  423. ]
  424. return args + self.compute_capability_args()
  425. def libraries_args(self):
  426. if sys.platform == "win32":
  427. return ['cublas', 'curand']
  428. else:
  429. return []