common.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import time
  7. import inspect
  8. import socket
  9. import subprocess
  10. from abc import ABC, abstractmethod
  11. from pathlib import Path
  12. import torch
  13. import torch.multiprocessing as mp
  14. import deepspeed
  15. from deepspeed.accelerator import get_accelerator
  16. import deepspeed.comm as dist
  17. import pytest
  18. from _pytest.outcomes import Skipped
  19. from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
  20. # Worker timeout *after* the first worker has completed.
  21. DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
  22. # Worker timeout for tests that hang
  23. DEEPSPEED_TEST_TIMEOUT = 600
  24. def is_rocm_pytorch():
  25. return hasattr(torch.version, 'hip') and torch.version.hip is not None
  26. def get_xdist_worker_id():
  27. xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None)
  28. if xdist_worker is not None:
  29. xdist_worker_id = xdist_worker.replace('gw', '')
  30. return int(xdist_worker_id)
  31. return None
  32. def get_master_port(base_port=29500, port_range_size=1000):
  33. xdist_worker_id = get_xdist_worker_id()
  34. if xdist_worker_id is not None:
  35. # Make xdist workers use different port ranges to avoid race conditions
  36. base_port += port_range_size * xdist_worker_id
  37. # Select first open port in range
  38. port = base_port
  39. max_port = base_port + port_range_size
  40. sock = socket.socket()
  41. while port < max_port:
  42. try:
  43. sock.bind(('', port))
  44. sock.close()
  45. return str(port)
  46. except OSError:
  47. port += 1
  48. raise IOError('no free ports')
  49. def set_accelerator_visible():
  50. cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  51. xdist_worker_id = get_xdist_worker_id()
  52. if xdist_worker_id is None:
  53. xdist_worker_id = 0
  54. if cuda_visible is None:
  55. # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead
  56. if get_accelerator().device_name() == 'cuda':
  57. if is_rocm_pytorch():
  58. rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
  59. gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n'))
  60. num_accelerators = len(list(gpu_ids))
  61. else:
  62. nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
  63. num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n'))
  64. elif get_accelerator().device_name() == 'xpu':
  65. clinfo = subprocess.check_output(['clinfo'])
  66. lines = clinfo.decode('utf-8').strip().split('\n')
  67. num_accelerators = 0
  68. for line in lines:
  69. match = re.search('Device Type.*GPU', line)
  70. if match:
  71. num_accelerators += 1
  72. else:
  73. assert get_accelerator().device_name() == 'cpu'
  74. cpu_sockets = int(
  75. subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
  76. num_accelerators = cpu_sockets
  77. cuda_visible = ",".join(map(str, range(num_accelerators)))
  78. # rotate list based on xdist worker id, example below
  79. # wid=0 -> ['0', '1', '2', '3']
  80. # wid=1 -> ['1', '2', '3', '0']
  81. # wid=2 -> ['2', '3', '0', '1']
  82. # wid=3 -> ['3', '0', '1', '2']
  83. dev_id_list = cuda_visible.split(",")
  84. dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
  85. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
  86. class DistributedExec(ABC):
  87. """
  88. Base class for distributed execution of functions/methods. Contains common
  89. methods needed for DistributedTest and DistributedFixture.
  90. """
  91. world_size = 2
  92. backend = get_accelerator().communication_backend_name()
  93. init_distributed = True
  94. set_dist_env = True
  95. requires_cuda_env = True
  96. reuse_dist_env = False
  97. _pool_cache = {}
  98. @abstractmethod
  99. def run(self):
  100. ...
  101. def __call__(self, request=None):
  102. self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
  103. world_size = self.world_size
  104. if self.requires_cuda_env and not get_accelerator().is_available():
  105. pytest.skip("only supported in accelerator environments.")
  106. if isinstance(world_size, int):
  107. world_size = [world_size]
  108. for procs in world_size:
  109. self._launch_procs(procs)
  110. def _get_fixture_kwargs(self, request, func):
  111. if not request:
  112. return {}
  113. # Grab fixture / parametrize kwargs from pytest request object
  114. fixture_kwargs = {}
  115. params = inspect.getfullargspec(func).args
  116. params.remove("self")
  117. for p in params:
  118. try:
  119. fixture_kwargs[p] = request.getfixturevalue(p)
  120. except FixtureLookupError:
  121. pass # test methods can have kwargs that are not fixtures
  122. return fixture_kwargs
  123. def _launch_procs(self, num_procs):
  124. # Verify we have enough accelerator devices to run this test
  125. if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
  126. pytest.skip(
  127. f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
  128. )
  129. # Set start method to `forkserver` (or `fork`)
  130. mp.set_start_method('forkserver', force=True)
  131. # Create process pool or use cached one
  132. master_port = None
  133. if self.reuse_dist_env:
  134. if num_procs not in self._pool_cache:
  135. self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
  136. master_port = get_master_port()
  137. pool = self._pool_cache[num_procs]
  138. else:
  139. pool = mp.Pool(processes=num_procs)
  140. master_port = get_master_port()
  141. # Run the test
  142. args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
  143. skip_msgs_async = pool.starmap_async(self._dist_run, args)
  144. try:
  145. skip_msgs = skip_msgs_async.get(DEEPSPEED_TEST_TIMEOUT)
  146. except mp.TimeoutError:
  147. # Shortcut to exit pytest in the case of a hanged test. This
  148. # usually means an environment error and the rest of tests will
  149. # hang (causing super long unit test runtimes)
  150. pytest.exit("Test hanged, exiting", returncode=0)
  151. # Tear down distributed environment and close process pools
  152. self._close_pool(pool, num_procs)
  153. # If we skipped a test, propagate that to this process
  154. if any(skip_msgs):
  155. assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
  156. pytest.skip(skip_msgs[0])
  157. def _dist_run(self, local_rank, num_procs, master_port):
  158. skip_msg = ''
  159. if not dist.is_initialized():
  160. """ Initialize deepspeed.comm and execute the user function. """
  161. if self.set_dist_env:
  162. os.environ['MASTER_ADDR'] = '127.0.0.1'
  163. os.environ['MASTER_PORT'] = str(master_port)
  164. os.environ['LOCAL_RANK'] = str(local_rank)
  165. # NOTE: unit tests don't support multi-node so local_rank == global rank
  166. os.environ['RANK'] = str(local_rank)
  167. # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
  168. # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly
  169. os.environ['LOCAL_SIZE'] = str(num_procs)
  170. os.environ['WORLD_SIZE'] = str(num_procs)
  171. # turn off NCCL logging if set
  172. os.environ.pop('NCCL_DEBUG', None)
  173. if get_accelerator().is_available():
  174. set_accelerator_visible()
  175. if self.init_distributed:
  176. deepspeed.init_distributed(dist_backend=self.backend)
  177. dist.barrier()
  178. if get_accelerator().is_available():
  179. get_accelerator().set_device(local_rank)
  180. try:
  181. self.run(**self._fixture_kwargs)
  182. except BaseException as e:
  183. if isinstance(e, Skipped):
  184. skip_msg = e.msg
  185. else:
  186. raise e
  187. return skip_msg
  188. def _dist_destroy(self):
  189. if (dist is not None) and dist.is_initialized():
  190. dist.barrier()
  191. dist.destroy_process_group()
  192. def _close_pool(self, pool, num_procs, force=False):
  193. if force or not self.reuse_dist_env:
  194. msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
  195. pool.close()
  196. pool.join()
  197. class DistributedFixture(DistributedExec):
  198. """
  199. Implementation that extends @pytest.fixture to allow for distributed execution.
  200. This is primarily meant to be used when a test requires executing two pieces of
  201. code with different world sizes.
  202. There are 2 parameters that can be modified:
  203. - world_size: int = 2 -- the number of processes to launch
  204. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  205. Features:
  206. - able to call pytest.skip() inside fixture
  207. - can be reused by multiple tests
  208. - can accept other fixtures as input
  209. Limitations:
  210. - cannot use @pytest.mark.parametrize
  211. - world_size cannot be modified after definition and only one world_size value is accepted
  212. - any fixtures used must also be used in the test that uses this fixture (see example below)
  213. - return values cannot be returned. Passing values to a DistributedTest
  214. object can be achieved using class_tmpdir and writing to file (see example below)
  215. Usage:
  216. - must implement a run(self, ...) method
  217. - fixture can be used by making the class name input to a test function
  218. Example:
  219. @pytest.fixture(params=[10,20])
  220. def regular_pytest_fixture(request):
  221. return request.param
  222. class distributed_fixture_example(DistributedFixture):
  223. world_size = 4
  224. def run(self, regular_pytest_fixture, class_tmpdir):
  225. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  226. local_rank = os.environ["LOCAL_RANK"]
  227. print(f"Rank {local_rank} with value {regular_pytest_fixture}")
  228. with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
  229. f.write(f"{local_rank},{regular_pytest_fixture}")
  230. class TestExample(DistributedTest):
  231. world_size = 1
  232. def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
  233. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  234. for rank in range(4):
  235. with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
  236. assert f.read() == f"{rank},{regular_pytest_fixture}"
  237. """
  238. is_dist_fixture = True
  239. # These values are just placeholders so that pytest recognizes this as a fixture
  240. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
  241. __name__ = ""
  242. def __init__(self):
  243. assert isinstance(self.world_size, int), "Only one world size is allowed for distributed fixtures"
  244. self.__name__ = type(self).__name__
  245. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None, name=self.__name__)
  246. class DistributedTest(DistributedExec):
  247. """
  248. Implementation for running pytest with distributed execution.
  249. There are 2 parameters that can be modified:
  250. - world_size: Union[int,List[int]] = 2 -- the number of processes to launch
  251. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  252. Features:
  253. - able to call pytest.skip() inside tests
  254. - works with pytest fixtures, parametrize, mark, etc.
  255. - can contain multiple tests (each of which can be parametrized separately)
  256. - class methods can be fixtures (usable by tests in this class only)
  257. - world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
  258. - class_tmpdir is a fixture that can be used to get a tmpdir shared among
  259. all tests (including DistributedFixture)
  260. Usage:
  261. - class name must start with "Test"
  262. - must implement one or more test*(self, ...) methods
  263. Example:
  264. @pytest.fixture(params=[10,20])
  265. def val1(request):
  266. return request.param
  267. @pytest.mark.fast
  268. @pytest.mark.parametrize("val2", [30,40])
  269. class TestExample(DistributedTest):
  270. world_size = 2
  271. @pytest.fixture(params=[50,60])
  272. def val3(self, request):
  273. return request.param
  274. def test_1(self, val1, val2, str1="hello world"):
  275. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  276. assert all(val1, val2, str1)
  277. @pytest.mark.world_size(1)
  278. @pytest.mark.parametrize("val4", [70,80])
  279. def test_2(self, val1, val2, val3, val4):
  280. assert int(os.environ["WORLD_SIZE"]) == 1
  281. assert all(val1, val2, val3, val4)
  282. """
  283. is_dist_test = True
  284. # Temporary directory that is shared among test methods in a class
  285. @pytest.fixture(autouse=True, scope="class")
  286. def class_tmpdir(self, tmpdir_factory):
  287. fn = tmpdir_factory.mktemp(self.__class__.__name__)
  288. return fn
  289. def run(self, **fixture_kwargs):
  290. self._current_test(**fixture_kwargs)
  291. def __call__(self, request):
  292. self._current_test = self._get_current_test_func(request)
  293. self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
  294. if self.requires_cuda_env and not get_accelerator().is_available():
  295. pytest.skip("only supported in accelerator environments.")
  296. # Catch world_size override pytest mark
  297. for mark in getattr(request.function, "pytestmark", []):
  298. if mark.name == "world_size":
  299. world_size = mark.args[0]
  300. break
  301. else:
  302. world_size = self.world_size
  303. if isinstance(world_size, int):
  304. world_size = [world_size]
  305. for procs in world_size:
  306. self._launch_procs(procs)
  307. time.sleep(0.5)
  308. def _get_current_test_func(self, request):
  309. # DistributedTest subclasses may have multiple test methods
  310. func_name = request.function.__name__
  311. return getattr(self, func_name)
  312. def get_test_path(filename):
  313. curr_path = Path(__file__).parent
  314. return str(curr_path.joinpath(filename))