12345678910111213141516171819202122232425 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from deepspeed.runtime import utils as ds_utils
- def check_partition(weights, num_parts, target_diff):
- result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts)
- parts_sum = []
- for b, e in zip(result[:-1], result[1:]):
- parts_sum.append(sum(weights[b:e]))
- assert max(parts_sum) - min(
- parts_sum
- ) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}"
- def test_partition_balanced():
- check_partition([1, 2, 1], 4, target_diff=2)
- check_partition([1, 1, 1, 1], 4, target_diff=0)
- check_partition([1, 1, 1, 1, 1], 4, target_diff=1)
- check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1)
|