test_block_exact_match.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
  2. import os
  3. import hivemind
  4. import torch
  5. import transformers
  6. from src.bloom.from_pretrained import load_pretrained_block
  7. from src.client.remote_block import RemoteTransformerBlock
  8. from src.dht_utils import get_remote_module
  9. INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
  10. if not INITIAL_PEERS:
  11. raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
  12. INITIAL_PEERS = INITIAL_PEERS.split()
  13. BLOCK_UID = os.environ.get("BLOCK_UID")
  14. if not BLOCK_UID:
  15. raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
  16. REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
  17. REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
  18. def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
  19. dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
  20. remote_block = get_remote_module(dht, BLOCK_UID)
  21. assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
  22. assert isinstance(remote_block, RemoteTransformerBlock)
  23. ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
  24. inputs = torch.randn(1, 8, ref_config.hidden_size)
  25. (outputs_forward,) = remote_block(inputs)
  26. outputs_inference = []
  27. with remote_block.inference_session() as sess:
  28. for i in range(inputs.shape[1]):
  29. outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
  30. outputs_inference = torch.cat(outputs_inference, dim=1)
  31. ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32)
  32. (outputs_local,) = ref_block(inputs)
  33. assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
  34. assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)