test_configurable_parallel_mp.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import torch
  6. import deepspeed
  7. import pytest
  8. import random
  9. import numpy as np
  10. import deepspeed.comm as dist
  11. from deepspeed.accelerator import get_accelerator
  12. from unit.common import DistributedTest, DistributedFixture
  13. from unit.megatron_model import get_gpt2_model, get_megatron_version
  14. from deepspeed.runtime.utils import required_torch_version
  15. pytestmark = pytest.mark.skipif(not required_torch_version(min_version=1.5, max_version=1.13),
  16. reason='Megatron-LM package requires Pytorch version >=1.5 and <=1.13')
  17. # TODO: integrated testing of TP and ZeRO 1/2/3
  18. def get_deepspeed_model(model):
  19. ds_config_dict = {
  20. "train_micro_batch_size_per_gpu": 1,
  21. "optimizer": {
  22. "type": "Lamb",
  23. "params": {
  24. "lr": 0.00015
  25. }
  26. },
  27. }
  28. from megatron import mpu
  29. model, _, _, _ = deepspeed.initialize(model=model,
  30. mpu=mpu,
  31. model_parameters=model.parameters(),
  32. config=ds_config_dict)
  33. return model
  34. class ConfigurableMP(DistributedTest):
  35. @pytest.fixture(autouse=True)
  36. def reset_random(self, seed=1234):
  37. random.seed(seed)
  38. np.random.seed(seed)
  39. torch.manual_seed(seed)
  40. get_accelerator().manual_seed_all(seed)
  41. @pytest.fixture
  42. def inputs(self, bs=1, seq_len=20):
  43. input_ids = torch.randint(low=0, high=1000, size=(bs, seq_len))
  44. position_ids = torch.randint(low=0, high=2, size=(bs, seq_len))
  45. attention_mask = torch.randint(low=0, high=2, size=(bs, seq_len), dtype=torch.bool)
  46. return [input_ids, position_ids, attention_mask]
  47. class TestConfigurableMP(ConfigurableMP):
  48. @pytest.mark.world_size(1)
  49. @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
  50. def test_gpt2_basic(self, tmpdir, inputs):
  51. args_defaults = {
  52. 'num_layers': 2,
  53. 'hidden_size': 128,
  54. 'num_attention_heads': 8,
  55. 'max_position_embeddings': 128,
  56. }
  57. model = get_gpt2_model(args_defaults)
  58. model = get_deepspeed_model(model)
  59. model.eval()
  60. device_name = get_accelerator().device_name()
  61. baseline = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
  62. tag = 'mp_1'
  63. state_dict = {}
  64. state_dict['checkpoint_version'] = get_megatron_version()
  65. model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
  66. dist.barrier()
  67. model.load_checkpoint(tmpdir, tag=tag, load_optimizer_states=False, load_lr_scheduler_states=False)
  68. test = model(inputs[0], inputs[1], inputs[2])
  69. assert torch.allclose(baseline, test,
  70. atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"
  71. @pytest.mark.world_size(2)
  72. @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
  73. def test_gpt2_mp2_no_resize(self, tmpdir, inputs):
  74. args_defaults = {
  75. 'num_layers': 2,
  76. 'hidden_size': 128,
  77. 'num_attention_heads': 8,
  78. 'max_position_embeddings': 128,
  79. }
  80. model = get_gpt2_model(args_defaults, mp_size=2)
  81. model = get_deepspeed_model(model)
  82. model.eval()
  83. device_name = get_accelerator().device_name()
  84. baseline = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
  85. tag = 'mp_2'
  86. state_dict = {}
  87. state_dict['checkpoint_version'] = get_megatron_version()
  88. model.save_checkpoint(tmpdir, tag=tag, client_state=state_dict)
  89. dist.barrier()
  90. model.load_checkpoint(tmpdir, tag=tag, load_optimizer_states=False, load_lr_scheduler_states=False)
  91. device_name = get_accelerator().device_name()
  92. test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
  93. assert torch.allclose(baseline, test, rtol=1.0,
  94. atol=1e-07), f"Baseline output {baseline} is not equal to save-then-load output {test}"
  95. # This fixture provides the baseline model with mp=2 to TestConfigurableMPResize
  96. class baseline_mp2(DistributedFixture):
  97. world_size = 2
  98. def run(self, inputs, class_tmpdir):
  99. args_defaults = {
  100. 'num_layers': 2,
  101. 'hidden_size': 128,
  102. 'num_attention_heads': 8,
  103. 'max_position_embeddings': 128,
  104. }
  105. model = get_gpt2_model(args_defaults, mp_size=self.world_size)
  106. model = get_deepspeed_model(model)
  107. model.eval()
  108. with torch.no_grad():
  109. device_name = get_accelerator().device_name()
  110. baseline = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
  111. if dist.get_rank() == 0:
  112. save_path = os.path.join(class_tmpdir, "output.pt")
  113. torch.save(baseline.cpu(), save_path)
  114. state_dict = {}
  115. state_dict['checkpoint_version'] = get_megatron_version()
  116. model.save_checkpoint(class_tmpdir, client_state=state_dict)
  117. class TestConfigurableResizeMP(ConfigurableMP):
  118. world_size = [1, 4]
  119. @pytest.mark.skip(reason="megatron-lm is currently broken so this test cannot be run.")
  120. def test(self, baseline_mp2, inputs, class_tmpdir):
  121. args_defaults = {
  122. 'num_layers': 2,
  123. 'hidden_size': 128,
  124. 'num_attention_heads': 8,
  125. 'max_position_embeddings': 128,
  126. }
  127. world_size = os.environ["WORLD_SIZE"]
  128. model = get_gpt2_model(args_defaults, mp_size=world_size)
  129. model = get_deepspeed_model(model)
  130. model.eval()
  131. with torch.no_grad():
  132. model.load_checkpoint(class_tmpdir, load_optimizer_states=False, load_lr_scheduler_states=False)
  133. device_name = get_accelerator().device_name()
  134. test = model(inputs[0].to(device_name), inputs[1].to(device_name), inputs[2].to(device_name))
  135. if dist.get_rank() == 0:
  136. load_path = os.path.join(class_tmpdir, "output.pt")
  137. baseline = torch.load(load_path)
  138. test = test.cpu()
  139. assert torch.allclose(
  140. baseline, test,
  141. atol=1e-03), f"Baseline output {baseline} is not equal to save-then-load output {test}"