utils.py 955 B

123456789101112131415161718192021222324252627
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. def bsh_decoder_gather(reserved_length, hidden_states, mask):
  6. # random-layer-token-drop
  7. rand_list = []
  8. part_hidden_states = [] # batch, seq, hidden ## different from megatron
  9. for k in range(hidden_states.size(0)):
  10. B_tmp = torch.randperm(hidden_states.size(1), device=hidden_states.device)[:reserved_length]
  11. B = B_tmp.sort()[0]
  12. rand_list.append(B)
  13. part_hidden_states.append(hidden_states[k:k + 1, B, :])
  14. part_hidden_states = torch.cat(part_hidden_states, dim=0)
  15. part_mask = mask[:, :, :reserved_length, :reserved_length]
  16. return part_hidden_states, rand_list, part_mask
  17. def bsh_decoder_scatter(hidden_states, part_hidden_states, rand_list):
  18. for k in range(hidden_states.size(0)):
  19. hidden_states[k, rand_list[k], :] = part_hidden_states[k, :, :]
  20. return hidden_states