123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- import pytest
- import torch
- import deepspeed.comm as dist
- from deepspeed.runtime.utils import partition_uniform
- from deepspeed.runtime.utils import partition_balanced
- from deepspeed.runtime.utils import prefix_sum_inc
- from deepspeed.runtime.utils import PartitionedTensor
- from .common import distributed_test
- @distributed_test(world_size=4)
- def test_partitioned_tensor():
- world = dist.get_world_size()
- rank = dist.get_rank()
- group = dist.new_group(ranks=list(range(world)))
- rows = world * 4
- cols = 3
- full = torch.rand(rows, cols).cuda()
- dist.broadcast(full, src=0, group=group)
- part = PartitionedTensor(full, group=group)
- assert len(part.local_size()) == 1
- assert part.local_size()[0] * world == full.numel()
- reconstructed = part.full()
- assert torch.equal(full, reconstructed)
- @distributed_test(world_size=4)
- def test_partitioned_tensor_meta():
- world = dist.get_world_size()
- rank = dist.get_rank()
- group = dist.new_group(ranks=list(range(world)))
- rows = world * 7
- cols = 3
- full = torch.rand(rows, cols).cuda()
- dist.broadcast(full, src=0, group=group)
- part = PartitionedTensor(full, group=group)
- my_meta = PartitionedTensor.from_meta(part.to_meta(), part.local_data, group)
- assert torch.equal(full, my_meta.full())
- def assert_valid_partition(weights, parts, P):
- N = len(weights)
- assert len(parts) == P + 1
- assert parts[0] == 0
- assert parts[P] == N
- for idx in range(P):
- assert parts[idx] <= parts[idx + 1]
- def get_partition_weights(weights, parts):
- """ Return the amount of weight in each partition. """
- costs = [0] * (len(parts) - 1)
- P = len(parts) - 1
- for p in range(P):
- start = parts[p]
- stop = parts[p + 1]
- costs[p] = sum(weights[start:stop])
- return costs
- def test_prefix_sum():
- x = [3, 4, 5]
- psum = prefix_sum_inc(x)
- assert psum == [3, 7, 12]
- def test_valid_partition():
- N = 10
- P = 1
- weights = [1] * N
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
- def test_short_partition_uniform():
- N = 2
- P = 4
- weights = [1] * N
- parts = partition_uniform(len(weights), P)
- assert_valid_partition(weights, parts, P)
- def test_short_partition():
- N = 2
- P = 4
- weights = [1] * N
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
- def test_easy_balance_uniform():
- weights = [1] * 8
- P = 4
- parts = partition_uniform(len(weights), P)
- assert_valid_partition(weights, parts, P)
- costs = get_partition_weights(weights, parts)
- assert all(c == 2 for c in costs)
- def test_easy_balance_balanced():
- weights = [1] * 8
- P = 4
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
- costs = get_partition_weights(weights, parts)
- assert all(c == 2 for c in costs), costs
- def test_int_balanced():
- weights = [0, 1, 2, 3, 3, 3]
- P = 4
- parts = partition_balanced(weights, P)
- assert parts == [0, 3, 4, 5, 6]
- assert_valid_partition(weights, parts, P)
- costs = get_partition_weights(weights, parts)
- assert all(c == 3 for c in costs)
- def test_float_balanced():
- weights = [0., 1.1, 1.9, 3., 3., 3.]
- P = 4
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
- assert parts == [0, 3, 4, 5, 6]
- @pytest.mark.skip(reason="Variance-minimizing partitioning returns different result.")
- def test_float_lastheavy():
- weights = [0., 1.1, 1.9, 3., 30.]
- P = 2
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
- assert parts == [0, 4, 5]
- def test_float_midheavy():
- weights = [0., 1.1, 30, 3.]
- P = 3
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
- assert parts == [0, 2, 3, 4]
- def test_balance_bert():
- # Parameters per layer for a transformer model with 24 transformers and hidden dim 1024
- weights = [
- 52559872,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 12596224,
- 0,
- 52559872
- ]
- P = 8
- parts = partition_balanced(weights, P)
- assert_valid_partition(weights, parts, P)
|