test_run.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import pytest
  2. from deepspeed.launcher import runner as dsrun
  3. def test_parser_mutual_exclusive():
  4. '''Ensure dsrun.parse_resource_filter() raises a ValueError when include_str and
  5. exclude_str are both provided.
  6. '''
  7. with pytest.raises(ValueError):
  8. dsrun.parse_resource_filter({}, include_str='A', exclude_str='B')
  9. def test_parser_local():
  10. ''' Test cases with only one node. '''
  11. # First try no incude/exclude
  12. hosts = {'worker-0': [0, 1, 2, 3]}
  13. ret = dsrun.parse_resource_filter(hosts)
  14. assert (ret == hosts)
  15. # exclude slots
  16. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:1')
  17. assert (ret == {'worker-0': [0, 2, 3]})
  18. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:1,2')
  19. assert (ret == {'worker-0': [0, 3]})
  20. # only use one slot
  21. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1')
  22. assert (ret == {'worker-0': [1]})
  23. # including slots multiple times shouldn't break things
  24. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1,1')
  25. assert (ret == {'worker-0': [1]})
  26. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0:1@worker-0:0,1')
  27. assert (ret == {'worker-0': [0, 1]})
  28. # including just 'worker-0' without : should still use all GPUs
  29. ret = dsrun.parse_resource_filter(hosts, include_str='worker-0')
  30. assert (ret == hosts)
  31. # excluding just 'worker-0' without : should eliminate everything
  32. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0')
  33. assert (ret == {})
  34. # exclude all slots manually
  35. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:0,1,2,3')
  36. assert (ret == {})
  37. def test_parser_multinode():
  38. # First try no incude/exclude
  39. hosts = {'worker-0': [0, 1, 2, 3], 'worker-1': [0, 1, 2, 3]}
  40. ret = dsrun.parse_resource_filter(hosts)
  41. assert (ret == hosts)
  42. # include a node
  43. ret = dsrun.parse_resource_filter(hosts, include_str='worker-1:0,3')
  44. assert (ret == {'worker-1': [0, 3]})
  45. # exclude a node
  46. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-1')
  47. assert (ret == {'worker-0': [0, 1, 2, 3]})
  48. # exclude part of each node
  49. ret = dsrun.parse_resource_filter(hosts, exclude_str='worker-0:0,1@worker-1:3')
  50. assert (ret == {'worker-0': [2, 3], 'worker-1': [0, 1, 2]})
  51. def test_parser_errors():
  52. '''Ensure we catch errors. '''
  53. hosts = {'worker-0': [0, 1, 2, 3], 'worker-1': [0, 1, 2, 3]}
  54. # host does not exist
  55. with pytest.raises(ValueError):
  56. dsrun.parse_resource_filter(hosts, include_str='jeff')
  57. with pytest.raises(ValueError):
  58. dsrun.parse_resource_filter(hosts, exclude_str='jeff')
  59. # slot does not exist
  60. with pytest.raises(ValueError):
  61. dsrun.parse_resource_filter(hosts, include_str='worker-1:4')
  62. with pytest.raises(ValueError):
  63. dsrun.parse_resource_filter(hosts, exclude_str='worker-1:4')
  64. # formatting
  65. with pytest.raises(ValueError):
  66. dsrun.parse_resource_filter(hosts, exclude_str='worker-1@worker-0:1@5')
  67. def test_num_plus_parser():
  68. ''' Ensure we catch errors relating to num_nodes/num_gpus + -i/-e being mutually exclusive'''
  69. # inclusion
  70. with pytest.raises(ValueError):
  71. dsrun.main(args="--num_nodes 1 -i localhost foo.py".split())
  72. with pytest.raises(ValueError):
  73. dsrun.main(args="--num_nodes 1 --num_gpus 1 -i localhost foo.py".split())
  74. with pytest.raises(ValueError):
  75. dsrun.main(args="--num_gpus 1 -i localhost foo.py".split())
  76. # exclusion
  77. with pytest.raises(ValueError):
  78. dsrun.main(args="--num_nodes 1 -e localhost foo.py".split())
  79. with pytest.raises(ValueError):
  80. dsrun.main(args="--num_nodes 1 --num_gpus 1 -e localhost foo.py".split())
  81. with pytest.raises(ValueError):
  82. dsrun.main(args="--num_gpus 1 -e localhost foo.py".split())