test_speculative_generation.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import random
  2. import pytest
  3. import torch
  4. import transformers
  5. from petals import (
  6. AutoDistributedConfig,
  7. AutoDistributedSpeculativeModel,
  8. DistributedLlamaForSpeculativeGeneration,
  9. RemoteSequential,
  10. )
  11. from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
  12. from petals.server.from_pretrained import load_pretrained_block
  13. from test_utils import *
  14. @pytest.mark.forked
  15. def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
  16. config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  17. remote_sequential = RemoteSequential(config)
  18. block_index = random.randint(0, config.num_hidden_layers - 1)
  19. remote_block = remote_sequential[block_index]
  20. inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
  21. short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
  22. short_inputs[:, :2, :] = inputs[:, :2, :]
  23. initial_outputs_inference = None
  24. secondary_outputs_inference = None
  25. with torch.inference_mode():
  26. with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
  27. initial_outputs_inference = sess.step(inputs)
  28. sess.position = 2
  29. secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
  30. result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
  31. ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
  32. (outputs_local,) = ref_block(short_inputs)
  33. assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
  34. @pytest.fixture
  35. def noisy_model():
  36. noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
  37. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  38. )
  39. lm_head = noisy_model.get_output_embeddings()
  40. assert isinstance(lm_head, torch.nn.Linear)
  41. with torch.no_grad():
  42. lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
  43. return noisy_model
  44. @pytest.fixture
  45. def model():
  46. return transformers.AutoModelForCausalLM.from_pretrained(
  47. MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  48. )
  49. @pytest.fixture
  50. def tokenizer():
  51. # We set use_fast=False since LlamaTokenizerFast is slow on load
  52. return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
  53. @pytest.mark.forked
  54. @pytest.mark.skipif(
  55. "llama" not in MODEL_NAME.lower(),
  56. reason="Speculative generation now works only for llama models",
  57. )
  58. def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
  59. speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
  60. MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
  61. )
  62. inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  63. generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
  64. generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)
  65. assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)