test_partition_balanced.py 799 B

12345678910111213141516171819202122232425
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.runtime import utils as ds_utils
  5. def check_partition(weights, num_parts, target_diff):
  6. result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts)
  7. parts_sum = []
  8. for b, e in zip(result[:-1], result[1:]):
  9. parts_sum.append(sum(weights[b:e]))
  10. assert max(parts_sum) - min(
  11. parts_sum
  12. ) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}"
  13. def test_partition_balanced():
  14. check_partition([1, 2, 1], 4, target_diff=2)
  15. check_partition([1, 1, 1, 1], 4, target_diff=0)
  16. check_partition([1, 1, 1, 1, 1], 4, target_diff=1)
  17. check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1)