test_ds_arguments.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import argparse
  5. import pytest
  6. import deepspeed
  7. from deepspeed.utils.numa import parse_range_list
  8. def basic_parser():
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument('--num_epochs', type=int)
  11. return parser
  12. def test_no_ds_arguments_no_ds_parser():
  13. parser = basic_parser()
  14. args = parser.parse_args(['--num_epochs', '2'])
  15. assert args
  16. assert hasattr(args, 'num_epochs')
  17. assert args.num_epochs == 2
  18. assert not hasattr(args, 'deepspeed')
  19. assert not hasattr(args, 'deepspeed_config')
  20. def test_no_ds_arguments():
  21. parser = basic_parser()
  22. parser = deepspeed.add_config_arguments(parser)
  23. args = parser.parse_args(['--num_epochs', '2'])
  24. assert args
  25. assert hasattr(args, 'num_epochs')
  26. assert args.num_epochs == 2
  27. assert hasattr(args, 'deepspeed')
  28. assert args.deepspeed == False
  29. assert hasattr(args, 'deepspeed_config')
  30. assert args.deepspeed_config is None
  31. def test_no_ds_enable_argument():
  32. parser = basic_parser()
  33. parser = deepspeed.add_config_arguments(parser)
  34. args = parser.parse_args(['--num_epochs', '2', '--deepspeed_config', 'foo.json'])
  35. assert args
  36. assert hasattr(args, 'num_epochs')
  37. assert args.num_epochs == 2
  38. assert hasattr(args, 'deepspeed')
  39. assert args.deepspeed == False
  40. assert hasattr(args, 'deepspeed_config')
  41. assert type(args.deepspeed_config) == str
  42. assert args.deepspeed_config == 'foo.json'
  43. def test_no_ds_config_argument():
  44. parser = basic_parser()
  45. parser = deepspeed.add_config_arguments(parser)
  46. args = parser.parse_args(['--num_epochs', '2', '--deepspeed'])
  47. assert args
  48. assert hasattr(args, 'num_epochs')
  49. assert args.num_epochs == 2
  50. assert hasattr(args, 'deepspeed')
  51. assert type(args.deepspeed) == bool
  52. assert args.deepspeed == True
  53. assert hasattr(args, 'deepspeed_config')
  54. assert args.deepspeed_config is None
  55. def test_no_ds_parser():
  56. parser = basic_parser()
  57. with pytest.raises(SystemExit):
  58. args = parser.parse_args(['--num_epochs', '2', '--deepspeed'])
  59. def test_core_deepscale_arguments():
  60. parser = basic_parser()
  61. parser = deepspeed.add_config_arguments(parser)
  62. args = parser.parse_args(['--num_epochs', '2', '--deepspeed', '--deepspeed_config', 'foo.json'])
  63. assert args
  64. assert hasattr(args, 'num_epochs')
  65. assert args.num_epochs == 2
  66. assert hasattr(args, 'deepspeed')
  67. assert type(args.deepspeed) == bool
  68. assert args.deepspeed == True
  69. assert hasattr(args, 'deepspeed_config')
  70. assert type(args.deepspeed_config) == str
  71. assert args.deepspeed_config == 'foo.json'
  72. def test_core_binding_arguments():
  73. core_list = parse_range_list("0,2-4,6,8-9")
  74. assert core_list == [0, 2, 3, 4, 6, 8, 9]
  75. try:
  76. # negative case for range overlapping
  77. core_list = parse_range_list("0,2-6,5-9")
  78. except ValueError as e:
  79. pass
  80. else:
  81. # invalid core list must fail
  82. assert False
  83. try:
  84. # negative case for reverse order -- case 1
  85. core_list = parse_range_list("8,2-6")
  86. except ValueError as e:
  87. pass
  88. else:
  89. # invalid core list must fail
  90. assert False
  91. try:
  92. # negative case for reverse order -- case 2
  93. core_list = parse_range_list("1,6-2")
  94. except ValueError as e:
  95. pass
  96. else:
  97. # invalid core list must fail
  98. assert False