common.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import time
  7. import inspect
  8. import socket
  9. import subprocess
  10. from abc import ABC, abstractmethod
  11. from pathlib import Path
  12. import torch
  13. import torch.multiprocessing as mp
  14. import deepspeed
  15. from deepspeed.accelerator import get_accelerator
  16. import deepspeed.comm as dist
  17. import pytest
  18. from _pytest.outcomes import Skipped
  19. from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
  20. # Worker timeout for tests that hang
  21. DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600'))
  22. def is_rocm_pytorch():
  23. return hasattr(torch.version, 'hip') and torch.version.hip is not None
  24. def get_xdist_worker_id():
  25. xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None)
  26. if xdist_worker is not None:
  27. xdist_worker_id = xdist_worker.replace('gw', '')
  28. return int(xdist_worker_id)
  29. return None
  30. def get_master_port(base_port=29500, port_range_size=1000):
  31. xdist_worker_id = get_xdist_worker_id()
  32. if xdist_worker_id is not None:
  33. # Make xdist workers use different port ranges to avoid race conditions
  34. base_port += port_range_size * xdist_worker_id
  35. # Select first open port in range
  36. port = base_port
  37. max_port = base_port + port_range_size
  38. sock = socket.socket()
  39. while port < max_port:
  40. try:
  41. sock.bind(('', port))
  42. sock.close()
  43. return str(port)
  44. except OSError:
  45. port += 1
  46. raise IOError('no free ports')
  47. def _get_cpu_socket_count():
  48. import shlex
  49. p1 = subprocess.Popen(shlex.split("cat /proc/cpuinfo"), stdout=subprocess.PIPE)
  50. p2 = subprocess.Popen(["grep", "physical id"], stdin=p1.stdout, stdout=subprocess.PIPE)
  51. p1.stdout.close()
  52. p3 = subprocess.Popen(shlex.split("sort -u"), stdin=p2.stdout, stdout=subprocess.PIPE)
  53. p2.stdout.close()
  54. p4 = subprocess.Popen(shlex.split("wc -l"), stdin=p3.stdout, stdout=subprocess.PIPE)
  55. p3.stdout.close()
  56. r = int(p4.communicate()[0])
  57. p4.stdout.close()
  58. return r
  59. def set_accelerator_visible():
  60. cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  61. xdist_worker_id = get_xdist_worker_id()
  62. if xdist_worker_id is None:
  63. xdist_worker_id = 0
  64. if cuda_visible is None:
  65. # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead
  66. if get_accelerator().device_name() == 'cuda':
  67. if is_rocm_pytorch():
  68. rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
  69. gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n'))
  70. num_accelerators = len(list(gpu_ids))
  71. else:
  72. nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
  73. num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n'))
  74. elif get_accelerator().device_name() == 'xpu':
  75. clinfo = subprocess.check_output(['clinfo'])
  76. lines = clinfo.decode('utf-8').strip().split('\n')
  77. num_accelerators = 0
  78. for line in lines:
  79. match = re.search('Device Type.*GPU', line)
  80. if match:
  81. num_accelerators += 1
  82. elif get_accelerator().device_name() == 'hpu':
  83. try:
  84. hl_smi = subprocess.check_output(['hl-smi', "-L"])
  85. num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode())
  86. except FileNotFoundError:
  87. sim_list = subprocess.check_output(['ls', '-1', '/dev/accel'])
  88. num_accelerators = re.findall(r"accel(\d+)", sim_list.decode())
  89. num_accelerators = sorted(num_accelerators, key=int)
  90. os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators)
  91. elif get_accelerator().device_name() == 'npu':
  92. npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
  93. num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
  94. else:
  95. assert get_accelerator().device_name() == 'cpu'
  96. num_accelerators = _get_cpu_socket_count()
  97. if isinstance(num_accelerators, list):
  98. cuda_visible = ",".join(num_accelerators)
  99. else:
  100. cuda_visible = ",".join(map(str, range(num_accelerators)))
  101. # rotate list based on xdist worker id, example below
  102. # wid=0 -> ['0', '1', '2', '3']
  103. # wid=1 -> ['1', '2', '3', '0']
  104. # wid=2 -> ['2', '3', '0', '1']
  105. # wid=3 -> ['3', '0', '1', '2']
  106. dev_id_list = cuda_visible.split(",")
  107. dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
  108. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
  109. class DistributedExec(ABC):
  110. """
  111. Base class for distributed execution of functions/methods. Contains common
  112. methods needed for DistributedTest and DistributedFixture.
  113. """
  114. world_size = 2
  115. backend = get_accelerator().communication_backend_name()
  116. init_distributed = True
  117. set_dist_env = True
  118. requires_cuda_env = True
  119. reuse_dist_env = False
  120. non_daemonic_procs = False
  121. _pool_cache = {}
  122. exec_timeout = DEEPSPEED_TEST_TIMEOUT
  123. @abstractmethod
  124. def run(self):
  125. ...
  126. def __call__(self, request=None):
  127. self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
  128. world_size = self.world_size
  129. if self.requires_cuda_env and not get_accelerator().is_available():
  130. pytest.skip("only supported in accelerator environments.")
  131. if isinstance(world_size, int):
  132. world_size = [world_size]
  133. for procs in world_size:
  134. self._launch_procs(procs)
  135. def _get_fixture_kwargs(self, request, func):
  136. if not request:
  137. return {}
  138. # Grab fixture / parametrize kwargs from pytest request object
  139. fixture_kwargs = {}
  140. params = inspect.getfullargspec(func).args
  141. params.remove("self")
  142. for p in params:
  143. try:
  144. fixture_kwargs[p] = request.getfixturevalue(p)
  145. except FixtureLookupError:
  146. pass # test methods can have kwargs that are not fixtures
  147. return fixture_kwargs
  148. def _launch_daemonic_procs(self, num_procs):
  149. # Create process pool or use cached one
  150. master_port = None
  151. if get_accelerator().device_name() == 'hpu':
  152. if self.reuse_dist_env:
  153. print("Ignoring reuse_dist_env for hpu")
  154. self.reuse_dist_env = False
  155. if self.reuse_dist_env:
  156. if num_procs not in self._pool_cache:
  157. self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
  158. master_port = get_master_port()
  159. pool = self._pool_cache[num_procs]
  160. else:
  161. pool = mp.Pool(processes=num_procs)
  162. master_port = get_master_port()
  163. # Run the test
  164. args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
  165. skip_msgs_async = pool.starmap_async(self._dist_run, args)
  166. try:
  167. skip_msgs = skip_msgs_async.get(self.exec_timeout)
  168. except mp.TimeoutError:
  169. # Shortcut to exit pytest in the case of a hanged test. This
  170. # usually means an environment error and the rest of tests will
  171. # hang (causing super long unit test runtimes)
  172. pytest.exit("Test hanged, exiting", returncode=1)
  173. finally:
  174. # Regardless of the outcome, ensure proper teardown
  175. # Tear down distributed environment and close process pools
  176. self._close_pool(pool, num_procs)
  177. # If we skipped a test, propagate that to this process
  178. if any(skip_msgs):
  179. assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
  180. pytest.skip(skip_msgs[0])
  181. def _launch_non_daemonic_procs(self, num_procs):
  182. assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes"
  183. master_port = get_master_port()
  184. skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
  185. processes = []
  186. prev_start_method = mp.get_start_method()
  187. mp.set_start_method('spawn', force=True)
  188. for local_rank in range(num_procs):
  189. p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg))
  190. p.start()
  191. processes.append(p)
  192. mp.set_start_method(prev_start_method, force=True)
  193. # Now loop and wait for a test to complete. The spin-wait here isn't a big
  194. # deal because the number of processes will be O(#GPUs) << O(#CPUs).
  195. any_done = False
  196. start = time.time()
  197. while (not any_done) and ((time.time() - start) < self.exec_timeout):
  198. for p in processes:
  199. if not p.is_alive():
  200. any_done = True
  201. break
  202. time.sleep(.1) # So we don't hog CPU
  203. # If we hit the timeout, then presume a test is hanged
  204. if not any_done:
  205. for p in processes:
  206. p.terminate()
  207. pytest.exit("Test hanged, exiting", returncode=1)
  208. # Wait for all other processes to complete
  209. for p in processes:
  210. p.join(self.exec_timeout)
  211. failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
  212. for rank, p in failed:
  213. # If it still hasn't terminated, kill it because it hung.
  214. if p.exitcode is None:
  215. p.terminate()
  216. pytest.fail(f'Worker {rank} hung.', pytrace=False)
  217. if p.exitcode < 0:
  218. pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', pytrace=False)
  219. if p.exitcode > 0:
  220. pytest.fail(f'Worker {rank} exited with code {p.exitcode}', pytrace=False)
  221. if not skip_msg.empty():
  222. # This assumed all skip messages are the same, it may be useful to
  223. # add a check here to assert all exit messages are equal
  224. pytest.skip(skip_msg.get())
  225. def _launch_procs(self, num_procs):
  226. # Verify we have enough accelerator devices to run this test
  227. if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
  228. pytest.skip(
  229. f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
  230. )
  231. if get_accelerator().device_name() == 'xpu':
  232. self.non_daemonic_procs = True
  233. self.reuse_dist_env = False
  234. # Set start method to `forkserver` (or `fork`)
  235. mp.set_start_method('forkserver', force=True)
  236. if self.non_daemonic_procs:
  237. self._launch_non_daemonic_procs(num_procs)
  238. else:
  239. self._launch_daemonic_procs(num_procs)
  240. def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""):
  241. if not dist.is_initialized():
  242. """ Initialize deepspeed.comm and execute the user function. """
  243. if self.set_dist_env:
  244. os.environ['MASTER_ADDR'] = '127.0.0.1'
  245. os.environ['MASTER_PORT'] = str(master_port)
  246. os.environ['LOCAL_RANK'] = str(local_rank)
  247. # NOTE: unit tests don't support multi-node so local_rank == global rank
  248. os.environ['RANK'] = str(local_rank)
  249. # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
  250. # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly
  251. os.environ['LOCAL_SIZE'] = str(num_procs)
  252. os.environ['WORLD_SIZE'] = str(num_procs)
  253. # turn off NCCL logging if set
  254. os.environ.pop('NCCL_DEBUG', None)
  255. if get_accelerator().is_available():
  256. set_accelerator_visible()
  257. if get_accelerator().is_available():
  258. get_accelerator().set_device(local_rank)
  259. if self.init_distributed:
  260. deepspeed.init_distributed(dist_backend=self.backend)
  261. dist.barrier()
  262. try:
  263. self.run(**self._fixture_kwargs)
  264. except BaseException as e:
  265. if isinstance(e, Skipped):
  266. if self.non_daemonic_procs:
  267. skip_msg.put(e.msg)
  268. else:
  269. skip_msg = e.msg
  270. else:
  271. raise e
  272. return skip_msg
  273. def _dist_destroy(self):
  274. if (dist is not None) and dist.is_initialized():
  275. dist.barrier()
  276. dist.destroy_process_group()
  277. def _close_pool(self, pool, num_procs, force=False):
  278. if force or not self.reuse_dist_env:
  279. msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
  280. pool.close()
  281. pool.join()
  282. class DistributedFixture(DistributedExec):
  283. """
  284. Implementation that extends @pytest.fixture to allow for distributed execution.
  285. This is primarily meant to be used when a test requires executing two pieces of
  286. code with different world sizes.
  287. There are 2 parameters that can be modified:
  288. - world_size: int = 2 -- the number of processes to launch
  289. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  290. Features:
  291. - able to call pytest.skip() inside fixture
  292. - can be reused by multiple tests
  293. - can accept other fixtures as input
  294. Limitations:
  295. - cannot use @pytest.mark.parametrize
  296. - world_size cannot be modified after definition and only one world_size value is accepted
  297. - any fixtures used must also be used in the test that uses this fixture (see example below)
  298. - return values cannot be returned. Passing values to a DistributedTest
  299. object can be achieved using class_tmpdir and writing to file (see example below)
  300. Usage:
  301. - must implement a run(self, ...) method
  302. - fixture can be used by making the class name input to a test function
  303. Example:
  304. @pytest.fixture(params=[10,20])
  305. def regular_pytest_fixture(request):
  306. return request.param
  307. class distributed_fixture_example(DistributedFixture):
  308. world_size = 4
  309. def run(self, regular_pytest_fixture, class_tmpdir):
  310. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  311. local_rank = os.environ["LOCAL_RANK"]
  312. print(f"Rank {local_rank} with value {regular_pytest_fixture}")
  313. with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
  314. f.write(f"{local_rank},{regular_pytest_fixture}")
  315. class TestExample(DistributedTest):
  316. world_size = 1
  317. def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
  318. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  319. for rank in range(4):
  320. with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
  321. assert f.read() == f"{rank},{regular_pytest_fixture}"
  322. """
  323. is_dist_fixture = True
  324. # These values are just placeholders so that pytest recognizes this as a fixture
  325. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
  326. __name__ = ""
  327. def __init__(self):
  328. assert isinstance(self.world_size, int), "Only one world size is allowed for distributed fixtures"
  329. self.__name__ = type(self).__name__
  330. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None, name=self.__name__)
  331. class DistributedTest(DistributedExec):
  332. """
  333. Implementation for running pytest with distributed execution.
  334. There are 2 parameters that can be modified:
  335. - world_size: Union[int,List[int]] = 2 -- the number of processes to launch
  336. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  337. Features:
  338. - able to call pytest.skip() inside tests
  339. - works with pytest fixtures, parametrize, mark, etc.
  340. - can contain multiple tests (each of which can be parametrized separately)
  341. - class methods can be fixtures (usable by tests in this class only)
  342. - world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
  343. - class_tmpdir is a fixture that can be used to get a tmpdir shared among
  344. all tests (including DistributedFixture)
  345. Usage:
  346. - class name must start with "Test"
  347. - must implement one or more test*(self, ...) methods
  348. Example:
  349. @pytest.fixture(params=[10,20])
  350. def val1(request):
  351. return request.param
  352. @pytest.mark.fast
  353. @pytest.mark.parametrize("val2", [30,40])
  354. class TestExample(DistributedTest):
  355. world_size = 2
  356. @pytest.fixture(params=[50,60])
  357. def val3(self, request):
  358. return request.param
  359. def test_1(self, val1, val2, str1="hello world"):
  360. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  361. assert all(val1, val2, str1)
  362. @pytest.mark.world_size(1)
  363. @pytest.mark.parametrize("val4", [70,80])
  364. def test_2(self, val1, val2, val3, val4):
  365. assert int(os.environ["WORLD_SIZE"]) == 1
  366. assert all(val1, val2, val3, val4)
  367. """
  368. is_dist_test = True
  369. # Temporary directory that is shared among test methods in a class
  370. @pytest.fixture(autouse=True, scope="class")
  371. def class_tmpdir(self, tmpdir_factory):
  372. fn = tmpdir_factory.mktemp(self.__class__.__name__)
  373. return fn
  374. def run(self, **fixture_kwargs):
  375. self._current_test(**fixture_kwargs)
  376. def __call__(self, request):
  377. self._current_test = self._get_current_test_func(request)
  378. self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
  379. if self.requires_cuda_env and not get_accelerator().is_available():
  380. pytest.skip("only supported in accelerator environments.")
  381. # Catch world_size override pytest mark
  382. for mark in getattr(request.function, "pytestmark", []):
  383. if mark.name == "world_size":
  384. world_size = mark.args[0]
  385. break
  386. else:
  387. world_size = self._fixture_kwargs.get("world_size", self.world_size)
  388. if isinstance(world_size, int):
  389. world_size = [world_size]
  390. for procs in world_size:
  391. self._launch_procs(procs)
  392. time.sleep(0.5)
  393. def _get_current_test_func(self, request):
  394. # DistributedTest subclasses may have multiple test methods
  395. func_name = request.function.__name__
  396. return getattr(self, func_name)
  397. def get_test_path(filename):
  398. curr_path = Path(__file__).parent
  399. return str(curr_path.joinpath(filename))
  400. # fp16 > bf16 > fp32
  401. def preferred_dtype():
  402. if get_accelerator().is_fp16_supported():
  403. return torch.float16
  404. elif get_accelerator().is_bf16_supported():
  405. return torch.bfloat16
  406. else:
  407. return torch.float32