123456789101112131415161718192021222324252627 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- def bsh_decoder_gather(reserved_length, hidden_states, mask):
- # random-layer-token-drop
- rand_list = []
- part_hidden_states = [] # batch, seq, hidden ## different from megatron
- for k in range(hidden_states.size(0)):
- B_tmp = torch.randperm(hidden_states.size(1), device=hidden_states.device)[:reserved_length]
- B = B_tmp.sort()[0]
- rand_list.append(B)
- part_hidden_states.append(hidden_states[k:k + 1, B, :])
- part_hidden_states = torch.cat(part_hidden_states, dim=0)
- part_mask = mask[:, :, :reserved_length, :reserved_length]
- return part_hidden_states, rand_list, part_mask
- def bsh_decoder_scatter(hidden_states, part_hidden_states, rand_list):
- for k in range(hidden_states.size(0)):
- hidden_states[k, rand_list[k], :] = part_hidden_states[k, :, :]
- return hidden_states
|