common.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import re
  6. import time
  7. import inspect
  8. import socket
  9. import subprocess
  10. from abc import ABC, abstractmethod
  11. from pathlib import Path
  12. import torch
  13. import torch.multiprocessing as mp
  14. import deepspeed
  15. from deepspeed.accelerator import get_accelerator
  16. import deepspeed.comm as dist
  17. import pytest
  18. from _pytest.outcomes import Skipped
  19. from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
  20. # Worker timeout for tests that hang
  21. DEEPSPEED_TEST_TIMEOUT = 600
  22. def is_rocm_pytorch():
  23. return hasattr(torch.version, 'hip') and torch.version.hip is not None
  24. def get_xdist_worker_id():
  25. xdist_worker = os.environ.get('PYTEST_XDIST_WORKER', None)
  26. if xdist_worker is not None:
  27. xdist_worker_id = xdist_worker.replace('gw', '')
  28. return int(xdist_worker_id)
  29. return None
  30. def get_master_port(base_port=29500, port_range_size=1000):
  31. xdist_worker_id = get_xdist_worker_id()
  32. if xdist_worker_id is not None:
  33. # Make xdist workers use different port ranges to avoid race conditions
  34. base_port += port_range_size * xdist_worker_id
  35. # Select first open port in range
  36. port = base_port
  37. max_port = base_port + port_range_size
  38. sock = socket.socket()
  39. while port < max_port:
  40. try:
  41. sock.bind(('', port))
  42. sock.close()
  43. return str(port)
  44. except OSError:
  45. port += 1
  46. raise IOError('no free ports')
  47. def set_accelerator_visible():
  48. cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  49. xdist_worker_id = get_xdist_worker_id()
  50. if xdist_worker_id is None:
  51. xdist_worker_id = 0
  52. if cuda_visible is None:
  53. # CUDA_VISIBLE_DEVICES is not set, discover it using accelerator specific command instead
  54. if get_accelerator().device_name() == 'cuda':
  55. if is_rocm_pytorch():
  56. rocm_smi = subprocess.check_output(['rocm-smi', '--showid'])
  57. gpu_ids = filter(lambda s: 'GPU' in s, rocm_smi.decode('utf-8').strip().split('\n'))
  58. num_accelerators = len(list(gpu_ids))
  59. else:
  60. nvidia_smi = subprocess.check_output(['nvidia-smi', '--list-gpus'])
  61. num_accelerators = len(nvidia_smi.decode('utf-8').strip().split('\n'))
  62. elif get_accelerator().device_name() == 'xpu':
  63. clinfo = subprocess.check_output(['clinfo'])
  64. lines = clinfo.decode('utf-8').strip().split('\n')
  65. num_accelerators = 0
  66. for line in lines:
  67. match = re.search('Device Type.*GPU', line)
  68. if match:
  69. num_accelerators += 1
  70. else:
  71. assert get_accelerator().device_name() == 'cpu'
  72. cpu_sockets = int(
  73. subprocess.check_output('cat /proc/cpuinfo | grep "physical id" | sort -u | wc -l', shell=True))
  74. num_accelerators = cpu_sockets
  75. cuda_visible = ",".join(map(str, range(num_accelerators)))
  76. # rotate list based on xdist worker id, example below
  77. # wid=0 -> ['0', '1', '2', '3']
  78. # wid=1 -> ['1', '2', '3', '0']
  79. # wid=2 -> ['2', '3', '0', '1']
  80. # wid=3 -> ['3', '0', '1', '2']
  81. dev_id_list = cuda_visible.split(",")
  82. dev_id_list = dev_id_list[xdist_worker_id:] + dev_id_list[:xdist_worker_id]
  83. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(dev_id_list)
  84. class DistributedExec(ABC):
  85. """
  86. Base class for distributed execution of functions/methods. Contains common
  87. methods needed for DistributedTest and DistributedFixture.
  88. """
  89. world_size = 2
  90. backend = get_accelerator().communication_backend_name()
  91. init_distributed = True
  92. set_dist_env = True
  93. requires_cuda_env = True
  94. reuse_dist_env = False
  95. _pool_cache = {}
  96. exec_timeout = DEEPSPEED_TEST_TIMEOUT
  97. @abstractmethod
  98. def run(self):
  99. ...
  100. def __call__(self, request=None):
  101. self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
  102. world_size = self.world_size
  103. if self.requires_cuda_env and not get_accelerator().is_available():
  104. pytest.skip("only supported in accelerator environments.")
  105. if isinstance(world_size, int):
  106. world_size = [world_size]
  107. for procs in world_size:
  108. self._launch_procs(procs)
  109. def _get_fixture_kwargs(self, request, func):
  110. if not request:
  111. return {}
  112. # Grab fixture / parametrize kwargs from pytest request object
  113. fixture_kwargs = {}
  114. params = inspect.getfullargspec(func).args
  115. params.remove("self")
  116. for p in params:
  117. try:
  118. fixture_kwargs[p] = request.getfixturevalue(p)
  119. except FixtureLookupError:
  120. pass # test methods can have kwargs that are not fixtures
  121. return fixture_kwargs
  122. def _launch_procs(self, num_procs):
  123. # Verify we have enough accelerator devices to run this test
  124. if get_accelerator().is_available() and get_accelerator().device_count() < num_procs:
  125. pytest.skip(
  126. f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
  127. )
  128. # Set start method to `forkserver` (or `fork`)
  129. mp.set_start_method('forkserver', force=True)
  130. # Create process pool or use cached one
  131. master_port = None
  132. if self.reuse_dist_env:
  133. if num_procs not in self._pool_cache:
  134. self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
  135. master_port = get_master_port()
  136. pool = self._pool_cache[num_procs]
  137. else:
  138. pool = mp.Pool(processes=num_procs)
  139. master_port = get_master_port()
  140. # Run the test
  141. args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
  142. skip_msgs_async = pool.starmap_async(self._dist_run, args)
  143. try:
  144. skip_msgs = skip_msgs_async.get(self.exec_timeout)
  145. except mp.TimeoutError:
  146. # Shortcut to exit pytest in the case of a hanged test. This
  147. # usually means an environment error and the rest of tests will
  148. # hang (causing super long unit test runtimes)
  149. pytest.exit("Test hanged, exiting", returncode=0)
  150. # Tear down distributed environment and close process pools
  151. self._close_pool(pool, num_procs)
  152. # If we skipped a test, propagate that to this process
  153. if any(skip_msgs):
  154. assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
  155. pytest.skip(skip_msgs[0])
  156. def _dist_run(self, local_rank, num_procs, master_port):
  157. skip_msg = ''
  158. if not dist.is_initialized():
  159. """ Initialize deepspeed.comm and execute the user function. """
  160. if self.set_dist_env:
  161. os.environ['MASTER_ADDR'] = '127.0.0.1'
  162. os.environ['MASTER_PORT'] = str(master_port)
  163. os.environ['LOCAL_RANK'] = str(local_rank)
  164. # NOTE: unit tests don't support multi-node so local_rank == global rank
  165. os.environ['RANK'] = str(local_rank)
  166. # In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
  167. # DeepSpeed single node launcher would also set LOCAL_SIZE accordingly
  168. os.environ['LOCAL_SIZE'] = str(num_procs)
  169. os.environ['WORLD_SIZE'] = str(num_procs)
  170. # turn off NCCL logging if set
  171. os.environ.pop('NCCL_DEBUG', None)
  172. if get_accelerator().is_available():
  173. set_accelerator_visible()
  174. if self.init_distributed:
  175. deepspeed.init_distributed(dist_backend=self.backend)
  176. dist.barrier()
  177. if get_accelerator().is_available():
  178. get_accelerator().set_device(local_rank)
  179. try:
  180. self.run(**self._fixture_kwargs)
  181. except BaseException as e:
  182. if isinstance(e, Skipped):
  183. skip_msg = e.msg
  184. else:
  185. raise e
  186. return skip_msg
  187. def _dist_destroy(self):
  188. if (dist is not None) and dist.is_initialized():
  189. dist.barrier()
  190. dist.destroy_process_group()
  191. def _close_pool(self, pool, num_procs, force=False):
  192. if force or not self.reuse_dist_env:
  193. msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
  194. pool.close()
  195. pool.join()
  196. class DistributedFixture(DistributedExec):
  197. """
  198. Implementation that extends @pytest.fixture to allow for distributed execution.
  199. This is primarily meant to be used when a test requires executing two pieces of
  200. code with different world sizes.
  201. There are 2 parameters that can be modified:
  202. - world_size: int = 2 -- the number of processes to launch
  203. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  204. Features:
  205. - able to call pytest.skip() inside fixture
  206. - can be reused by multiple tests
  207. - can accept other fixtures as input
  208. Limitations:
  209. - cannot use @pytest.mark.parametrize
  210. - world_size cannot be modified after definition and only one world_size value is accepted
  211. - any fixtures used must also be used in the test that uses this fixture (see example below)
  212. - return values cannot be returned. Passing values to a DistributedTest
  213. object can be achieved using class_tmpdir and writing to file (see example below)
  214. Usage:
  215. - must implement a run(self, ...) method
  216. - fixture can be used by making the class name input to a test function
  217. Example:
  218. @pytest.fixture(params=[10,20])
  219. def regular_pytest_fixture(request):
  220. return request.param
  221. class distributed_fixture_example(DistributedFixture):
  222. world_size = 4
  223. def run(self, regular_pytest_fixture, class_tmpdir):
  224. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  225. local_rank = os.environ["LOCAL_RANK"]
  226. print(f"Rank {local_rank} with value {regular_pytest_fixture}")
  227. with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
  228. f.write(f"{local_rank},{regular_pytest_fixture}")
  229. class TestExample(DistributedTest):
  230. world_size = 1
  231. def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
  232. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  233. for rank in range(4):
  234. with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
  235. assert f.read() == f"{rank},{regular_pytest_fixture}"
  236. """
  237. is_dist_fixture = True
  238. # These values are just placeholders so that pytest recognizes this as a fixture
  239. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
  240. __name__ = ""
  241. def __init__(self):
  242. assert isinstance(self.world_size, int), "Only one world size is allowed for distributed fixtures"
  243. self.__name__ = type(self).__name__
  244. _pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None, name=self.__name__)
  245. class DistributedTest(DistributedExec):
  246. """
  247. Implementation for running pytest with distributed execution.
  248. There are 2 parameters that can be modified:
  249. - world_size: Union[int,List[int]] = 2 -- the number of processes to launch
  250. - backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
  251. Features:
  252. - able to call pytest.skip() inside tests
  253. - works with pytest fixtures, parametrize, mark, etc.
  254. - can contain multiple tests (each of which can be parametrized separately)
  255. - class methods can be fixtures (usable by tests in this class only)
  256. - world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
  257. - class_tmpdir is a fixture that can be used to get a tmpdir shared among
  258. all tests (including DistributedFixture)
  259. Usage:
  260. - class name must start with "Test"
  261. - must implement one or more test*(self, ...) methods
  262. Example:
  263. @pytest.fixture(params=[10,20])
  264. def val1(request):
  265. return request.param
  266. @pytest.mark.fast
  267. @pytest.mark.parametrize("val2", [30,40])
  268. class TestExample(DistributedTest):
  269. world_size = 2
  270. @pytest.fixture(params=[50,60])
  271. def val3(self, request):
  272. return request.param
  273. def test_1(self, val1, val2, str1="hello world"):
  274. assert int(os.environ["WORLD_SIZE"]) == self.world_size
  275. assert all(val1, val2, str1)
  276. @pytest.mark.world_size(1)
  277. @pytest.mark.parametrize("val4", [70,80])
  278. def test_2(self, val1, val2, val3, val4):
  279. assert int(os.environ["WORLD_SIZE"]) == 1
  280. assert all(val1, val2, val3, val4)
  281. """
  282. is_dist_test = True
  283. # Temporary directory that is shared among test methods in a class
  284. @pytest.fixture(autouse=True, scope="class")
  285. def class_tmpdir(self, tmpdir_factory):
  286. fn = tmpdir_factory.mktemp(self.__class__.__name__)
  287. return fn
  288. def run(self, **fixture_kwargs):
  289. self._current_test(**fixture_kwargs)
  290. def __call__(self, request):
  291. self._current_test = self._get_current_test_func(request)
  292. self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
  293. if self.requires_cuda_env and not get_accelerator().is_available():
  294. pytest.skip("only supported in accelerator environments.")
  295. # Catch world_size override pytest mark
  296. for mark in getattr(request.function, "pytestmark", []):
  297. if mark.name == "world_size":
  298. world_size = mark.args[0]
  299. break
  300. else:
  301. world_size = self.world_size
  302. if isinstance(world_size, int):
  303. world_size = [world_size]
  304. for procs in world_size:
  305. self._launch_procs(procs)
  306. time.sleep(0.5)
  307. def _get_current_test_func(self, request):
  308. # DistributedTest subclasses may have multiple test methods
  309. func_name = request.function.__name__
  310. return getattr(self, func_name)
  311. def get_test_path(filename):
  312. curr_path = Path(__file__).parent
  313. return str(curr_path.joinpath(filename))