common.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import time
  7. import inspect
  8. import socket
  9. import subprocess
  10. from abc import ABC, abstractmethod
  11. from pathlib import Path
  12. import torch
  13. import torch.multiprocessing as mp
  14. import deepspeed
  15. from deepspeed.accelerator import get_accelerator
  16. import deepspeed.comm as dist
  17. import pytest
  18. from _pytest.outcomes import Skipped
  19. from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
  20. # Worker timeout for tests that hang
  21. DEEPSPEED_TEST_TIMEOUT = int(os.environ.get('DS_UNITTEST_TIMEOUT', '600'))
  22. warn_reuse_dist_env = False
  23. def is_rocm_pytorch():
  24. return hasattr(torch.version, 'hip') and torch.version.hip is not None
  25. def get_xdist_worker_id():
  26. xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None)
  27. if xdist_worker is not None:
  28. xdist_worker_id = xdist_worker.replace('gw', '')
  29. return int(xdist_worker_id)
  30. return None
  31. def get_master_port(base_port=29500, port_range_size=1000):
  32. xdist_worker_id = get_xdist_worker_id()
  33. if xdist_worker_id is not None:
  34. # Make xdist workers use different port ranges to avoid race conditions
  35. base_port += port_range_size * xdist_worker_id
  36. # Select first open port in range
  37. port = base_port
  38. max_port = base_port + port_range_size
  39. sock = socket.socket()
  40. while port < max_port:
  41. try:
  42. sock.bind(('', port))
  43. sock.close()
  44. return str(port)
  45. except OSError:
  46. port += 1
  47. raise IOError('no free ports')
  48. def _get_cpu_socket_count():
  49. import shlex
  50. p1 = subprocess.Popen(shlex.split("cat /proc/cpuinfo"), stdout=subprocess.PIPE)
  51. p2 = subprocess.Popen(["grep", "physical id"], stdin=p1.stdout, stdout=subprocess.PIPE)
  52. p1.stdout.close()
  53. p3 = subprocess.Popen(shlex.split("sort -u"), stdin=p2.stdout, stdout=subprocess.PIPE)
  54. p2.stdout.close()
  55. p4 = subprocess.Popen(shlex.split("wc -l"), stdin=p3.stdout, stdout=subprocess.PIPE)
  56. p3.stdout.close()
  57. r = int(p4.communicate()[0])
  58. p4.stdout.close()
  59. return r
  60. def set_accelerator_visible():
  61. cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  62. xdist_worker_id = get_xdist_worker_id()
  63. if xdist_worker_id is None:
  64. xdist_worker_id = 0
  65. if cuda_visible is None:
  66. # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead
  67. if get_accelerator().device_name() == 'cuda':
  68. if is_rocm_pytorch():
  69. rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
  70. gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n'))
  71. num_accelerators = len(list(gpu_ids))
  72. else:
  73. nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
  74. num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n'))
  75. elif get_accelerator().device_name() == 'xpu':
  76. clinfo = subprocess.check_output(['clinfo'])
  77. lines = clinfo.decode('utf-8').strip().split('\n')
  78. num_accelerators = 0
  79. for line in lines:
  80. match = re.search('Device Type.*GPU', line)
  81. if match:
  82. num_accelerators += 1
  83. elif get_accelerator().device_name() == 'hpu':
  84. try:
  85. hl_smi = subprocess.check_output(['hl-smi', "-L"])
  86. num_accelerators = re.findall(r"Module ID\s+:\s+(\d+)", hl_smi.decode())
  87. except FileNotFoundError:
  88. sim_list = subprocess.check_output(['ls', '-1', '/dev/accel'])
  89. num_accelerators = re.findall(r"accel(\d+)", sim_list.decode())
  90. num_accelerators = sorted(num_accelerators, key=int)
  91. os.environ["HABANA_VISIBLE_MODULES"] = ",".join(num_accelerators)
  92. elif get_accelerator().device_name() == 'npu':
  93. npu_smi = subprocess.check_output(['npu-smi', 'info', '-l'])
  94. num_accelerators = int(npu_smi.decode('utf-8').strip().split('\n')[0].split(':')[1].strip())
  95. else:
  96. assert get_accelerator().device_name() == 'cpu'
  97. num_accelerators = _get_cpu_socket_count()
  98. if isinstance(num_accelerators, list):
  99. cuda_visible = ",".join(num_accelerators)
  100. else:
  101. cuda_visible = ",".join(map(str, range(num_accelerators)))
  102. # rotate list based on xdist worker id, example below
  103. # wid=0 -> ['0', '1', '2', '3']
  104. # wid=1 -> ['1', '2', '3', '0']
  105. # wid=2 -> ['2', '3', '0', '1']
  106. # wid=3 -> ['3', '0', '1', '2']
  107. dev_id_list = cuda_visible.split(",")
  108. dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
  109. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
  110. class DistributedExec(ABC):
  111. """
  112. Base class for distributed execution of functions/methods. Contains common
  113. methods needed for DistributedTest and DistributedFixture.
  114. """
  115. world_size = 2
  116. backend = get_accelerator().communication_backend_name()
  117. init_distributed = True
  118. set_dist_env = True
  119. requires_cuda_env = True
  120. reuse_dist_env = False
  121. non_daemonic_procs = False
  122. _pool_cache = {}
  123. exec_timeout = DEEPSPEED_TEST_TIMEOUT
  124. @abstractmethod
  125. def run(self):
  126. ...
  127. def __call__(self, request):
  128. self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
  129. world_size = self.world_size
  130. if self.requires_cuda_env and not get_accelerator().is_available():
  131. pytest.skip("only supported in accelerator environments.")
  132. self._launch_with_file_store(request, world_size)
  133. def _get_fixture_kwargs(self, request, func):
  134. if not request:
  135. return {}
  136. # Grab fixture / parametrize kwargs from pytest request object
  137. fixture_kwargs = {}
  138. params = inspect.getfullargspec(func).args
  139. params.remove("self")
  140. for p in params:
  141. try:
  142. fixture_kwargs[p] = request.getfixturevalue(p)
  143. except FixtureLookupError:
  144. pass # test methods can have kwargs that are not fixtures
  145. return fixture_kwargs
  146. def _launch_daemonic_procs(self, num_procs, init_method):
  147. # Create process pool or use cached one
  148. master_port = None
  149. if get_accelerator().device_name() == 'hpu':
  150. if self.reuse_dist_env:
  151. print("Ignoring reuse_dist_env for hpu")
  152. self.reuse_dist_env = False
  153. global warn_reuse_dist_env
  154. if self.reuse_dist_env and not warn_reuse_dist_env:
  155. # Currently we see memory leak for tests that reuse distributed environment
  156. print("Ignoring reuse_dist_env and forcibly setting it to False")
  157. warn_reuse_dist_env = True
  158. self.reuse_dist_env = False
  159. if self.reuse_dist_env:
  160. if num_procs not in self._pool_cache:
  161. self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
  162. master_port = get_master_port()
  163. pool = self._pool_cache[num_procs]
  164. else:
  165. pool = mp.Pool(processes=num_procs)
  166. master_port = get_master_port()
  167. # Run the test
  168. args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)]
  169. skip_msgs_async = pool.starmap_async(self._dist_run, args)
  170. try:
  171. skip_msgs = skip_msgs_async.get(self.exec_timeout)
  172. except mp.TimeoutError:
  173. # Shortcut to exit pytest in the case of a hanged test. This
  174. # usually means an environment error and the rest of tests will
  175. # hang (causing super long unit test runtimes)
  176. pytest.exit("Test hanged, exiting", returncode=1)
  177. finally:
  178. # Regardless of the outcome, ensure proper teardown
  179. # Tear down distributed environment and close process pools
  180. self._close_pool(pool, num_procs)
  181. # If we skipped a test, propagate that to this process
  182. if any(skip_msgs):
  183. assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
  184. pytest.skip(skip_msgs[0])
  185. def _launch_non_daemonic_procs(self, num_procs, init_method):
  186. assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes"
  187. master_port = get_master_port()
  188. skip_msg = mp.Queue() # Allows forked processes to share pytest.skip reason
  189. processes = []
  190. prev_start_method = mp.get_start_method()
  191. mp.set_start_method('spawn', force=True)
  192. for local_rank in range(num_procs):
  193. p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg))
  194. p.start()
  195. processes.append(p)
  196. mp.set_start_method(prev_start_method, force=True)
  197. # Now loop and wait for a test to complete. The spin-wait here isn't a big
  198. # deal because the number of processes will be O(#GPUs) << O(#CPUs).
  199. any_done = False
  200. start = time.time()
  201. while (not any_done) and ((time.time() - start) < self.exec_timeout):
  202. for p in processes:
  203. if not p.is_alive():
  204. any_done = True
  205. break
  206. time.sleep(.1) # So we don't hog CPU
  207. # If we hit the timeout, then presume a test is hanged
  208. if not any_done:
  209. for p in processes:
  210. p.terminate()
  211. pytest.exit("Test hanged, exiting", returncode=1)
  212. # Wait for all other processes to complete
  213. for p in processes:
  214. p.join(self.exec_timeout)
  215. failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
  216. for rank, p in failed:
  217. # If it still hasn't terminated, kill it because it hung.
  218. if p.exitcode is None:
  219. p.terminate()
  220. pytest.fail(f'Worker {rank} hung.', pytrace=False)
  221. if p.exitcode < 0:
  222. pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', pytrace=False)
  223. if p.exitcode > 0:
  224. pytest.fail(f'Worker {rank} exited with code {p.exitcode}', pytrace=False)
  225. if not skip_msg.empty():
  226. # This assumed all skip messages are the same, it may be useful to
  227. # add a check here to assert all exit messages are equal
  228. pytest.skip(skip_msg.get())
  229. def _launch_procs(self, num_procs, init_method):
  230. # Verify we have enough accelerator devices to run this test
  231. if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
  232. pytest.skip(
  233. f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
  234. )
  235. if get_accelerator().device_name() == 'xpu':
  236. self.non_daemonic_procs = True
  237. self.reuse_dist_env = False
  238. # Set start method to `forkserver` (or `fork`)
  239. mp.set_start_method('forkserver', force=True)
  240. if self.non_daemonic_procs:
  241. self._launch_non_daemonic_procs(num_procs, init_method)
  242. else:
  243. self._launch_daemonic_procs(num_procs, init_method)
  244. def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""):
  245. if not dist.is_initialized():
  246. """ Initialize deepspeed.comm and execute the user function. """
  247. if self.set_dist_env:
  248. os.environ['MASTER_ADDR'] = '127.0.0.1'
  249. os.environ['MASTER_PORT'] = str(master_port)
  250. os.environ['LOCAL_RANK'] = str(local_rank)
  251. # NOTE: unit tests don't support multi-node so local_rank == global rank
  252. os.environ['RANK'] = str(local_rank)
  253. # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
  254. # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly
  255. os.environ['LOCAL_SIZE'] = str(num_procs)
  256. os.environ['WORLD_SIZE'] = str(num_procs)
  257. # turn off NCCL logging if set
  258. os.environ.pop('NCCL_DEBUG', None)
  259. if get_accelerator().is_available():
  260. set_accelerator_visible()
  261. if get_accelerator().is_available():
  262. get_accelerator().set_device(local_rank)
  263. if self.init_distributed:
  264. deepspeed.init_distributed(dist_backend=self.backend,
  265. init_method=init_method,
  266. rank=local_rank,
  267. world_size=num_procs)
  268. dist.barrier()
  269. try:
  270. self.run(**self._fixture_kwargs)
  271. except BaseException as e:
  272. if isinstance(e, Skipped):
  273. if self.non_daemonic_procs:
  274. skip_msg.put(e.msg)
  275. else:
  276. skip_msg = e.msg
  277. else:
  278. raise e
  279. return skip_msg
  280. def _launch_with_file_store(self, request, world_size):
  281. tmpdir = request.getfixturevalue("tmpdir")
  282. dist_file_store = tmpdir.join("dist_file_store")
  283. assert not os.path.exists(dist_file_store)
  284. init_method = f"file://{dist_file_store}"
  285. if isinstance(world_size, int):
  286. world_size = [world_size]
  287. for procs in world_size:
  288. try:
  289. self._launch_procs(procs, init_method)
  290. finally:
  291. if os.path.exists(dist_file_store):
  292. os.remove(dist_file_store)
  293. time.sleep(0.5)
  294. def _dist_destroy(self):
  295. if (dist is not None) and dist.is_initialized():
  296. dist.barrier()
  297. dist.destroy_process_group()
  298. def _close_pool(self, pool, num_procs, force=False):
  299. if force or not self.reuse_dist_env:
  300. msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
  301. pool.close()
  302. pool.join()
  303. class DistributedFixture(DistributedExec):
  304. """
  305. Implementation that extends @pytest.fixture to allow for distributed execution.
  306. This is primarily meant to be used when a test requires executing two pieces of
  307. code with different world sizes.
  308. There are 2 parameters that can be modified:
  309. - world_size: int = 2 -- the number of processes to launch
  310. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  311. Features:
  312. - able to call pytest.skip() inside fixture
  313. - can be reused by multiple tests
  314. - can accept other fixtures as input
  315. Limitations:
  316. - cannot use @pytest.mark.parametrize
  317. - world_size cannot be modified after definition and only one world_size value is accepted
  318. - any fixtures used must also be used in the test that uses this fixture (see example below)
  319. - return values cannot be returned. Passing values to a DistributedTest
  320. object can be achieved using class_tmpdir and writing to file (see example below)
  321. Usage:
  322. - must implement a run(self, ...) method
  323. - fixture can be used by making the class name input to a test function
  324. Example:
  325. @pytest.fixture(params=[10,20])
  326. def regular_pytest_fixture(request):
  327. return request.param
  328. class distributed_fixture_example(DistributedFixture):
  329. world_size = 4
  330. def run(self, regular_pytest_fixture, class_tmpdir):
  331. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  332. local_rank = os.environ["LOCAL_RANK"]
  333. print(f"Rank {local_rank} with value {regular_pytest_fixture}")
  334. with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
  335. f.write(f"{local_rank},{regular_pytest_fixture}")
  336. class TestExample(DistributedTest):
  337. world_size = 1
  338. def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
  339. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  340. for rank in range(4):
  341. with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
  342. assert f.read() == f"{rank},{regular_pytest_fixture}"
  343. """
  344. is_dist_fixture = True
  345. # These values are just placeholders so that pytest recognizes this as a fixture
  346. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
  347. __name__ = ""
  348. def __init__(self):
  349. assert isinstance(self.world_size, int), "Only one world size is allowed for distributed fixtures"
  350. self.__name__ = type(self).__name__
  351. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None, name=self.__name__)
  352. class DistributedTest(DistributedExec):
  353. """
  354. Implementation for running pytest with distributed execution.
  355. There are 2 parameters that can be modified:
  356. - world_size: Union[int,List[int]] = 2 -- the number of processes to launch
  357. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  358. Features:
  359. - able to call pytest.skip() inside tests
  360. - works with pytest fixtures, parametrize, mark, etc.
  361. - can contain multiple tests (each of which can be parametrized separately)
  362. - class methods can be fixtures (usable by tests in this class only)
  363. - world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
  364. - class_tmpdir is a fixture that can be used to get a tmpdir shared among
  365. all tests (including DistributedFixture)
  366. Usage:
  367. - class name must start with "Test"
  368. - must implement one or more test*(self, ...) methods
  369. Example:
  370. @pytest.fixture(params=[10,20])
  371. def val1(request):
  372. return request.param
  373. @pytest.mark.fast
  374. @pytest.mark.parametrize("val2", [30,40])
  375. class TestExample(DistributedTest):
  376. world_size = 2
  377. @pytest.fixture(params=[50,60])
  378. def val3(self, request):
  379. return request.param
  380. def test_1(self, val1, val2, str1="hello world"):
  381. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  382. assert all(val1, val2, str1)
  383. @pytest.mark.world_size(1)
  384. @pytest.mark.parametrize("val4", [70,80])
  385. def test_2(self, val1, val2, val3, val4):
  386. assert int(os.environ["WORLD_SIZE"]) == 1
  387. assert all(val1, val2, val3, val4)
  388. """
  389. is_dist_test = True
  390. # Temporary directory that is shared among test methods in a class
  391. @pytest.fixture(autouse=True, scope="class")
  392. def class_tmpdir(self, tmpdir_factory):
  393. fn = tmpdir_factory.mktemp(self.__class__.__name__)
  394. return fn
  395. def run(self, **fixture_kwargs):
  396. self._current_test(**fixture_kwargs)
  397. def __call__(self, request):
  398. self._current_test = self._get_current_test_func(request)
  399. self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
  400. if self.requires_cuda_env and not get_accelerator().is_available():
  401. pytest.skip("only supported in accelerator environments.")
  402. # Catch world_size override pytest mark
  403. for mark in getattr(request.function, "pytestmark", []):
  404. if mark.name == "world_size":
  405. world_size = mark.args[0]
  406. break
  407. else:
  408. world_size = self._fixture_kwargs.get("world_size", self.world_size)
  409. self._launch_with_file_store(request, world_size)
  410. def _get_current_test_func(self, request):
  411. # DistributedTest subclasses may have multiple test methods
  412. func_name = request.function.__name__
  413. return getattr(self, func_name)
  414. def get_test_path(filename):
  415. curr_path = Path(__file__).parent
  416. return str(curr_path.joinpath(filename))
  417. # fp16 > bf16 > fp32
  418. def preferred_dtype():
  419. if get_accelerator().is_fp16_supported():
  420. return torch.float16
  421. elif get_accelerator().is_bf16_supported():
  422. return torch.bfloat16
  423. else:
  424. return torch.float32