tp_shard.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed import comm as dist
  5. global num_kv_heads
  6. def set_num_kv_heads(num):
  7. global num_kv_heads
  8. num_kv_heads = num
  9. def set_num_attention_heads(num):
  10. global num_attention_heads
  11. num_attention_heads = num
  12. def set_n_embd(num):
  13. global n_embd
  14. n_embd = num
  15. def get_num_kv_heads():
  16. global num_kv_heads
  17. return num_kv_heads
  18. def get_num_attention_heads():
  19. global num_attention_heads
  20. return num_attention_heads
  21. def get_shard_size(total_size, mp_size, name=None, rank=None):
  22. global num_kv_heads
  23. last_linear = ["lm_head", "embed_out"]
  24. # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
  25. if rank == None:
  26. rank = dist.get_rank()
  27. if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
  28. name) not in last_linear:
  29. my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
  30. return total_size * my_slices // num_kv_heads
  31. else:
  32. if total_size >= 64:
  33. grain_size = total_size // 64
  34. return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
  35. else:
  36. return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)
  37. def get_n_embd():
  38. global n_embd
  39. return n_embd
  40. def get_shard_size_list(total_size, mp_size, name=None):
  41. shard_sizes = []
  42. for i in range(mp_size):
  43. shard_sizes.append(get_shard_size(total_size, mp_size, name, i))
  44. return shard_sizes