tp_shard.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import comm as dist
  5. global num_kv_heads
  6. def set_num_kv_heads(num):
  7. global num_kv_heads
  8. num_kv_heads = num
  9. def set_n_embd(num):
  10. global n_embd
  11. n_embd = num
  12. def get_num_kv_heads():
  13. global num_kv_heads
  14. return num_kv_heads
  15. def get_shard_size(total_size, mp_size, rank=None):
  16. global num_kv_heads
  17. # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
  18. if num_kv_heads is not None:
  19. if rank is None:
  20. rank = dist.get_rank()
  21. my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
  22. return total_size * my_slices // num_kv_heads
  23. else:
  24. if total_size % mp_size == 0:
  25. return total_size // mp_size
  26. else:
  27. assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"
  28. def get_n_embd():
  29. global n_embd
  30. return n_embd
  31. def get_shard_size_list(total_size, mp_size):
  32. shard_sizes = []
  33. for i in range(mp_size):
  34. shard_sizes.append(get_shard_size(total_size, mp_size, i))
  35. return shard_sizes