test_multinode_runner.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from copy import deepcopy
  5. from deepspeed.launcher import multinode_runner as mnrunner
  6. from deepspeed.launcher.runner import encode_world_info, parse_args
  7. import os
  8. import pytest
  9. @pytest.fixture
  10. def runner_info():
  11. hosts = {'worker-0': 4, 'worker-1': 4}
  12. world_info = encode_world_info(hosts)
  13. env = deepcopy(os.environ)
  14. args = parse_args(['test_launcher.py'])
  15. return env, hosts, world_info, args
  16. def test_pdsh_runner(runner_info):
  17. env, resource_pool, world_info, args = runner_info
  18. runner = mnrunner.PDSHRunner(args, world_info)
  19. cmd, kill_cmd, env = runner.get_cmd(env, resource_pool)
  20. assert cmd[0] == 'pdsh'
  21. assert env['PDSH_RCMD_TYPE'] == 'ssh'
  22. def test_openmpi_runner(runner_info):
  23. env, resource_pool, world_info, args = runner_info
  24. runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
  25. cmd = runner.get_cmd(env, resource_pool)
  26. assert cmd[0] == 'mpirun'
  27. def test_mpich_runner(runner_info):
  28. env, resource_pool, world_info, args = runner_info
  29. runner = mnrunner.MPICHRunner(args, world_info, resource_pool)
  30. cmd = runner.get_cmd(env, resource_pool)
  31. assert cmd[0] == 'mpirun'
  32. def test_slurm_runner(runner_info):
  33. env, resource_pool, world_info, args = runner_info
  34. runner = mnrunner.SlurmRunner(args, world_info, resource_pool)
  35. cmd = runner.get_cmd(env, resource_pool)
  36. assert cmd[0] == 'srun'
  37. def test_mvapich_runner(runner_info):
  38. env, resource_pool, world_info, args = runner_info
  39. runner = mnrunner.MVAPICHRunner(args, world_info, resource_pool)
  40. cmd = runner.get_cmd(env, resource_pool)
  41. assert cmd[0] == 'mpirun'