test_averaging_sparse_gradients.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. import pytest
  6. import deepspeed
  7. from unit.common import DistributedTest
  8. from unit.util import skip_on_arch
  9. from deepspeed.accelerator import get_accelerator
  10. if get_accelerator().device_name() == 'hpu':
  11. pytest.skip("sparse_gradients not supported by HPU.", allow_module_level=True)
  12. class Model(torch.nn.Module):
  13. def __init__(self):
  14. super().__init__()
  15. self.emb = torch.nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
  16. self.linear = torch.nn.Linear(3, 1)
  17. def forward(self, x, offsets):
  18. return self.linear(self.emb(x, offsets))
  19. class Adam(torch.optim.Optimizer):
  20. def __init__(self, dense_params, sparse_params):
  21. super().__init__(dense_params + sparse_params, defaults={})
  22. self.adam = torch.optim.Adam(dense_params)
  23. self.adam_sparse = torch.optim.SparseAdam(sparse_params)
  24. @torch.no_grad()
  25. def step(self, closure=None):
  26. loss_1 = self.adam.step(closure)
  27. loss_2 = self.adam_sparse.step(closure)
  28. if loss_1 is not None and loss_2 is not None:
  29. return loss_1 + loss_2
  30. return loss_1 or loss_2
  31. def get_model_optimizer():
  32. torch.manual_seed(0)
  33. model = Model()
  34. optimizer = Adam(list(model.linear.parameters()), list(model.emb.parameters()))
  35. return model, optimizer
  36. def get_data(device):
  37. x = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long, device=device)
  38. offsets = torch.tensor([0, 4], dtype=torch.long, device=device)
  39. y = torch.tensor([[1.0], [0.0]], device=device)
  40. return x, offsets, y
  41. class TestSparseAdam(DistributedTest):
  42. world_size = 2
  43. def test(self):
  44. skip_on_arch(min_arch=7)
  45. config_dict = {"train_batch_size": 2, "steps_per_print": 1, "sparse_gradients": True}
  46. model, optimizer = get_model_optimizer()
  47. loss = torch.nn.BCEWithLogitsLoss()
  48. engine, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=config_dict)
  49. x, offsets, y = get_data(engine.device)
  50. engine.gradient_average = True
  51. res = engine(x, offsets)
  52. engine.backward(loss(res, y))
  53. averaged_grads = {}
  54. for k, v in engine.named_parameters():
  55. grad = v.grad.to_dense() if v.grad.is_sparse else v.grad
  56. averaged_grads[k] = grad
  57. v.grad = None
  58. engine.gradient_average = False
  59. res = engine(x, offsets)
  60. engine.backward(loss(res, y))
  61. for k, v in engine.named_parameters():
  62. grad = v.grad.to_dense() if v.grad.is_sparse else v.grad
  63. assert torch.allclose(grad, averaged_grads[k] * engine.world_size)