123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- from deepspeed.utils import logger
- from torch import Tensor
- from torch.nn import Module
- from ..constants import *
- from deepspeed.ops.random_ltd.dropping_utils import gpt_sample_tokens, bert_sample_tokens, GatherTokens, ScatterTokens
- #####based on the paper random-ltd: https://arxiv.org/abs/2211.11586
- class RandomLayerTokenDrop(Module):
- """
- A layer wrapper for random LTD
- """
- def __init__(self, layer: Module):
- super(RandomLayerTokenDrop, self).__init__()
- self.random_ltd_layer = layer
- self.reserved_length = None #config['max_value']
- self.random_ltd_scheduler = None
- self.max_length = None
- self.reserved_length = -1
- self.curr_seq = -1
- self.batch_first = False
- def init_config(self, config, scheduler, random_ltd_layer_id):
- self.random_ltd_scheduler = scheduler
- self.random_ltd_layer_id = random_ltd_layer_id
- self.max_length = self.random_ltd_scheduler.state[RANDOM_LTD_MAX_VALUE]
- self.mask_name = config[RANDOM_LTD_MODEL_MASK_NAME]
- self.micro_bs = config[RANDOM_LTD_MICRO_BATCH_SIZE]
- self.random_ltd_num_layer = self.random_ltd_scheduler.random_ltd_layer_num
- hs_order = config[RANDOM_LTD_HIDDEN_STATE_ORDER]
- self.model_type = config[RANDOM_LTD_MODEL_TYPE]
- if hs_order == 'batch_seq_dim':
- self.get_hidden_tensor_shape = self.get_bsh
- self.batch_first = True
- elif hs_order == 'seq_batch_dim':
- self.get_hidden_tensor_shape = self.get_sbh
- self.batch_first = False
- else:
- logger.warning(
- "************For now, we only support batch_seq_dim or seq_batch_dim inputs. You can easily \
- your own input dimension orders************")
- raise NotImplementedError
- if self.model_type == 'encoder':
- self.index_generator = bert_sample_tokens
- elif self.model_type == 'decoder':
- self.index_generator = gpt_sample_tokens
- else:
- logger.warning("************For now, we only support encoder-only or decoder-only models************")
- raise NotImplementedError
- def get_bsh(self, hidden_stats):
- self.curr_seq, self.curr_micro_batch = hidden_stats.size()[1], hidden_stats.size()[0]
- def get_sbh(self, hidden_stats):
- self.curr_seq, self.curr_micro_batch = hidden_stats.size()[0], hidden_stats.size()[1]
- def forward(self, hidden_states, **kwargs) -> Tensor:
- if self.random_ltd_scheduler is not None:
- self.reserved_length = self.random_ltd_scheduler.get_current_seq()
- self.get_hidden_tensor_shape(hidden_states)
- if self.training and self.random_ltd_scheduler is not None and self.reserved_length < self.curr_seq:
- if self.mask_name is not None:
- mask = kwargs[self.mask_name]
- else:
- mask = None
- if self.random_ltd_layer_id == 0:
- sampled_indices, part_attention_mask = self.index_generator(self.reserved_length,\
- self.curr_seq, \
- self.curr_micro_batch, \
- self.random_ltd_num_layer, \
- hidden_states.device, mask)
- self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX] = sampled_indices
- self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK] = part_attention_mask
- else:
- sampled_indices = self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX]
- part_attention_mask = self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK]
- hidden_states, part_hidden_states = GatherTokens.apply(hidden_states,
- sampled_indices[self.random_ltd_layer_id, :, :],
- self.batch_first)
- if self.mask_name is not None:
- if self.model_type == 'encoder':
- kwargs[self.mask_name] = part_attention_mask[self.random_ltd_layer_id]
- else:
- kwargs[self.mask_name] = part_attention_mask
- outputs = self.random_ltd_layer(part_hidden_states, **kwargs)
- if isinstance(outputs, tuple):
- hidden_states = ScatterTokens.apply(hidden_states, outputs[0],
- sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
- my_list = list(outputs)
- my_list[0] = hidden_states
- return tuple(my_list)
- elif isinstance(outputs, Tensor):
- hidden_states = ScatterTokens.apply(hidden_states, outputs,
- sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
- return hidden_states
- else:
- logger.warning("************For now, we only support tuple and tensor output. \
- You need to adjust the output according to the layer in your model************")
- raise NotImplementedError
- else:
- return self.random_ltd_layer(hidden_states, **kwargs)
|