test_mics_config.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  5. # SPDX-License-Identifier: Apache-2.0
  6. """
  7. Testing on a 8 GPUs node
  8. NDEV_PER_NODE=2 torchrun --nnodes 1 --nproc-per-node 8 test_mics_config.py
  9. """
  10. import os
  11. import json
  12. import argparse
  13. import torch
  14. import deepspeed
  15. from torch.utils.data.distributed import DistributedSampler
  16. import deepspeed.comm as dist
  17. class SimpleModel(torch.nn.Module):
  18. def __init__(self, hidden_dim, empty_grad=False):
  19. super(SimpleModel, self).__init__()
  20. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  21. if empty_grad:
  22. self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
  23. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  24. def forward(self, x, y):
  25. hidden = x
  26. hidden = self.linear(hidden)
  27. return self.cross_entropy_loss(hidden, y)
  28. def create_config_from_dict(tmpdir, config_dict):
  29. config_path = os.path.join(tmpdir, 'temp_config.json')
  30. with open(config_path, 'w') as fd:
  31. json.dump(config_dict, fd)
  32. return config_path
  33. def get_data_loader(model, total_samples, hidden_dim, device):
  34. batch_size = model.train_micro_batch_size_per_gpu()
  35. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.float)
  36. train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
  37. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  38. sampler = DistributedSampler(train_dataset)
  39. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
  40. return train_loader
  41. def get_args(tmpdir, config_dict):
  42. parser = argparse.ArgumentParser()
  43. parser.add_argument('--zero', type=int, default=3)
  44. parser.add_argument('--local_rank', type=int)
  45. parser.add_argument('--mics_shard_size', default=2, type=int)
  46. parser.add_argument('--mics_hierarchical_params_gather', default=False, action='store_true')
  47. args = parser.parse_args() #args=''
  48. config_dict["zero_optimization"]["stage"] = args.zero
  49. config_dict["zero_optimization"]["mics_shard_size"] = args.mics_shard_size
  50. config_dict["zero_optimization"]["mics_hierarchical_params_gather"] = args.mics_hierarchical_params_gather
  51. # print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
  52. config_path = create_config_from_dict(tmpdir, config_dict)
  53. args.deepspeed_config = config_path
  54. return args
  55. def print0(msg):
  56. if dist.get_rank() == 0:
  57. print(msg, flush=True)
  58. rank = int(os.environ['RANK'])
  59. print('seed:', 2222 + rank)
  60. torch.random.manual_seed(2222 + rank)
  61. config_dict = {
  62. "train_batch_size": 8,
  63. "steps_per_print": 1,
  64. "optimizer": {
  65. "type": "Adam",
  66. "params": {
  67. "lr": 0.00015,
  68. }
  69. },
  70. "fp16": {
  71. "enabled": False,
  72. "initial_scale_power": 15
  73. },
  74. "zero_optimization": {
  75. "stage": 3,
  76. "reduce_bucket_size": 20,
  77. "mics_shard_size": 4,
  78. "mics_hierarchical_params_gather": True,
  79. "stage3_model_persistence_threshold": 10
  80. }
  81. }
  82. # "initial_scale_power": 15
  83. args = get_args('/tmp/', config_dict)
  84. hidden_dim = 32
  85. with deepspeed.zero.MiCS_Init(config_dict_or_path=config_dict):
  86. model = SimpleModel(hidden_dim, empty_grad=False)
  87. # print('------> init model with deepspeed.zero.Init()')
  88. model, _, _, _ = deepspeed.initialize(args=args,
  89. model=model,
  90. model_parameters=model.parameters(),
  91. dist_init_required=True)
  92. def print_params(tag, model):
  93. if dist.get_rank() == 0:
  94. for n, p in model.named_parameters():
  95. print0("{} {}:{}".format(tag, n, p))
  96. data_loader = get_data_loader(model=model, total_samples=1000, hidden_dim=hidden_dim, device=model.device)
  97. #print_params('pre-train', model)
  98. for n, batch in enumerate(data_loader):
  99. loss = model(batch[0], batch[1])
  100. if dist.get_rank() == 0:
  101. print("LOSS:", loss.item())
  102. model.backward(loss)
  103. model.step()
  104. #print_params('step={}'.format(n), model)
  105. if n == 5: break