test_multi_output_model.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import deepspeed
  6. from pytest import approx
  7. from unit.common import DistributedTest
  8. from unit.multi_output_model import MultiOutputModel, multi_output_dataloader
  9. class TestTwoOutputModel(DistributedTest):
  10. world_size = 1
  11. def test(self, tmpdir):
  12. grad_accumulation_steps = 2
  13. micro_batch_size = 1
  14. world_size = self.world_size
  15. config_dict = {
  16. "train_micro_batch_size_per_gpu": micro_batch_size,
  17. "gradient_accumulation_steps": grad_accumulation_steps,
  18. "train_batch_size": micro_batch_size * grad_accumulation_steps * world_size,
  19. "steps_per_print": 1,
  20. "optimizer": {
  21. "type": "Adam",
  22. "params": {
  23. "lr": 0.00015
  24. }
  25. },
  26. "fp16": {
  27. "enabled": True
  28. }
  29. }
  30. hidden_dim = 10
  31. weight_value = 0.1
  32. model = MultiOutputModel(hidden_dim, weight_value)
  33. model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
  34. total_samples = 4
  35. data_loader = multi_output_dataloader(model=model,
  36. total_samples=total_samples,
  37. hidden_dim=hidden_dim,
  38. device=model.device,
  39. inputs=[1.0, 2.0],
  40. targets=[1, 2])
  41. for n, batch in enumerate(data_loader):
  42. assert len(batch) % 2 == 0, \
  43. f"multi_output_dataloader failed to return even number of data samples (input+target)"
  44. midpoint = len(batch) // 2
  45. inputs, targets = batch[:midpoint], batch[midpoint:]
  46. loss_tuple = model(inputs, targets)
  47. expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device)
  48. for loss in loss_tuple:
  49. assert loss.shape == torch.Size([])
  50. assert loss.item() == approx(expected_loss.item())
  51. summed_loss = sum(loss_tuple)
  52. scaled_loss = model.backward(summed_loss)
  53. expected_scaled_loss = summed_loss.float() / grad_accumulation_steps
  54. assert scaled_loss.item() == approx(expected_scaled_loss.item())
  55. model.step()
  56. class TestThreeOutputModel(DistributedTest):
  57. world_size = 1
  58. def test(self, tmpdir):
  59. grad_accumulation_steps = 3
  60. micro_batch_size = 1
  61. world_size = 1
  62. config_dict = {
  63. "train_micro_batch_size_per_gpu": micro_batch_size,
  64. "gradient_accumulation_steps": grad_accumulation_steps,
  65. "train_batch_size": micro_batch_size * grad_accumulation_steps * world_size,
  66. "steps_per_print": 1,
  67. "optimizer": {
  68. "type": "Adam",
  69. "params": {
  70. "lr": 0.00015
  71. }
  72. },
  73. "fp16": {
  74. "enabled": True
  75. }
  76. }
  77. hidden_dim = 10
  78. weight_value = 0.1
  79. model = MultiOutputModel(hidden_dim, weight_value)
  80. model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters())
  81. total_samples = grad_accumulation_steps * micro_batch_size * 2
  82. data_loader = multi_output_dataloader(model=model,
  83. total_samples=total_samples,
  84. hidden_dim=hidden_dim,
  85. device=model.device,
  86. inputs=[1.0, 2.0, 3.0],
  87. targets=[1, 2, 3])
  88. for n, batch in enumerate(data_loader):
  89. assert len(batch) % 2 == 0, \
  90. f"multi_output_dataloader failed to return even number of data samples (input+target)"
  91. midpoint = len(batch) // 2
  92. inputs, targets = batch[:midpoint], batch[midpoint:]
  93. loss_tuple = model(inputs, targets)
  94. assert len(loss_tuple) == 3
  95. expected_loss = torch.tensor(2.302734375, dtype=torch.half, device=model.device)
  96. for loss in loss_tuple:
  97. assert loss.shape == torch.Size([])
  98. assert loss.item() == approx(expected_loss.item())
  99. summed_loss = sum(loss_tuple)
  100. scaled_loss = model.backward(summed_loss)
  101. expected_scaled_loss = summed_loss.float() / grad_accumulation_steps
  102. assert scaled_loss.item() == approx(expected_scaled_loss.item())
  103. model.step()