test_multi_output_model.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import torch
  2. import deepspeed
  3. import argparse
  4. import pytest
  5. from pytest import approx
  6. import json
  7. import os
  8. from common import distributed_test
  9. from simple_model import args_from_dict
  10. from multi_output_model import MultiOutputModel, multi_output_dataloader
  11. def create_config_dict(micro_batch_size, grad_accumulation_steps, world_size):
  12. return {
  13. "train_micro_batch_size_per_gpu": micro_batch_size,
  14. "gradient_accumulation_steps": grad_accumulation_steps,
  15. "train_batch_size": micro_batch_size * grad_accumulation_steps * world_size,
  16. "steps_per_print": 1,
  17. "optimizer": {
  18. "type": "Adam",
  19. "params": {
  20. "lr": 0.00015
  21. }
  22. },
  23. "fp16": {
  24. "enabled": True
  25. }
  26. }
  27. def test_two_output_model(tmpdir):
  28. gradient_accumulation_steps = 2
  29. micro_batch_size = 1
  30. world_size = 1
  31. config_dict = create_config_dict(micro_batch_size,
  32. gradient_accumulation_steps,
  33. world_size)
  34. hidden_dim = 10
  35. weight_value = 0.1
  36. args = args_from_dict(tmpdir, config_dict)
  37. model = MultiOutputModel(hidden_dim, weight_value)
  38. @distributed_test(world_size=[1])
  39. def _test_two_output_model(args, model, hidden_dim):
  40. model, _, _, _ = deepspeed.initialize(args=args,
  41. model=model,
  42. model_parameters=model.parameters())
  43. total_samples = 4
  44. data_loader = multi_output_dataloader(model=model,
  45. total_samples=total_samples,
  46. hidden_dim=hidden_dim,
  47. device=model.device,
  48. inputs=[1.0,
  49. 2.0],
  50. targets=[1,
  51. 2])
  52. for n, batch in enumerate(data_loader):
  53. assert len(batch) % 2 == 0, \
  54. f"multi_output_dataloader failed to return even number of data samples (input+target)"
  55. midpoint = len(batch) // 2
  56. inputs, targets = batch[:midpoint], batch[midpoint:]
  57. loss_tuple = model(inputs, targets)
  58. expected_loss = torch.tensor(2.302734375,
  59. dtype=torch.half,
  60. device=model.device)
  61. for loss in loss_tuple:
  62. assert loss.shape == torch.Size([])
  63. assert loss.item() == approx(expected_loss.item())
  64. summed_loss = sum(loss_tuple)
  65. scaled_loss = model.backward(summed_loss)
  66. expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
  67. assert scaled_loss.item() == approx(expected_scaled_loss.item())
  68. model.step()
  69. _test_two_output_model(args=args, model=model, hidden_dim=hidden_dim)
  70. def test_three_output_model(tmpdir):
  71. gradient_accumulation_steps = 3
  72. micro_batch_size = 1
  73. world_size = 1
  74. config_dict = create_config_dict(micro_batch_size,
  75. gradient_accumulation_steps,
  76. world_size)
  77. hidden_dim = 10
  78. weight_value = 0.1
  79. args = args_from_dict(tmpdir, config_dict)
  80. model = MultiOutputModel(hidden_dim, weight_value)
  81. @distributed_test(world_size=[1])
  82. def _test_three_output_model(args, model, hidden_dim):
  83. model, _, _, _ = deepspeed.initialize(args=args,
  84. model=model,
  85. model_parameters=model.parameters())
  86. total_samples = gradient_accumulation_steps * micro_batch_size * 2
  87. data_loader = multi_output_dataloader(model=model,
  88. total_samples=total_samples,
  89. hidden_dim=hidden_dim,
  90. device=model.device,
  91. inputs=[1.0,
  92. 2.0,
  93. 3.0],
  94. targets=[1,
  95. 2,
  96. 3])
  97. for n, batch in enumerate(data_loader):
  98. assert len(batch) % 2 == 0, \
  99. f"multi_output_dataloader failed to return even number of data samples (input+target)"
  100. midpoint = len(batch) // 2
  101. inputs, targets = batch[:midpoint], batch[midpoint:]
  102. loss_tuple = model(inputs, targets)
  103. assert len(loss_tuple) == 3
  104. expected_loss = torch.tensor(2.302734375,
  105. dtype=torch.half,
  106. device=model.device)
  107. for loss in loss_tuple:
  108. assert loss.shape == torch.Size([])
  109. assert loss.item() == approx(expected_loss.item())
  110. summed_loss = sum(loss_tuple)
  111. scaled_loss = model.backward(summed_loss)
  112. expected_scaled_loss = summed_loss.float() / gradient_accumulation_steps
  113. assert scaled_loss.item() == approx(expected_scaled_loss.item())
  114. model.step()
  115. _test_three_output_model(args=args, model=model, hidden_dim=hidden_dim)