test_full_model.py 4.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import pytest
  2. import torch
  3. import transformers
  4. from hivemind import get_logger, use_hivemind_log_handler
  5. from test_utils import *
  6. from src.bloom.model import BloomForCausalLM
  7. from src.client.remote_model import DistributedBloomForCausalLM
  8. use_hivemind_log_handler("in_root_logger")
  9. logger = get_logger(__file__)
  10. @pytest.mark.forked
  11. def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
  12. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  13. model = DistributedBloomForCausalLM.from_pretrained(
  14. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  15. )
  16. config = model.config
  17. assert isinstance(model, DistributedBloomForCausalLM)
  18. assert len(model.transformer.h) == model.config.n_layer
  19. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  20. with torch.inference_mode():
  21. parallel_outputs = model.forward(test_inputs).logits
  22. assert torch.all(torch.isfinite(parallel_outputs))
  23. logger.info("Forward outputs are finite")
  24. embs = model.transformer.word_embeddings(test_inputs)
  25. embs = model.transformer.word_embeddings_layernorm(embs)
  26. recurrent_outputs = []
  27. with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
  28. for t in range(embs.shape[1]):
  29. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  30. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  31. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  32. recurrent_outputs = model.lm_head(recurrent_outputs)
  33. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  34. logger.info("Inference is consistent with forward")
  35. del model, embs, recurrent_outputs
  36. if REF_NAME:
  37. ref_model = transformers.BloomForCausalLM.from_pretrained(
  38. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  39. )
  40. if config.vocab_size < ref_model.config.vocab_size:
  41. ref_model.resize_token_embeddings(config.vocab_size)
  42. logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
  43. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  44. # note: this creates a dummy mask to make the test compatible with older transformer versions
  45. # prior to https://github.com/huggingface/transformers/pull/17837
  46. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
  47. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  48. logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
  49. del ref_model, ref_outputs, dummy_mask
  50. else:
  51. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  52. assert False
  53. @pytest.mark.forked
  54. def test_greedy_generation(max_new_tokens=4):
  55. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  56. model = DistributedBloomForCausalLM.from_pretrained(
  57. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  58. )
  59. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  60. remote_outputs = model.generate(
  61. inputs,
  62. max_new_tokens=max_new_tokens,
  63. )
  64. hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
  65. assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
  66. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  67. "input_ids"
  68. ]
  69. remote_outputs_batch = model.generate(
  70. inputs_batch,
  71. max_new_tokens=max_new_tokens,
  72. )
  73. hf_outputs_batch = BloomForCausalLM.greedy_search(
  74. model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
  75. )
  76. assert torch.allclose(
  77. remote_outputs_batch, hf_outputs_batch
  78. ), "Greedy search are not identical to HF in multibatch mode"