test_remote_sequential.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import pytest
  2. import torch
  3. from hivemind import DHT, get_logger, use_hivemind_log_handler
  4. from test_utils import *
  5. from src import RemoteSequential
  6. from src.client.remote_model import DistributedBloomConfig
  7. use_hivemind_log_handler("in_root_logger")
  8. logger = get_logger(__file__)
  9. @pytest.mark.forked
  10. def test_remote_sequential():
  11. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  12. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  13. test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
  14. grad_proj = torch.randn(1, 5, config.hidden_size)
  15. sequential = RemoteSequential(config, dht)
  16. full_outputs = sequential(test_inputs)
  17. (full_outputs * grad_proj).sum().backward()
  18. assert test_inputs.grad is not None
  19. full_grad = test_inputs.grad.clone()
  20. test_inputs.grad.data.zero_()
  21. first_half = sequential[: config.n_layer // 2]
  22. second_half = sequential[config.n_layer // 2 :]
  23. assert len(first_half) + len(second_half) == len(sequential)
  24. assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
  25. for m in sequential, first_half, second_half:
  26. assert isinstance(repr(m), str)
  27. hidden = first_half(test_inputs)
  28. assert isinstance(hidden, torch.Tensor)
  29. assert hidden.shape == test_inputs.shape
  30. assert hidden.requires_grad
  31. second_half_outputs = second_half(hidden)
  32. assert torch.allclose(second_half_outputs, full_outputs)
  33. (second_half_outputs * grad_proj).sum().backward()
  34. assert torch.allclose(test_inputs.grad, full_grad)