test_remote_sequential.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import pytest
  2. import torch
  3. from hivemind import DHT, get_logger, use_hivemind_log_handler
  4. from test_utils import *
  5. from src import RemoteSequential
  6. from src.bloom.from_pretrained import load_pretrained_block
  7. from src.client.remote_model import DistributedBloomConfig
  8. use_hivemind_log_handler("in_root_logger")
  9. logger = get_logger(__file__)
  10. @pytest.mark.forked
  11. def test_remote_sequential():
  12. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  13. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  14. test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
  15. grad_proj = torch.randn(1, 5, config.hidden_size)
  16. sequential = RemoteSequential(config, dht)
  17. full_outputs = sequential(test_inputs)
  18. (full_outputs * grad_proj).sum().backward()
  19. assert test_inputs.grad is not None
  20. full_grad = test_inputs.grad.clone()
  21. test_inputs.grad.data.zero_()
  22. first_half = sequential[: config.n_layer // 2]
  23. second_half = sequential[config.n_layer // 2 :]
  24. assert len(first_half) + len(second_half) == len(sequential)
  25. assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
  26. for m in sequential, first_half, second_half:
  27. assert isinstance(repr(m), str)
  28. hidden = first_half(test_inputs)
  29. assert isinstance(hidden, torch.Tensor)
  30. assert hidden.shape == test_inputs.shape
  31. assert hidden.requires_grad
  32. second_half_outputs = second_half(hidden)
  33. assert torch.allclose(second_half_outputs, full_outputs)
  34. (second_half_outputs * grad_proj).sum().backward()
  35. assert torch.allclose(test_inputs.grad, full_grad)
  36. @pytest.mark.forked
  37. def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
  38. config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  39. dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
  40. remote_sequential = RemoteSequential(config, dht)
  41. inputs = torch.randn(batch_size, seq_len, config.hidden_size)
  42. output_proj = torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size)
  43. input_prompts = torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
  44. intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
  45. input_prompts = input_prompts.detach().requires_grad_(True)
  46. intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
  47. inputs_with_prompts = torch.cat([inputs, input_prompts], dim=1)
  48. assert inputs_with_prompts.shape == (batch_size, seq_len + pre_seq_len, config.hidden_size)
  49. outputs = remote_sequential(inputs_with_prompts, prompts=intermediate_prompts)
  50. (outputs * output_proj).sum().backward()
  51. assert intermediate_prompts.grad is not None
  52. input_prompts_ref = input_prompts.clone().detach().requires_grad_(True)
  53. intermediate_prompts_ref = intermediate_prompts.clone().detach().requires_grad_(True)
  54. assert input_prompts_ref.grad is None
  55. assert intermediate_prompts_ref.grad is None
  56. outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
  57. for block_index in range(config.n_layer):
  58. block_prompt = intermediate_prompts_ref[block_index]
  59. outputs_ref[:, : block_prompt.shape[1]] += block_prompt
  60. block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32)
  61. (outputs_ref,) = block(outputs_ref)
  62. assert torch.allclose(outputs_ref, outputs)
  63. (outputs_ref * output_proj).sum().backward()
  64. assert input_prompts_ref.grad is not None
  65. assert torch.allclose(input_prompts_ref.grad, input_prompts.grad)
  66. assert intermediate_prompts_ref.grad is not None
  67. assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad)