test_full_model.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import pytest
  2. import torch
  3. import transformers
  4. from hivemind import get_logger, use_hivemind_log_handler
  5. from test_utils import *
  6. from src.bloom.model import BloomForCausalLM
  7. from src.client.remote_model import DistributedBloomForCausalLM
  8. use_hivemind_log_handler("in_root_logger")
  9. logger = get_logger(__file__)
  10. @pytest.mark.forked
  11. def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
  12. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  13. model = DistributedBloomForCausalLM.from_pretrained(
  14. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  15. )
  16. assert isinstance(model, DistributedBloomForCausalLM)
  17. assert len(model.transformer.h) == model.config.n_layer
  18. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  19. with torch.inference_mode():
  20. parallel_outputs = model.forward(test_inputs).logits
  21. assert torch.all(torch.isfinite(parallel_outputs))
  22. logger.info("Forward outputs are finite")
  23. embs = model.transformer.word_embeddings(test_inputs)
  24. embs = model.transformer.word_embeddings_layernorm(embs)
  25. recurrent_outputs = []
  26. with model.transformer.h.inference_session() as sess:
  27. for t in range(embs.shape[1]):
  28. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  29. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  30. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  31. recurrent_outputs = model.lm_head(recurrent_outputs)
  32. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  33. logger.info("Inference is consistent with forward")
  34. del model, embs, recurrent_outputs
  35. if REF_NAME:
  36. ref_model = transformers.BloomForCausalLM.from_pretrained(
  37. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  38. )
  39. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  40. # note: this creates a dummy mask to make the test compatible with older transformer versions
  41. # prior to https://github.com/huggingface/transformers/pull/17837
  42. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
  43. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  44. logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
  45. del ref_model, ref_outputs, dummy_mask
  46. else:
  47. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  48. assert False
  49. @pytest.mark.forked
  50. def test_greedy_generation(max_new_tokens=4):
  51. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  52. model = DistributedBloomForCausalLM.from_pretrained(
  53. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  54. )
  55. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  56. remote_outputs = model.generate(
  57. inputs,
  58. max_new_tokens=max_new_tokens,
  59. )
  60. hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
  61. assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
  62. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  63. "input_ids"
  64. ]
  65. remote_outputs_batch = model.generate(
  66. inputs_batch,
  67. max_new_tokens=max_new_tokens,
  68. )
  69. hf_outputs_batch = BloomForCausalLM.greedy_search(
  70. model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
  71. )
  72. assert torch.allclose(
  73. remote_outputs_batch, hf_outputs_batch
  74. ), "Greedy search are not identical to HF in multibatch mode"