test_full_model.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
  2. import os
  3. import torch
  4. import transformers
  5. from hivemind import get_logger, use_hivemind_log_handler
  6. from src.client.remote_model import DistributedBloomForCausalLM
  7. use_hivemind_log_handler("in_root_logger")
  8. logger = get_logger(__file__)
  9. INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
  10. if not INITIAL_PEERS:
  11. raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
  12. INITIAL_PEERS = INITIAL_PEERS.split()
  13. MODEL_NAME = os.environ.get("MODEL_NAME")
  14. if not MODEL_NAME:
  15. raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
  16. REF_NAME = os.environ.get("REF_NAME")
  17. def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
  18. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  19. model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  20. assert isinstance(model, DistributedBloomForCausalLM)
  21. assert len(model.transformer.h) == model.config.n_layer
  22. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  23. parallel_outputs = model.forward(test_inputs).logits
  24. assert torch.all(torch.isfinite(parallel_outputs))
  25. logger.info("Forward outputs are finite")
  26. if REF_NAME:
  27. with torch.no_grad():
  28. ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
  29. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  30. # note: this creates a dummy mask to make the test compatible with older transformer versions
  31. # prior to https://github.com/huggingface/transformers/pull/17837
  32. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
  33. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  34. del ref_model, ref_outputs
  35. else:
  36. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  37. with torch.inference_mode():
  38. embs = model.transformer.word_embeddings(test_inputs)
  39. embs = model.transformer.word_embeddings_layernorm(embs)
  40. recurrent_outputs = []
  41. with model.transformer.h.inference_session() as sess:
  42. for t in range(embs.shape[1]):
  43. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  44. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  45. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  46. dictionary = model.transformer.word_embeddings.weight.t()
  47. recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
  48. recurrent_outputs = (recurrent_outputs @ dictionary).float()
  49. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  50. logger.info("Inference is consistent with forward")