1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from deepspeed import comm as dist
- global num_kv_heads
- def set_num_kv_heads(num):
- global num_kv_heads
- num_kv_heads = num
- def set_num_attention_heads(num):
- global num_attention_heads
- num_attention_heads = num
- def set_n_embd(num):
- global n_embd
- n_embd = num
- def get_num_kv_heads():
- global num_kv_heads
- return num_kv_heads
- def get_num_attention_heads():
- global num_attention_heads
- return num_attention_heads
- def get_shard_size(total_size, mp_size, name=None, rank=None):
- global num_kv_heads
- last_linear = ["lm_head", "embed_out"]
- # When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
- if rank == None:
- rank = dist.get_rank()
- if num_kv_heads != None and total_size % num_kv_heads == 0 and "mlp" not in str(name) and str(
- name) not in last_linear:
- my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
- return total_size * my_slices // num_kv_heads
- else:
- if total_size >= 64:
- grain_size = total_size // 64
- return (grain_size // mp_size + (1 if rank < (grain_size % mp_size) else 0)) * 64
- else:
- return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)
- def get_n_embd():
- global n_embd
- return n_embd
- def get_shard_size_list(total_size, mp_size, name=None):
- shard_sizes = []
- for i in range(mp_size):
- shard_sizes.append(get_shard_size(total_size, mp_size, name, i))
- return shard_sizes
|