test_aux_functions.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import subprocess
  2. import sys
  3. import pytest
  4. import torch
  5. from hivemind import nested_compare, nested_flatten
  6. from petals import AutoDistributedConfig
  7. from petals.server.throughput import measure_compute_rps
  8. from petals.utils.convert_block import QuantType
  9. from petals.utils.misc import DUMMY, is_dummy
  10. from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs
  11. from test_utils import MODEL_NAME
  12. def test_bnb_not_imported_when_unnecessary():
  13. """
  14. We avoid importing bitsandbytes when it's not used,
  15. since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
  16. If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft
  17. in the function's/method's code when it's actually needed instead of importing them in the beginning of the file.
  18. This won't slow down the code - importing a module for the 2nd time doesn't rerun module code.
  19. """
  20. subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"])
  21. @pytest.mark.forked
  22. @pytest.mark.parametrize("inference", [False, True])
  23. @pytest.mark.parametrize("n_tokens", [1, 16])
  24. @pytest.mark.parametrize("tensor_parallel", [False, True])
  25. def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
  26. config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
  27. if tensor_parallel and config.model_type != "bloom":
  28. pytest.skip("Tensor parallelism is implemented only for BLOOM for now")
  29. tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
  30. compute_rps = measure_compute_rps(
  31. config,
  32. device=torch.device("cpu"),
  33. dtype=torch.bfloat16,
  34. quant_type=QuantType.NONE,
  35. tensor_parallel_devices=tensor_parallel_devices,
  36. n_tokens=n_tokens,
  37. n_steps=5,
  38. inference=inference,
  39. )
  40. assert isinstance(compute_rps, float) and compute_rps > 0
  41. @pytest.mark.forked
  42. def test_pack_inputs():
  43. x = torch.ones(3)
  44. y = torch.arange(5)
  45. z = DUMMY
  46. args = (x, z, None, (y, y), z)
  47. kwargs = dict(foo=torch.zeros(1, 1), bar={"l": "i", "g": "h", "t": ("y", "e", "a", "r", torch.rand(1), x, y)})
  48. flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)
  49. assert len(flat_tensors) == 5
  50. assert all(isinstance(t, torch.Tensor) for t in flat_tensors)
  51. restored_args, restored_kwargs = unpack_args_kwargs(flat_tensors, args_structure)
  52. assert len(restored_args) == len(args)
  53. assert torch.all(restored_args[0] == x).item() and restored_args[2] is None
  54. assert nested_compare((args, kwargs), (restored_args, restored_kwargs))
  55. for original, restored in zip(nested_flatten((args, kwargs)), nested_flatten((restored_args, restored_kwargs))):
  56. if isinstance(original, torch.Tensor):
  57. assert torch.all(original == restored)
  58. else:
  59. assert original == restored