benchmark_inference.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python3
  2. import argparse
  3. import multiprocessing as mp
  4. from time import perf_counter
  5. import numpy as np
  6. import torch
  7. from hivemind.utils.logging import get_logger
  8. from transformers import AutoTokenizer
  9. from petals import AutoDistributedModelForCausalLM
  10. from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
  11. logger = get_logger()
  12. def main():
  13. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  14. parser.add_argument("--model", type=str, required=True, help="Model")
  15. parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
  16. parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
  17. parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
  18. parser.add_argument("--seq_len", type=int, default=2048, help="Sequence length")
  19. parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
  20. args = parser.parse_args()
  21. if args.n_processes == "n_gpus":
  22. args.n_processes = torch.cuda.device_count()
  23. else:
  24. args.n_processes = int(args.n_processes)
  25. pipe_recv, pipe_send = mp.Pipe(duplex=False)
  26. processes = [mp.Process(target=benchmark_inference, args=(i, args, pipe_send)) for i in range(args.n_processes)]
  27. for proc in processes:
  28. proc.start()
  29. for proc in processes:
  30. proc.join()
  31. speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
  32. logger.info(f"Final result: {speed=:.2f}")
  33. @torch.inference_mode()
  34. def benchmark_inference(process_idx, args, result_pipe):
  35. tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
  36. # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
  37. model = AutoDistributedModelForCausalLM.from_pretrained(
  38. args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
  39. )
  40. logger.info(f"Created model: {process_idx=} {model.device=}")
  41. result = ""
  42. step_times = []
  43. with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
  44. for step in range(args.seq_len):
  45. start_time = perf_counter()
  46. outputs = model.generate(max_new_tokens=1, session=sess)
  47. result += tokenizer.decode(outputs[0])
  48. if step >= args.warmup_steps:
  49. step_times.append(perf_counter() - start_time)
  50. speed = 1 / np.mean(step_times)
  51. logger.info(f"{process_idx=} {step=} {speed=:.2f}")
  52. result_pipe.send(speed)
  53. if __name__ == "__main__":
  54. main()