test_dist.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import torch
  2. import torch.distributed as dist
  3. from common import distributed_test
  4. import pytest
  5. @distributed_test(world_size=3)
  6. def test_init():
  7. assert dist.is_initialized()
  8. assert dist.get_world_size() == 3
  9. assert dist.get_rank() < 3
  10. # Demonstration of pytest's paramaterization
  11. @pytest.mark.parametrize('number,color', [(1138, 'purple')])
  12. def test_dist_args(number, color):
  13. """Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
  14. helper function.
  15. """
  16. @distributed_test(world_size=2)
  17. def _test_dist_args_helper(x, color='red'):
  18. assert dist.get_world_size() == 2
  19. assert x == 1138
  20. assert color == 'purple'
  21. """Ensure that we can parse args to distributed_test decorated functions. """
  22. _test_dist_args_helper(number, color=color)
  23. @distributed_test(world_size=[1, 2, 4])
  24. def test_dist_allreduce():
  25. x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
  26. sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
  27. result = torch.ones(1, 3).cuda() * sum_of_ranks
  28. dist.all_reduce(x)
  29. assert torch.all(x == result)