test_model.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import os
  2. import json
  3. import argparse
  4. import torch
  5. import deepspeed
  6. from torch.utils.data.distributed import DistributedSampler
  7. class SimpleModel(torch.nn.Module):
  8. def __init__(self, hidden_dim, empty_grad=False):
  9. super(SimpleModel, self).__init__()
  10. self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
  11. if empty_grad:
  12. self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
  13. self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
  14. def forward(self, x, y):
  15. hidden_dim = x
  16. hidden_dim = self.linear(hidden_dim)
  17. return self.cross_entropy_loss(hidden_dim, y)
  18. def create_config_from_dict(tmpdir, config_dict):
  19. config_path = os.path.join(tmpdir, 'temp_config.json')
  20. with open(config_path, 'w') as fd:
  21. json.dump(config_dict, fd)
  22. return config_path
  23. def get_data_loader(model, total_samples, hidden_dim, device):
  24. batch_size = model.train_micro_batch_size_per_gpu()
  25. train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=torch.half)
  26. train_label = torch.empty(total_samples,
  27. dtype=torch.long,
  28. device=device).random_(hidden_dim)
  29. train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
  30. sampler = DistributedSampler(train_dataset)
  31. train_loader = torch.utils.data.DataLoader(train_dataset,
  32. batch_size=batch_size,
  33. sampler=sampler)
  34. return train_loader
  35. def get_args(tmpdir, config_dict):
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument("--local_rank", type=int, default=0)
  38. parser.add_argument('--zero', type=int, default=0)
  39. args = parser.parse_args() #args=''
  40. config_dict["zero_optimization"]["stage"] = args.zero
  41. print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
  42. config_path = create_config_from_dict(tmpdir, config_dict)
  43. args.deepspeed_config = config_path
  44. return args
  45. def print0(msg):
  46. if torch.distributed.get_rank() == 0:
  47. print(msg, flush=True)
  48. rank = int(os.environ['RANK'])
  49. print('seed:', 2222 + rank)
  50. torch.random.manual_seed(2222 + rank)
  51. config_dict = {
  52. "train_batch_size": 8,
  53. "steps_per_print": 1,
  54. "optimizer": {
  55. "type": "Adam",
  56. "params": {
  57. "lr": 0.00015,
  58. }
  59. },
  60. "fp16": {
  61. "enabled": True,
  62. "initial_scale_power": 15
  63. },
  64. "zero_optimization": {
  65. "stage": 0,
  66. "reduce_bucket_size": 20
  67. }
  68. }
  69. # "initial_scale_power": 15
  70. args = get_args('/tmp/', config_dict)
  71. hidden_dim = 4
  72. model = SimpleModel(hidden_dim, empty_grad=False)
  73. model, _, _,_ = deepspeed.initialize(args=args,
  74. model=model,
  75. model_parameters=model.parameters(),
  76. dist_init_required=True)
  77. def print_params(tag, model):
  78. if torch.distributed.get_rank() == 0:
  79. for n, p in model.named_parameters():
  80. print0("{} {}:{}".format(tag, n, p))
  81. data_loader = get_data_loader(model=model,
  82. total_samples=1000,
  83. hidden_dim=hidden_dim,
  84. device=model.device)
  85. #print_params('pre-train', model)
  86. for n, batch in enumerate(data_loader):
  87. loss = model(batch[0], batch[1])
  88. if torch.distributed.get_rank() == 0:
  89. print("LOSS:", loss.item())
  90. model.backward(loss)
  91. model.step()
  92. #print_params('step={}'.format(n), model)
  93. if n == 5: break