12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from copy import deepcopy
- from deepspeed.launcher import multinode_runner as mnrunner
- from deepspeed.launcher.runner import encode_world_info, parse_args
- import os
- import pytest
- @pytest.fixture
- def runner_info():
- hosts = {'worker-0': 4, 'worker-1': 4}
- world_info = encode_world_info(hosts)
- env = deepcopy(os.environ)
- args = parse_args(['test_launcher.py'])
- return env, hosts, world_info, args
- def test_pdsh_runner(runner_info):
- env, resource_pool, world_info, args = runner_info
- runner = mnrunner.PDSHRunner(args, world_info)
- cmd, kill_cmd, env = runner.get_cmd(env, resource_pool)
- assert cmd[0] == 'pdsh'
- assert env['PDSH_RCMD_TYPE'] == 'ssh'
- def test_openmpi_runner(runner_info):
- env, resource_pool, world_info, args = runner_info
- runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
- cmd = runner.get_cmd(env, resource_pool)
- assert cmd[0] == 'mpirun'
- def test_mpich_runner(runner_info):
- env, resource_pool, world_info, args = runner_info
- runner = mnrunner.MPICHRunner(args, world_info, resource_pool)
- cmd = runner.get_cmd(env, resource_pool)
- assert cmd[0] == 'mpirun'
- def test_slurm_runner(runner_info):
- env, resource_pool, world_info, args = runner_info
- runner = mnrunner.SlurmRunner(args, world_info, resource_pool)
- cmd = runner.get_cmd(env, resource_pool)
- assert cmd[0] == 'srun'
- def test_mvapich_runner(runner_info):
- env, resource_pool, world_info, args = runner_info
- runner = mnrunner.MVAPICHRunner(args, world_info, resource_pool)
- cmd = runner.get_cmd(env, resource_pool)
- assert cmd[0] == 'mpirun'
|