test_run.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. from deepspeed.launcher import runner as dsrun
  6. def test_parser_mutual_exclusive():
  7. '''Ensure dsrun.parse_resource_filter() raises a ValueError when include_str and
  8. exclude_str are both provided.
  9. '''
  10. with pytest.raises(ValueError):
  11. dsrun.parse_resource_filter({}, include_str='A', exclude_str='B')
  12. def test_parser_local():
  13. ''' Test cases with only one node. '''
  14. # First try no include/exclude
  15. hosts = {'worker-0': [0, 1, 2, 3]}
  16. ret = dsrun.parse_resource_filter(hosts)
  17. assert (ret == hosts)
  18. # exclude slots
  19. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:1')
  20. assert (ret == {'worker-0': [0, 2, 3]})
  21. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:1,2')
  22. assert (ret == {'worker-0': [0, 3]})
  23. # only use one slot
  24. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1')
  25. assert (ret == {'worker-0': [1]})
  26. # including slots multiple times shouldn't break things
  27. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1,1')
  28. assert (ret == {'worker-0': [1]})
  29. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1@worker-0:0,1')
  30. assert (ret == {'worker-0': [0, 1]})
  31. # including just 'worker-0' without : should still use all GPUs
  32. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0')
  33. assert (ret == hosts)
  34. # excluding just 'worker-0' without : should eliminate everything
  35. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0')
  36. assert (ret == {})
  37. # exclude all slots manually
  38. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:0,1,2,3')
  39. assert (ret == {})
  40. def test_parser_multinode():
  41. # First try no include/exclude
  42. hosts = {'worker-0': [0, 1, 2, 3], 'worker-1': [0, 1, 2, 3]}
  43. ret = dsrun.parse_resource_filter(hosts)
  44. assert (ret == hosts)
  45. # include a node
  46. ret = dsrun.parse_resource_filter(hosts, include_str='worker-1:0,3')
  47. assert (ret == {'worker-1': [0, 3]})
  48. # exclude a node
  49. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-1')
  50. assert (ret == {'worker-0': [0, 1, 2, 3]})
  51. # exclude part of each node
  52. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:0,1@worker-1:3')
  53. assert (ret == {'worker-0': [2, 3], 'worker-1': [0, 1, 2]})
  54. def test_parser_errors():
  55. '''Ensure we catch errors. '''
  56. hosts = {'worker-0': [0, 1, 2, 3], 'worker-1': [0, 1, 2, 3]}
  57. # host does not exist
  58. with pytest.raises(ValueError):
  59. dsrun.parse_resource_filter(hosts, include_str='jeff')
  60. with pytest.raises(ValueError):
  61. dsrun.parse_resource_filter(hosts, exclude_str='jeff')
  62. # slot does not exist
  63. with pytest.raises(ValueError):
  64. dsrun.parse_resource_filter(hosts, include_str='worker-1:4')
  65. with pytest.raises(ValueError):
  66. dsrun.parse_resource_filter(hosts, exclude_str='worker-1:4')
  67. # formatting
  68. with pytest.raises(ValueError):
  69. dsrun.parse_resource_filter(hosts, exclude_str='worker-1@worker-0:1@5')
  70. def test_num_plus_parser():
  71. ''' Ensure we catch errors relating to num_nodes/num_gpus + -i/-e being mutually exclusive'''
  72. # inclusion
  73. with pytest.raises(ValueError):
  74. dsrun.main(args="--num_nodes 1 -i localhost foo.py".split())
  75. with pytest.raises(ValueError):
  76. dsrun.main(args="--num_nodes 1 --num_gpus 1 -i localhost foo.py".split())
  77. with pytest.raises(ValueError):
  78. dsrun.main(args="--num_gpus 1 -i localhost foo.py".split())
  79. # exclusion
  80. with pytest.raises(ValueError):
  81. dsrun.main(args="--num_nodes 1 -e localhost foo.py".split())
  82. with pytest.raises(ValueError):
  83. dsrun.main(args="--num_nodes 1 --num_gpus 1 -e localhost foo.py".split())
  84. with pytest.raises(ValueError):
  85. dsrun.main(args="--num_gpus 1 -e localhost foo.py".split())
  86. def test_hostfile_good():
  87. # good hostfile w. empty lines and comment
  88. hostfile = """
  89. worker-1 slots=2
  90. worker-2 slots=2
  91. localhost slots=1
  92. 123.23.12.10 slots=2
  93. #worker-1 slots=3
  94. # this is a comment
  95. """
  96. r = dsrun._parse_hostfile(hostfile.splitlines())
  97. assert "worker-1" in r
  98. assert "worker-2" in r
  99. assert "localhost" in r
  100. assert "123.23.12.10" in r
  101. assert r["worker-1"] == 2
  102. assert r["worker-2"] == 2
  103. assert r["localhost"] == 1
  104. assert r["123.23.12.10"] == 2
  105. assert len(r) == 4
  106. def test_hostfiles_bad():
  107. # duplicate host
  108. hostfile = """
  109. worker-1 slots=2
  110. worker-2 slots=1
  111. worker-1 slots=1
  112. """
  113. with pytest.raises(ValueError):
  114. dsrun._parse_hostfile(hostfile.splitlines())
  115. # incorrect whitespace
  116. hostfile = """
  117. this is bad slots=1
  118. """
  119. with pytest.raises(ValueError):
  120. dsrun._parse_hostfile(hostfile.splitlines())
  121. # no whitespace
  122. hostfile = """
  123. missingslots
  124. """
  125. with pytest.raises(ValueError):
  126. dsrun._parse_hostfile(hostfile.splitlines())
  127. # empty
  128. hostfile = """
  129. """
  130. with pytest.raises(ValueError):
  131. dsrun._parse_hostfile(hostfile.splitlines())
  132. # mix of good/bad
  133. hostfile = """
  134. worker-1 slots=2
  135. this is bad slots=1
  136. worker-2 slots=4
  137. missingslots
  138. """
  139. with pytest.raises(ValueError):
  140. dsrun._parse_hostfile(hostfile.splitlines())