test_full_model.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
  2. import os
  3. import torch
  4. import transformers
  5. from hivemind import get_logger, use_hivemind_log_handler
  6. from src.client.remote_model import DistributedBloomForCausalLM
  7. use_hivemind_log_handler("in_root_logger")
  8. logger = get_logger(__file__)
  9. INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
  10. if not INITIAL_PEERS:
  11. raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
  12. INITIAL_PEERS = INITIAL_PEERS.split()
  13. MODEL_NAME = os.environ.get("MODEL_NAME")
  14. if not MODEL_NAME:
  15. raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
  16. REF_NAME = os.environ.get("REF_NAME")
  17. def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
  18. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  19. model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, prefix=prefix)
  20. assert len(model.transformer.h) == model.config.n_layer
  21. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  22. parallel_outputs = model.forward(test_inputs).logits
  23. assert torch.all(torch.isfinite(parallel_outputs))
  24. logger.info("Forward outputs are finite")
  25. if REF_NAME:
  26. ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
  27. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  28. # note: this creates a dummy mask to make the test compatible with older transformer versions
  29. # prior to https://github.com/huggingface/transformers/pull/17837
  30. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
  31. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  32. else:
  33. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  34. embs = model.transformer.word_embeddings(test_inputs)
  35. embs = model.transformer.word_embeddings_layernorm(embs)
  36. recurrent_outputs = []
  37. with model.transformer.h.inference_session() as sess:
  38. for t in range(embs.shape[1]):
  39. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  40. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  41. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  42. dictionary = model.transformer.word_embeddings.weight.t()
  43. recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
  44. recurrent_outputs = (recurrent_outputs @ dictionary).float()
  45. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  46. logger.info("Inference is consistent with forward")