test_topology.py 7.8 KB


  1. import pytest
  2. import torch
  3. import torch.distributed as dist
  4. from deepspeed.runtime.pipe.topology import PipelineParallelGrid as Grid
  5. from deepspeed.runtime.pipe.topology import ProcessTopology as Topo
  6. from deepspeed.runtime.pipe.topology import _prime_factors
  7. from common import distributed_test
  8. def test_topology_2d():
  9. topo = Topo(axes=['row', 'col'], dims=[2, 2])
  10. assert topo.world_size() == 4
  11. assert topo.get_rank(row=0, col=0) == 0
  12. assert topo.get_rank(row=0, col=1) == 1
  13. assert topo.get_rank(row=1, col=0) == 2
  14. assert topo.get_rank(row=1, col=1) == 3
  15. assert topo.get_axis_list(axis='row', idx=0) == [0, 1]
  16. assert topo.get_axis_list(axis='row', idx=1) == [2, 3]
  17. assert topo.get_axis_list(axis='col', idx=0) == [0, 2]
  18. assert topo.get_axis_list(axis='col', idx=1) == [1, 3]
  19. def test_topology_dims():
  20. topo = Topo(axes=['a', 'b', 'c'], dims=[2, 3, 4])
  21. assert topo.world_size() == 24
  22. assert topo.get_dim('a') == 2
  23. assert topo.get_dim('b') == 3
  24. assert topo.get_dim('c') == 4
  25. def test_topology_match():
  26. topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  27. print(topo.filter_match(pipe=0, data=1))
  28. assert topo.filter_match(pipe=0, data=1) == [2, 3]
  29. print([topo.get_coord(r) for r in topo.filter_match(pipe=0, data=1)])
  30. def test_topology_rank_repr():
  31. topo = Topo(axes=['a', 'b'], dims=[2, 2])
  32. assert topo.get_rank_repr(rank=0) == 'a_00-b_00'
  33. assert topo.get_rank_repr(rank=1) == 'a_00-b_01'
  34. assert topo.get_rank_repr(rank=2) == 'a_01-b_00'
  35. assert topo.get_rank_repr(rank=3) == 'a_01-b_01'
  36. assert topo.get_rank_repr(rank=3, inner_sep='+') == 'a+01-b+01'
  37. assert topo.get_rank_repr(rank=3,
  38. inner_sep='🤗',
  39. outer_sep='_JEFF_') == 'a🤗01_JEFF_b🤗01'
  40. topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
  41. assert topo.get_rank_repr(rank=0) == ''
  42. assert topo.get_rank_repr(rank=1) == ''
  43. assert topo.get_rank_repr(rank=2) == ''
  44. assert topo.get_rank_repr(rank=3) == ''
  45. assert topo.get_rank_repr(rank=0, omit_axes=['pipe']) == 'data_00'
  46. assert topo.get_rank_repr(rank=1, omit_axes=['pipe']) == 'data_01'
  47. assert topo.get_rank_repr(rank=2, omit_axes=['pipe']) == 'data_00'
  48. assert topo.get_rank_repr(rank=3, omit_axes=['pipe']) == 'data_01'
  49. assert topo.get_rank_repr(rank=0, omit_axes=[]) == 'pipe_00-data_00'
  50. assert topo.get_rank_repr(rank=1, omit_axes=[]) == 'pipe_00-data_01'
  51. assert topo.get_rank_repr(rank=2, omit_axes=[]) == 'pipe_01-data_00'
  52. assert topo.get_rank_repr(rank=3, omit_axes=[]) == 'pipe_01-data_01'
  53. topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  54. assert topo.get_rank_repr(rank=0) == 'model_00'
  55. assert topo.get_rank_repr(rank=1) == 'model_01'
  56. assert topo.get_rank_repr(rank=2) == 'model_00'
  57. assert topo.get_rank_repr(rank=3) == 'model_01'
  58. assert topo.get_rank_repr(rank=4) == 'model_00'
  59. assert topo.get_rank_repr(rank=5) == 'model_01'
  60. assert topo.get_rank_repr(rank=6) == 'model_00'
  61. assert topo.get_rank_repr(rank=7) == 'model_01'
  62. def test_topology_3d():
  63. topo = Topo(axes=['a', 'b', 'c'], dims=[2, 2, 2])
  64. assert topo.get_rank(a=0, b=0, c=0) == 0
  65. assert topo.get_rank(a=0, b=0, c=1) == 1
  66. assert topo.get_rank(a=0, b=1, c=0) == 2
  67. assert topo.get_rank(a=0, b=1, c=1) == 3
  68. assert topo.get_rank(a=1, b=0, c=0) == 4
  69. assert topo.get_rank(a=1, b=0, c=1) == 5
  70. assert topo.get_rank(a=1, b=1, c=0) == 6
  71. assert topo.get_rank(a=1, b=1, c=1) == 7
  72. assert topo.get_axis_list('a', 0) == [0, 1, 2, 3]
  73. assert topo.get_axis_list('a', 1) == [4, 5, 6, 7]
  74. assert topo.get_axis_list('b', 0) == [0, 1, 4, 5]
  75. assert topo.get_axis_list('b', 1) == [2, 3, 6, 7]
  76. assert topo.get_axis_list('c', 0) == [0, 2, 4, 6]
  77. assert topo.get_axis_list('c', 1) == [1, 3, 5, 7]
  78. assert topo.get_coord(0) == topo.ProcessCoord(0, 0, 0)
  79. assert topo.get_coord(1) == topo.ProcessCoord(0, 0, 1)
  80. assert topo.get_coord(2) == topo.ProcessCoord(0, 1, 0)
  81. assert topo.get_coord(3) == topo.ProcessCoord(0, 1, 1)
  82. assert topo.get_coord(4) == topo.ProcessCoord(1, 0, 0)
  83. assert topo.get_coord(5) == topo.ProcessCoord(1, 0, 1)
  84. assert topo.get_coord(6) == topo.ProcessCoord(1, 1, 0)
  85. assert topo.get_coord(7) == topo.ProcessCoord(1, 1, 1)
  86. assert topo.filter_match(a=0) == [0, 1, 2, 3]
  87. assert topo.filter_match(b=1, c=1) == [3, 7]
  88. assert topo.filter_match(a=1, b=1, c=1) == [7]
  89. # Easy access method
  90. assert topo.get_coord(0).a == 0
  91. def test_topology_comm_list():
  92. topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])
  93. assert topo.get_rank(pipe=0, data=0, model=0) == 0
  94. assert topo.get_rank(pipe=0, data=0, model=1) == 1
  95. assert topo.get_rank(pipe=0, data=1, model=0) == 2
  96. assert topo.get_rank(pipe=0, data=1, model=1) == 3
  97. assert topo.get_rank(pipe=1, data=0, model=0) == 4
  98. assert topo.get_rank(pipe=1, data=0, model=1) == 5
  99. assert topo.get_rank(pipe=1, data=1, model=0) == 6
  100. assert topo.get_rank(pipe=1, data=1, model=1) == 7
  101. pipe_list = [
  102. [0, 4], # data=0, model=0
  103. [1, 5], # data=0, model=1
  104. [2, 6], # data=1, model=0
  105. [3, 7], # data=1, model=1
  106. ]
  107. assert topo.get_axis_comm_lists('pipe') == pipe_list
  108. data_list = [
  109. [0, 2], # pipe=0, model=0
  110. [1, 3], # pipe=0, model=1
  111. [4, 6], # pipe=1, model=0
  112. [5, 7], # pipe=1, model=1
  113. ]
  114. assert topo.get_axis_comm_lists('data') == data_list
  115. model_list = [
  116. [0, 1], # pipe=0, data=0
  117. [2, 3], # pipe=0, data=1
  118. [4, 5], # pipe=1, data=0
  119. [6, 7], # pipe=1, data=1
  120. ]
  121. assert topo.get_axis_comm_lists('model') == model_list
  122. # Handle nonsense. We don't want to RuntimeError because it allows us to write more
  123. # generalized code for data/model/pipe parallelism
  124. assert topo.get_axis_comm_lists('jeff') == []
  125. @distributed_test(world_size=4)
  126. def test_grid_pipe_data():
  127. topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
  128. grid = Grid(topology=topo)
  129. assert grid._is_grid_valid()
  130. rank = dist.get_rank()
  131. assert grid.is_first_stage == (grid.get_stage_id() == 0)
  132. assert grid.is_last_stage == (
  133. grid.get_stage_id() == grid.get_pipe_parallel_world_size() - 1)
  134. # Test collectives along the pipeline parallel process groups
  135. rank_tensor = torch.LongTensor(data=[rank]).cuda()
  136. dist.all_reduce(rank_tensor, group=grid.get_pipe_parallel_group())
  137. pipe_group = grid.pp_group
  138. assert torch.all(rank_tensor == sum(pipe_group))
  139. # Test collectives along the data parallel process groups
  140. rank_tensor = torch.LongTensor(data=[rank]).cuda()
  141. dist.all_reduce(rank_tensor, group=grid.get_data_parallel_group())
  142. data_group = grid.dp_group
  143. assert torch.all(rank_tensor == sum(data_group))
  144. @distributed_test(world_size=4)
  145. def test_stage_to_global():
  146. topo = Topo(axes=['pipe', 'data'], dims=[2, 2])
  147. grid = Grid(topology=topo)
  148. assert grid._is_grid_valid()
  149. assert grid.stage_to_global(stage_id=0, data=0) == 0
  150. assert grid.stage_to_global(stage_id=0, data=1) == 1
  151. assert grid.stage_to_global(stage_id=1, data=0) == 2
  152. assert grid.stage_to_global(stage_id=1, data=1) == 3
  153. me = topo.get_coord(rank=dist.get_rank())
  154. if me.data == 0:
  155. assert grid.stage_to_global(stage_id=0) == 0
  156. assert grid.stage_to_global(stage_id=1) == 2
  157. else:
  158. assert grid.stage_to_global(stage_id=0) == 1
  159. assert grid.stage_to_global(stage_id=1) == 3
  160. def test_primes():
  161. """ Test prime factorizations. """
  162. def _product(ps):
  163. p = 1
  164. for num in ps:
  165. p *= num
  166. return p
  167. with pytest.raises(ValueError):
  168. _prime_factors(0)
  169. for x in range(1, 30):
  170. primes = _prime_factors(x)
  171. assert _product(primes) == x
  172. for p in primes:
  173. assert _prime_factors(p) == [p]