dropping_utils.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from deepspeed.ops.op_builder import RandomLTDBuilder
  6. """
  7. Returns:
  8. sampled_indices: [layers, batch_size, reserved_length]
  9. new_mask: [batch_size, 1, reserved_length, reserved_length]
  10. """
  11. random_ltd_module = None
  12. def gpt_sample_tokens(reserved_length: int,
  13. seq_length: int,
  14. batch_size: int,
  15. layers: int = 1,
  16. device: str = 'cpu',
  17. attn_mask: torch.Tensor = None):
  18. prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
  19. sampled_indices = torch.multinomial(prob_dist, reserved_length)
  20. sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
  21. global random_ltd_module
  22. if random_ltd_module is None:
  23. random_ltd_module = RandomLTDBuilder().load()
  24. sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
  25. # Not certain the optimized kernel is actually better here, cause it kind of screws
  26. # with alignment right if the sequence length is not divisible by like 16
  27. # new_mask = random_ltd_module.mask_gather_gpt(attn_mask, reserved_length)
  28. if attn_mask is not None:
  29. new_mask = attn_mask[:, :, :reserved_length, :reserved_length]
  30. else:
  31. new_mask = None
  32. return sampled_indices, new_mask
  33. """
  34. Returns:
  35. sampled_indices: [layers, batch_size, reserved_length]
  36. new_mask: [layers, batch_size, 1, reserved_length, reserved_length]
  37. """
  38. def bert_sample_tokens(reserved_length: int,
  39. seq_length: int,
  40. batch_size: int,
  41. layers: int = 1,
  42. device: str = 'cpu',
  43. attn_mask: torch.Tensor = None):
  44. assert attn_mask is not None
  45. prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
  46. sampled_indices = torch.multinomial(prob_dist, reserved_length)
  47. sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
  48. global random_ltd_module
  49. if random_ltd_module is None:
  50. random_ltd_module = RandomLTDBuilder().load()
  51. sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
  52. dtype = sampled_indices.dtype
  53. sampled_indices = sampled_indices.to(torch.long)
  54. new_mask = []
  55. for l in range(layers):
  56. tmp_mask_list = []
  57. for i in range(batch_size):
  58. mask_tmp = attn_mask[i:i + 1, :, sampled_indices[l][i], :]
  59. tmp_mask_list.append(mask_tmp[:, :, :, sampled_indices[l][i]])
  60. new_mask.append(torch.cat(tmp_mask_list, dim=0))
  61. return sampled_indices.to(dtype), new_mask
  62. class GatherTokens(torch.autograd.Function):
  63. @staticmethod
  64. def forward(ctx, activations: torch.Tensor, sorted_indices: torch.Tensor, batch_first: bool):
  65. global random_ltd_module
  66. if random_ltd_module is None:
  67. random_ltd_module = RandomLTDBuilder().load()
  68. ctx.save_for_backward(activations, sorted_indices)
  69. ctx.batch_first = batch_first
  70. return activations, random_ltd_module.token_gather(activations, sorted_indices, batch_first)
  71. @staticmethod
  72. def backward(ctx, a_gradients: torch.Tensor, g_gradients: torch.Tensor):
  73. g_gradients = g_gradients.contiguous()
  74. global random_ltd_module
  75. if random_ltd_module is None:
  76. random_ltd_module = RandomLTDBuilder().load()
  77. activations, sorted_indices = ctx.saved_tensors
  78. batch_first = ctx.batch_first
  79. return random_ltd_module.token_scatter_(a_gradients, g_gradients, sorted_indices, batch_first), None, None
  80. class ScatterTokens(torch.autograd.Function):
  81. @staticmethod
  82. def forward(ctx, all_activations: torch.Tensor, layer_activations: torch.Tensor, sorted_indices: torch.Tensor,
  83. batch_first: bool):
  84. global random_ltd_module
  85. if random_ltd_module is None:
  86. random_ltd_module = RandomLTDBuilder().load()
  87. scatter_results = random_ltd_module.token_scatter_(all_activations.clone(), layer_activations, sorted_indices,
  88. batch_first)
  89. ctx.save_for_backward(sorted_indices)
  90. ctx.batch_first = batch_first
  91. return scatter_results
  92. @staticmethod
  93. def backward(ctx, out_gradients: torch.Tensor):
  94. out_gradients = out_gradients.contiguous()
  95. global random_ltd_module
  96. if random_ltd_module is None:
  97. random_ltd_module = RandomLTDBuilder().load()
  98. sorted_indices, = ctx.saved_tensors
  99. batch_first = ctx.batch_first
  100. ret_val = random_ltd_module.token_gather(out_gradients, sorted_indices, batch_first)
  101. return out_gradients, ret_val, None, None