basic_layer.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from deepspeed.utils import logger
  5. from torch import Tensor
  6. from torch.nn import Module
  7. from ..constants import *
  8. from deepspeed.ops.random_ltd.dropping_utils import gpt_sample_tokens, bert_sample_tokens, GatherTokens, ScatterTokens
  9. #####based on the paper random-ltd: https://arxiv.org/abs/2211.11586
  10. class RandomLayerTokenDrop(Module):
  11. """
  12. A layer wrapper for random LTD
  13. """
  14. def __init__(self, layer: Module):
  15. super(RandomLayerTokenDrop, self).__init__()
  16. self.random_ltd_layer = layer
  17. self.reserved_length = None #config['max_value']
  18. self.random_ltd_scheduler = None
  19. self.max_length = None
  20. self.reserved_length = -1
  21. self.curr_seq = -1
  22. self.batch_first = False
  23. def init_config(self, config, scheduler, random_ltd_layer_id):
  24. self.random_ltd_scheduler = scheduler
  25. self.random_ltd_layer_id = random_ltd_layer_id
  26. self.max_length = self.random_ltd_scheduler.state[RANDOM_LTD_MAX_VALUE]
  27. self.mask_name = config[RANDOM_LTD_MODEL_MASK_NAME]
  28. self.micro_bs = config[RANDOM_LTD_MICRO_BATCH_SIZE]
  29. self.random_ltd_num_layer = self.random_ltd_scheduler.random_ltd_layer_num
  30. hs_order = config[RANDOM_LTD_HIDDEN_STATE_ORDER]
  31. self.model_type = config[RANDOM_LTD_MODEL_TYPE]
  32. if hs_order == 'batch_seq_dim':
  33. self.get_hidden_tensor_shape = self.get_bsh
  34. self.batch_first = True
  35. elif hs_order == 'seq_batch_dim':
  36. self.get_hidden_tensor_shape = self.get_sbh
  37. self.batch_first = False
  38. else:
  39. logger.warning(
  40. "************For now, we only support batch_seq_dim or seq_batch_dim inputs. You can easily \
  41. your own input dimension orders************")
  42. raise NotImplementedError
  43. if self.model_type == 'encoder':
  44. self.index_generator = bert_sample_tokens
  45. elif self.model_type == 'decoder':
  46. self.index_generator = gpt_sample_tokens
  47. else:
  48. logger.warning("************For now, we only support encoder-only or decoder-only models************")
  49. raise NotImplementedError
  50. def get_bsh(self, hidden_stats):
  51. self.curr_seq, self.curr_micro_batch = hidden_stats.size()[1], hidden_stats.size()[0]
  52. def get_sbh(self, hidden_stats):
  53. self.curr_seq, self.curr_micro_batch = hidden_stats.size()[0], hidden_stats.size()[1]
  54. def forward(self, hidden_states, **kwargs) -> Tensor:
  55. if self.random_ltd_scheduler is not None:
  56. self.reserved_length = self.random_ltd_scheduler.get_current_seq()
  57. self.get_hidden_tensor_shape(hidden_states)
  58. if self.training and self.random_ltd_scheduler is not None and self.reserved_length < self.curr_seq:
  59. if self.mask_name is not None:
  60. mask = kwargs[self.mask_name]
  61. else:
  62. mask = None
  63. if self.random_ltd_layer_id == 0:
  64. sampled_indices, part_attention_mask = self.index_generator(self.reserved_length,\
  65. self.curr_seq, \
  66. self.curr_micro_batch, \
  67. self.random_ltd_num_layer, \
  68. hidden_states.device, mask)
  69. self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX] = sampled_indices
  70. self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK] = part_attention_mask
  71. else:
  72. sampled_indices = self.random_ltd_scheduler.state[RANDOM_LTD_SAMPLE_INDEX]
  73. part_attention_mask = self.random_ltd_scheduler.state[RANDOM_LTD_ATTENTION_MASK]
  74. hidden_states, part_hidden_states = GatherTokens.apply(hidden_states,
  75. sampled_indices[self.random_ltd_layer_id, :, :],
  76. self.batch_first)
  77. if self.mask_name is not None:
  78. if self.model_type == 'encoder':
  79. kwargs[self.mask_name] = part_attention_mask[self.random_ltd_layer_id]
  80. else:
  81. kwargs[self.mask_name] = part_attention_mask
  82. outputs = self.random_ltd_layer(part_hidden_states, **kwargs)
  83. if isinstance(outputs, tuple):
  84. hidden_states = ScatterTokens.apply(hidden_states, outputs[0],
  85. sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
  86. my_list = list(outputs)
  87. my_list[0] = hidden_states
  88. return tuple(my_list)
  89. elif isinstance(outputs, Tensor):
  90. hidden_states = ScatterTokens.apply(hidden_states, outputs,
  91. sampled_indices[self.random_ltd_layer_id, :, :], self.batch_first)
  92. return hidden_states
  93. else:
  94. logger.warning("************For now, we only support tuple and tensor output. \
  95. You need to adjust the output according to the layer in your model************")
  96. raise NotImplementedError
  97. else:
  98. return self.random_ltd_layer(hidden_states, **kwargs)