123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- # TODO: add tests with model parallelism for activation partitioning and other features.
- from copy import deepcopy
- import pytest
- import torch
- import deepspeed
- ckpt = deepspeed.checkpointing.checkpoint
- from common import distributed_test
- def _compute(module, *inputs, do_checkpoint=False):
- if do_checkpoint:
- outputs = ckpt(module, *inputs)
- else:
- outputs = module(*inputs)
- if torch.is_tensor(outputs):
- outputs = (outputs, )
- sum(o.sum() for o in outputs if o.requires_grad).backward()
- grads = [p.grad for p in module.parameters()]
- input_grads = [inp.grad for inp in inputs]
- return {
- 'outputs': outputs,
- 'module_grads': grads,
- 'input_grads': input_grads,
- }
- # This is distributed because checkpoint() assumes that torch.distributed is initialized.
- # torch.distributed is used with activation partitioning, but not for these simple cases.
- @distributed_test(world_size=1)
- def _test_activation_checkpoint(module, *inputs):
- # Move to device
- module.cuda()
- # Get rid of dropouts until we fork the RNG between tests.
- module.eval()
- module_ = deepcopy(module)
- inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
- base = _compute(module_, *inputs_, do_checkpoint=False)
- module_ = deepcopy(module)
- inputs_ = tuple(deepcopy(inp).cuda() for inp in inputs)
- test = _compute(module_, *inputs_, do_checkpoint=True)
- for group in base.keys():
- for b, t in zip(base[group], test[group]):
- # Catch grad `None`s, etc.
- if not torch.is_tensor(b):
- assert b == t
- elif b.is_floating_point():
- assert torch.allclose(b, t)
- else:
- assert torch.equal(b, t)
- #
- # Helpers
- #
- class MaskedLinear(torch.nn.Linear):
- def forward(self, x, mask):
- out = super().forward(x)
- if mask.is_floating_point():
- out = out * mask
- else:
- # must cast BoolTensor in older torch versions
- out = out * mask.type_as(out)
- return out
- class MaskedLinearSeq(MaskedLinear):
- """Tests pipeline modules by also returning the mask."""
- def forward(self, x, mask):
- return super().forward(x, mask), mask
- class MaskedLinearSeqDup(MaskedLinearSeq):
- """MaskedLinearSeq, but with more outputs than inputs and in a different order."""
- def forward(self, x, mask):
- dup = x.clone().detach() * 1.38 # just an arbitrary scaling
- x, mask = super().forward(x, mask)
- return dup, x, mask
- HIDDEN_DIM = 20
- def _mixed_mask(size=HIDDEN_DIM):
- entries = torch.randn(size)
- mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size))
- mask = mask.bool()
- return mask
- def _bool_to_float(btensor, dtype=torch.float32):
- """Converts a torch.BoolTensor to an equivalent dtype. """
- ones = torch.ones(size=btensor.size(), dtype=dtype)
- zeros = torch.zeros(size=btensor.size(), dtype=dtype)
- return torch.where(btensor, ones, zeros)
- #
- # Tests
- #
- def test_ckpt_inputs1_outputs1():
- module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
- inputs = torch.rand(HIDDEN_DIM)
- inputs.requires_grad = True
- _test_activation_checkpoint(module, inputs)
- # both bool and float are important, as bool is not diffentiable
- @pytest.mark.parametrize('mask',
- [
- _mixed_mask(),
- _bool_to_float(_mixed_mask()),
- ])
- def test_ckpt_inputs2_outputs1(mask):
- module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM)
- inputs = torch.rand(HIDDEN_DIM)
- inputs.requires_grad = True
- _test_activation_checkpoint(module, inputs, mask)
- @pytest.mark.parametrize('mask',
- [
- _mixed_mask(),
- _bool_to_float(_mixed_mask()),
- ])
- def test_ckpt_inputs2_outputs2(mask):
- module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM)
- inputs = torch.rand(HIDDEN_DIM)
- inputs.requires_grad = True
- _test_activation_checkpoint(module, inputs, mask)
- @pytest.mark.parametrize('mask',
- [
- _mixed_mask(),
- _bool_to_float(_mixed_mask()),
- ])
- def test_ckpt_inputs2_outputs3(mask):
- module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM)
- inputs = torch.rand(HIDDEN_DIM)
- inputs.requires_grad = True
- _test_activation_checkpoint(module, inputs, mask)
|