test_config.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # A test on its own
  2. import torch
  3. import pytest
  4. import json
  5. import argparse
  6. from common import distributed_test
  7. from simple_model import SimpleModel, create_config_from_dict, random_dataloader
  8. import torch.distributed as dist
  9. # A test on its own
  10. import deepspeed
  11. from deepspeed.runtime.config import DeepSpeedConfig
  12. def test_cuda():
  13. assert (torch.cuda.is_available())
  14. def test_check_version():
  15. assert hasattr(deepspeed, "__git_hash__")
  16. assert hasattr(deepspeed, "__git_branch__")
  17. assert hasattr(deepspeed, "__version__")
  18. assert hasattr(deepspeed, "__version_major__")
  19. assert hasattr(deepspeed, "__version_minor__")
  20. assert hasattr(deepspeed, "__version_patch__")
  21. def _run_batch_config(ds_config, train_batch=None, micro_batch=None, gas=None):
  22. ds_config.train_batch_size = train_batch
  23. ds_config.train_micro_batch_size_per_gpu = micro_batch
  24. ds_config.gradient_accumulation_steps = gas
  25. success = True
  26. try:
  27. ds_config._configure_train_batch_size()
  28. except AssertionError:
  29. success = False
  30. return success
  31. def _batch_assert(status, ds_config, batch, micro_batch, gas, success):
  32. if not success:
  33. assert not status
  34. print("Failed but All is well")
  35. return
  36. assert ds_config.train_batch_size == batch
  37. assert ds_config.train_micro_batch_size_per_gpu == micro_batch
  38. assert ds_config.gradient_accumulation_steps == gas
  39. print("All is well")
  40. #Tests different batch config provided in deepspeed json file
  41. @pytest.mark.parametrize('num_ranks,batch,micro_batch,gas,success',
  42. [(2,32,16,1,True),
  43. (2,32,8,2,True),
  44. (2,33,17,2,False),
  45. (2,32,18,1,False)]) # yapf: disable
  46. def test_batch_config(num_ranks, batch, micro_batch, gas, success):
  47. @distributed_test(world_size=2)
  48. def _test_batch_config(num_ranks, batch, micro_batch, gas, success):
  49. assert dist.get_world_size() == num_ranks, \
  50. 'The test assumes a world size of f{num_ranks}'
  51. ds_batch_config = 'tests/unit/ds_batch_config.json'
  52. ds_config = DeepSpeedConfig(ds_batch_config)
  53. #test cases when all parameters are provided
  54. status = _run_batch_config(ds_config,
  55. train_batch=batch,
  56. micro_batch=micro_batch,
  57. gas=gas)
  58. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  59. #test cases when two out of three parameters are provided
  60. status = _run_batch_config(ds_config, train_batch=batch, micro_batch=micro_batch)
  61. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  62. if success:
  63. #when gas is provided with one more parameter
  64. status = _run_batch_config(ds_config, train_batch=batch, gas=gas)
  65. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  66. status = _run_batch_config(ds_config, micro_batch=micro_batch, gas=gas)
  67. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  68. #test the case when only micro_batch or train_batch is provided
  69. if gas == 1:
  70. status = _run_batch_config(ds_config, micro_batch=micro_batch)
  71. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  72. status = _run_batch_config(ds_config, train_batch=batch)
  73. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  74. else:
  75. #when only gas is provided
  76. status = _run_batch_config(ds_config, gas=gas)
  77. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  78. #when gas is provided with something else and gas does not divide batch
  79. if gas != 1:
  80. status = _run_batch_config(ds_config, train_batch=batch, gas=gas)
  81. _batch_assert(status, ds_config, batch, micro_batch, gas, success)
  82. """Run batch config test """
  83. _test_batch_config(num_ranks, batch, micro_batch, gas, success)
  84. def test_temp_config_json(tmpdir):
  85. config_dict = {
  86. "train_batch_size": 1,
  87. }
  88. config_path = create_config_from_dict(tmpdir, config_dict)
  89. config_json = json.load(open(config_path, 'r'))
  90. assert 'train_batch_size' in config_json
  91. def test_deprecated_deepscale_config(tmpdir):
  92. config_dict = {
  93. "train_batch_size": 1,
  94. "optimizer": {
  95. "type": "Adam",
  96. "params": {
  97. "lr": 0.00015
  98. }
  99. },
  100. "fp16": {
  101. "enabled": True
  102. }
  103. }
  104. config_path = create_config_from_dict(tmpdir, config_dict)
  105. parser = argparse.ArgumentParser()
  106. args = parser.parse_args(args='')
  107. args.deepscale_config = config_path
  108. args.local_rank = 0
  109. hidden_dim = 10
  110. model = SimpleModel(hidden_dim)
  111. @distributed_test(world_size=[1])
  112. def _test_deprecated_deepscale_config(args, model, hidden_dim):
  113. model, _, _,_ = deepspeed.initialize(args=args,
  114. model=model,
  115. model_parameters=model.parameters())
  116. data_loader = random_dataloader(model=model,
  117. total_samples=5,
  118. hidden_dim=hidden_dim,
  119. device=model.device)
  120. for n, batch in enumerate(data_loader):
  121. loss = model(batch[0], batch[1])
  122. model.backward(loss)
  123. model.step()
  124. _test_deprecated_deepscale_config(args=args, model=model, hidden_dim=hidden_dim)
  125. def test_dist_init_true(tmpdir):
  126. config_dict = {
  127. "train_batch_size": 1,
  128. "optimizer": {
  129. "type": "Adam",
  130. "params": {
  131. "lr": 0.00015
  132. }
  133. },
  134. "fp16": {
  135. "enabled": True
  136. }
  137. }
  138. config_path = create_config_from_dict(tmpdir, config_dict)
  139. parser = argparse.ArgumentParser()
  140. args = parser.parse_args(args='')
  141. args.deepscale_config = config_path
  142. args.local_rank = 0
  143. hidden_dim = 10
  144. model = SimpleModel(hidden_dim)
  145. @distributed_test(world_size=[1])
  146. def _test_dist_init_true(args, model, hidden_dim):
  147. model, _, _,_ = deepspeed.initialize(args=args,
  148. model=model,
  149. model_parameters=model.parameters(),
  150. dist_init_required=True)
  151. data_loader = random_dataloader(model=model,
  152. total_samples=5,
  153. hidden_dim=hidden_dim,
  154. device=model.device)
  155. for n, batch in enumerate(data_loader):
  156. loss = model(batch[0], batch[1])
  157. model.backward(loss)
  158. model.step()
  159. _test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)