test_model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import json
  6. import argparse
  7. import torch
  8. import deepspeed
  9. from torch.utils.data.distributed import DistributedSampler
  10. import deepspeed.comm as dist
  11. class SimpleModel(torch.nn.Module):
  12. def __init__(self, hidden_dim, empty_grad=False):
  13. super(SimpleModel, self).__init__()
  14. self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
  15. self.linear = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
  16. if empty_grad:
  17. self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim,
  18. hidden_dim)]) #QuantizeLinear(hidden_dim, hidden_dim)
  19. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  20. def forward(self, x, y):
  21. hidden = x
  22. hidden1 = self.linear(hidden)
  23. hidden2 = self.linear(hidden1)
  24. return self.cross_entropy_loss(hidden2, y)
  25. def create_config_from_dict(tmpdir, config_dict):
  26. config_path = os.path.join(tmpdir, 'temp_config.json')
  27. with open(config_path, 'w') as fd:
  28. json.dump(config_dict, fd)
  29. return config_path
  30. def get_data_loader(model, total_samples, hidden_dim, device):
  31. batch_size = model.train_micro_batch_size_per_gpu()
  32. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
  33. train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
  34. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  35. sampler = DistributedSampler(train_dataset)
  36. train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
  37. return train_loader
  38. def get_args(tmpdir, config_dict):
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument("--local_rank", type=int, default=0)
  41. parser.add_argument('--zero', type=int, default=0)
  42. parser.add_argument('--zero_hpz_partition_size', type=int, default=1)
  43. args = parser.parse_args() #args=''
  44. config_dict["zero_optimization"]["stage"] = args.zero
  45. config_dict["zero_optimization"]["zero_hpz_partition_size"] = args.zero_hpz_partition_size
  46. print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
  47. config_path = create_config_from_dict(tmpdir, config_dict)
  48. args.deepspeed_config = config_path
  49. return args
  50. def print0(msg):
  51. if dist.get_rank() == 0:
  52. print(msg, flush=True)
  53. rank = int(os.environ['RANK'])
  54. print('seed:', 2222 + rank)
  55. torch.random.manual_seed(2222 + rank)
  56. config_dict = {
  57. "train_batch_size": 256,
  58. "steps_per_print": 1,
  59. "optimizer": {
  60. "type": "Adam",
  61. "params": {
  62. "lr": 0.00015,
  63. }
  64. },
  65. "fp16": {
  66. "enabled": True,
  67. "initial_scale_power": 8
  68. },
  69. "zero_optimization": {
  70. "stage": 0,
  71. "reduce_bucket_size": 20,
  72. "zero_hpz_partition_size": 1,
  73. "reduce_scatter": True,
  74. "zero_quantized_weights": False,
  75. "zero_quantized_gradients": False
  76. }
  77. }
  78. # "initial_scale_power": 15
  79. args = get_args('/tmp/', config_dict)
  80. hidden_dim = 4 * 1024
  81. model = SimpleModel(hidden_dim, empty_grad=False)
  82. model, _, _, _ = deepspeed.initialize(args=args,
  83. model=model,
  84. model_parameters=model.parameters(),
  85. dist_init_required=True)
  86. def print_params(tag, model):
  87. if dist.get_rank() == 0:
  88. for n, p in model.named_parameters():
  89. print0("{} {}:{}".format(tag, n, p))
  90. data_loader = get_data_loader(model=model, total_samples=256, hidden_dim=hidden_dim, device=model.device)
  91. #print_params('pre-train', model)
  92. for n, batch in enumerate(data_loader):
  93. loss = model(batch[0], batch[1])
  94. if dist.get_rank() == 0:
  95. print("LOSS:", loss.item())
  96. model.backward(loss)
  97. model.step()
  98. #print_params('step={}'.format(n), model)
  99. #if n == 5: break