test_activation_checkpointing.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # TODO: add tests with model parallelism for activation partitioning and other features.
  2. from copy import deepcopy
  3. import pytest
  4. import torch
  5. import deepspeed
  6. ckpt = deepspeed.checkpointing.checkpoint
  7. from common import distributed_test
  8. def _compute(module, *inputs, do_checkpoint=False):
  9. if do_checkpoint:
  10. outputs = ckpt(module, *inputs)
  11. else:
  12. outputs = module(*inputs)
  13. if torch.is_tensor(outputs):
  14. outputs = (outputs, )
  15. sum(o.sum() for o in outputs if o.requires_grad).backward()
  16. grads = [p.grad for p in module.parameters()]
  17. input_grads = [inp.grad for inp in inputs]
  18. return {
  19. 'outputs': outputs,
  20. 'module_grads': grads,
  21. 'input_grads': input_grads,
  22. }
  23. # This is distributed because checkpoint() assumes that torch.distributed is initialized.
  24. # torch.distributed is used with activation partitioning, but not for these simple cases.
  25. @distributed_test(world_size=1)
  26. def _test_activation_checkpoint(module, *inputs):
  27. # Move to device
  28. module.cuda()
  29. # Get rid of dropouts until we fork the RNG between tests.
  30. module.eval()
  31. module_ = deepcopy(module)
  32. inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
  33. base = _compute(module_, *inputs_, do_checkpoint=False)
  34. module_ = deepcopy(module)
  35. inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
  36. test = _compute(module_, *inputs_, do_checkpoint=True)
  37. for group in base.keys():
  38. for b, t in zip(base[group], test[group]):
  39. # Catch grad `None`s, etc.
  40. if not torch.is_tensor(b):
  41. assert b == t
  42. elif b.is_floating_point():
  43. assert torch.allclose(b, t)
  44. else:
  45. assert torch.equal(b, t)
  46. #
  47. # Helpers
  48. #
  49. class MaskedLinear(torch.nn.Linear):
  50. def forward(self, x, mask):
  51. out = super().forward(x)
  52. if mask.is_floating_point():
  53. out = out * mask
  54. else:
  55. # must cast BoolTensor in older torch versions
  56. out = out * mask.type_as(out)
  57. return out
  58. class MaskedLinearSeq(MaskedLinear):
  59. """Tests pipeline modules by also returning the mask."""
  60. def forward(self, x, mask):
  61. return super().forward(x, mask), mask
  62. class MaskedLinearSeqDup(MaskedLinearSeq):
  63. """MaskedLinearSeq, but with more outputs than inputs and in a different order."""
  64. def forward(self, x, mask):
  65. dup = x.clone().detach() * 1.38 # just an arbitrary scaling
  66. x, mask = super().forward(x, mask)
  67. return dup, x, mask
  68. HIDDEN_DIM = 20
  69. def _mixed_mask(size=HIDDEN_DIM):
  70. entries = torch.randn(size)
  71. mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size))
  72. mask = mask.bool()
  73. return mask
  74. def _bool_to_float(btensor, dtype=torch.float32):
  75. """Converts a torch.BoolTensor to an equivalent dtype. """
  76. ones = torch.ones(size=btensor.size(), dtype=dtype)
  77. zeros = torch.zeros(size=btensor.size(), dtype=dtype)
  78. return torch.where(btensor, ones, zeros)
  79. #
  80. # Tests
  81. #
  82. def test_ckpt_inputs1_outputs1():
  83. module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
  84. inputs = torch.rand(HIDDEN_DIM)
  85. inputs.requires_grad = True
  86. _test_activation_checkpoint(module, inputs)
  87. # both bool and float are important, as bool is not diffentiable
  88. @pytest.mark.parametrize('mask',
  89. [
  90. _mixed_mask(),
  91. _bool_to_float(_mixed_mask()),
  92. ])
  93. def test_ckpt_inputs2_outputs1(mask):
  94. module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM)
  95. inputs = torch.rand(HIDDEN_DIM)
  96. inputs.requires_grad = True
  97. _test_activation_checkpoint(module, inputs, mask)
  98. @pytest.mark.parametrize('mask',
  99. [
  100. _mixed_mask(),
  101. _bool_to_float(_mixed_mask()),
  102. ])
  103. def test_ckpt_inputs2_outputs2(mask):
  104. module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM)
  105. inputs = torch.rand(HIDDEN_DIM)
  106. inputs.requires_grad = True
  107. _test_activation_checkpoint(module, inputs, mask)
  108. @pytest.mark.parametrize('mask',
  109. [
  110. _mixed_mask(),
  111. _bool_to_float(_mixed_mask()),
  112. ])
  113. def test_ckpt_inputs2_outputs3(mask):
  114. module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM)
  115. inputs = torch.rand(HIDDEN_DIM)
  116. inputs.requires_grad = True
  117. _test_activation_checkpoint(module, inputs, mask)