test_partition.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import pytest
  2. import torch
  3. import deepspeed.comm as dist
  4. from deepspeed.runtime.utils import partition_uniform
  5. from deepspeed.runtime.utils import partition_balanced
  6. from deepspeed.runtime.utils import prefix_sum_inc
  7. from deepspeed.runtime.utils import PartitionedTensor
  8. from .common import distributed_test
  9. @distributed_test(world_size=4)
  10. def test_partitioned_tensor():
  11. world = dist.get_world_size()
  12. rank = dist.get_rank()
  13. group = dist.new_group(ranks=list(range(world)))
  14. rows = world * 4
  15. cols = 3
  16. full = torch.rand(rows, cols).cuda()
  17. dist.broadcast(full, src=0, group=group)
  18. part = PartitionedTensor(full, group=group)
  19. assert len(part.local_size()) == 1
  20. assert part.local_size()[0] * world == full.numel()
  21. reconstructed = part.full()
  22. assert torch.equal(full, reconstructed)
  23. @distributed_test(world_size=4)
  24. def test_partitioned_tensor_meta():
  25. world = dist.get_world_size()
  26. rank = dist.get_rank()
  27. group = dist.new_group(ranks=list(range(world)))
  28. rows = world * 7
  29. cols = 3
  30. full = torch.rand(rows, cols).cuda()
  31. dist.broadcast(full, src=0, group=group)
  32. part = PartitionedTensor(full, group=group)
  33. my_meta = PartitionedTensor.from_meta(part.to_meta(), part.local_data, group)
  34. assert torch.equal(full, my_meta.full())
  35. def assert_valid_partition(weights, parts, P):
  36. N = len(weights)
  37. assert len(parts) == P + 1
  38. assert parts[0] == 0
  39. assert parts[P] == N
  40. for idx in range(P):
  41. assert parts[idx] <= parts[idx + 1]
  42. def get_partition_weights(weights, parts):
  43. """ Return the amount of weight in each partition. """
  44. costs = [0] * (len(parts) - 1)
  45. P = len(parts) - 1
  46. for p in range(P):
  47. start = parts[p]
  48. stop = parts[p + 1]
  49. costs[p] = sum(weights[start:stop])
  50. return costs
  51. def test_prefix_sum():
  52. x = [3, 4, 5]
  53. psum = prefix_sum_inc(x)
  54. assert psum == [3, 7, 12]
  55. def test_valid_partition():
  56. N = 10
  57. P = 1
  58. weights = [1] * N
  59. parts = partition_balanced(weights, P)
  60. assert_valid_partition(weights, parts, P)
  61. def test_short_partition_uniform():
  62. N = 2
  63. P = 4
  64. weights = [1] * N
  65. parts = partition_uniform(len(weights), P)
  66. assert_valid_partition(weights, parts, P)
  67. def test_short_partition():
  68. N = 2
  69. P = 4
  70. weights = [1] * N
  71. parts = partition_balanced(weights, P)
  72. assert_valid_partition(weights, parts, P)
  73. def test_easy_balance_uniform():
  74. weights = [1] * 8
  75. P = 4
  76. parts = partition_uniform(len(weights), P)
  77. assert_valid_partition(weights, parts, P)
  78. costs = get_partition_weights(weights, parts)
  79. assert all(c == 2 for c in costs)
  80. def test_easy_balance_balanced():
  81. weights = [1] * 8
  82. P = 4
  83. parts = partition_balanced(weights, P)
  84. assert_valid_partition(weights, parts, P)
  85. costs = get_partition_weights(weights, parts)
  86. assert all(c == 2 for c in costs), costs
  87. def test_int_balanced():
  88. weights = [0, 1, 2, 3, 3, 3]
  89. P = 4
  90. parts = partition_balanced(weights, P)
  91. assert parts == [0, 3, 4, 5, 6]
  92. assert_valid_partition(weights, parts, P)
  93. costs = get_partition_weights(weights, parts)
  94. assert all(c == 3 for c in costs)
  95. def test_float_balanced():
  96. weights = [0., 1.1, 1.9, 3., 3., 3.]
  97. P = 4
  98. parts = partition_balanced(weights, P)
  99. assert_valid_partition(weights, parts, P)
  100. assert parts == [0, 3, 4, 5, 6]
  101. @pytest.mark.skip(reason="Variance-minimizing partitioning returns different result.")
  102. def test_float_lastheavy():
  103. weights = [0., 1.1, 1.9, 3., 30.]
  104. P = 2
  105. parts = partition_balanced(weights, P)
  106. assert_valid_partition(weights, parts, P)
  107. assert parts == [0, 4, 5]
  108. def test_float_midheavy():
  109. weights = [0., 1.1, 30, 3.]
  110. P = 3
  111. parts = partition_balanced(weights, P)
  112. assert_valid_partition(weights, parts, P)
  113. assert parts == [0, 2, 3, 4]
  114. def test_balance_bert():
  115. # Parameters per layer for a transformer model with 24 transformers and hidden dim 1024
  116. weights = [
  117. 52559872,
  118. 12596224,
  119. 12596224,
  120. 12596224,
  121. 12596224,
  122. 12596224,
  123. 12596224,
  124. 12596224,
  125. 12596224,
  126. 12596224,
  127. 12596224,
  128. 12596224,
  129. 12596224,
  130. 12596224,
  131. 12596224,
  132. 12596224,
  133. 12596224,
  134. 12596224,
  135. 12596224,
  136. 12596224,
  137. 12596224,
  138. 12596224,
  139. 12596224,
  140. 12596224,
  141. 12596224,
  142. 0,
  143. 52559872
  144. ]
  145. P = 8
  146. parts = partition_balanced(weights, P)
  147. assert_valid_partition(weights, parts, P)