test_autocast.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pytest
  5. import torch
  6. from deepspeed.runtime.zero.linear import LinearModuleForZeroStage3
  7. from deepspeed.accelerator import get_accelerator
  8. from unit.common import DistributedTest
  9. @pytest.mark.parametrize('half_op', [False, True])
  10. class TestAutoCastDisable(DistributedTest):
  11. def test_missing_amp_autocast(self, half_op):
  12. hidden_dim = 4
  13. if half_op:
  14. input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half()
  15. ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()).half()
  16. else:
  17. input = torch.randn(hidden_dim).to(get_accelerator().device_name())
  18. ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())
  19. output = ds_linear(input)
  20. assert output.dtype == ds_linear.weight.dtype
  21. def test_disable_autocast_linear(self, half_op):
  22. amp = get_accelerator().amp()
  23. hidden_dim = 4
  24. if half_op:
  25. input = torch.randn(hidden_dim).to(get_accelerator().device_name()).half()
  26. ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name()).half()
  27. else:
  28. input = torch.randn(hidden_dim).to(get_accelerator().device_name())
  29. ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())
  30. with amp.autocast(False):
  31. output = ds_linear(input)
  32. assert output.dtype == ds_linear.weight.dtype
  33. @pytest.mark.skipif(get_accelerator().amp() is None, reason='amp is not installed')
  34. @pytest.mark.parametrize('half_input, half_weight', [(False, False), (False, True), (True, False), (True, True)])
  35. class TestAutoCastEnable(DistributedTest):
  36. def test_autocast_linear(self, tmpdir, half_input, half_weight):
  37. amp = get_accelerator().amp()
  38. hidden_dim = 4
  39. input = torch.randn(hidden_dim).to(get_accelerator().device_name())
  40. ds_linear = LinearModuleForZeroStage3(hidden_dim, hidden_dim).to(get_accelerator().device_name())
  41. if half_input:
  42. input = input.half()
  43. if half_weight:
  44. ds_linear = ds_linear.half()
  45. with amp.autocast():
  46. output = ds_linear(input)
  47. assert output.dtype == torch.half or output.dtype == torch.bfloat16