test_ds_arguments.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import argparse
  2. import pytest
  3. import deepspeed
  4. def basic_parser():
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument('--num_epochs', type=int)
  7. return parser
  8. def test_no_ds_arguments_no_ds_parser():
  9. parser = basic_parser()
  10. args = parser.parse_args(['--num_epochs', '2'])
  11. assert args
  12. assert hasattr(args, 'num_epochs')
  13. assert args.num_epochs == 2
  14. assert not hasattr(args, 'deepspeed')
  15. assert not hasattr(args, 'deepspeed_config')
  16. def test_no_ds_arguments():
  17. parser = basic_parser()
  18. parser = deepspeed.add_config_arguments(parser)
  19. args = parser.parse_args(['--num_epochs', '2'])
  20. assert args
  21. assert hasattr(args, 'num_epochs')
  22. assert args.num_epochs == 2
  23. assert hasattr(args, 'deepspeed')
  24. assert args.deepspeed == False
  25. assert hasattr(args, 'deepspeed_config')
  26. assert args.deepspeed_config == None
  27. def test_no_ds_enable_argument():
  28. parser = basic_parser()
  29. parser = deepspeed.add_config_arguments(parser)
  30. args = parser.parse_args(['--num_epochs', '2', '--deepspeed_config', 'foo.json'])
  31. assert args
  32. assert hasattr(args, 'num_epochs')
  33. assert args.num_epochs == 2
  34. assert hasattr(args, 'deepspeed')
  35. assert args.deepspeed == False
  36. assert hasattr(args, 'deepspeed_config')
  37. assert type(args.deepspeed_config) == str
  38. assert args.deepspeed_config == 'foo.json'
  39. def test_no_ds_config_argument():
  40. parser = basic_parser()
  41. parser = deepspeed.add_config_arguments(parser)
  42. args = parser.parse_args(['--num_epochs', '2', '--deepspeed'])
  43. assert args
  44. assert hasattr(args, 'num_epochs')
  45. assert args.num_epochs == 2
  46. assert hasattr(args, 'deepspeed')
  47. assert type(args.deepspeed) == bool
  48. assert args.deepspeed == True
  49. assert hasattr(args, 'deepspeed_config')
  50. assert args.deepspeed_config == None
  51. def test_no_ds_parser():
  52. parser = basic_parser()
  53. with pytest.raises(SystemExit):
  54. args = parser.parse_args(['--num_epochs', '2', '--deepspeed'])
  55. def test_core_deepscale_arguments():
  56. parser = basic_parser()
  57. parser = deepspeed.add_config_arguments(parser)
  58. args = parser.parse_args(
  59. ['--num_epochs',
  60. '2',
  61. '--deepspeed',
  62. '--deepspeed_config',
  63. 'foo.json'])
  64. assert args
  65. assert hasattr(args, 'num_epochs')
  66. assert args.num_epochs == 2
  67. assert hasattr(args, 'deepspeed')
  68. assert type(args.deepspeed) == bool
  69. assert args.deepspeed == True
  70. assert hasattr(args, 'deepspeed_config')
  71. assert type(args.deepspeed_config) == str
  72. assert args.deepspeed_config == 'foo.json'