dropping_utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. """
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. """
  4. import torch
  5. from deepspeed.ops.op_builder import RandomLTDBuilder
  6. """
  7. Returns:
  8. sampled_indices: [layers, batch_size, reserved_length]
  9. new_mask: [batch_size, 1, reserved_length, reserved_length]
  10. """
  11. random_ltd_module = None
  12. def gpt_sample_tokens(reserved_length: int,
  13. seq_length: int,
  14. batch_size: int,
  15. layers: int = 1,
  16. device: str = 'cpu',
  17. attn_mask: torch.Tensor = None):
  18. prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
  19. sampled_indices = torch.multinomial(prob_dist, reserved_length)
  20. sampled_indices = sampled_indices.reshape(layers,
  21. batch_size,
  22. reserved_length).to(torch.int32)
  23. global random_ltd_module
  24. if random_ltd_module is None:
  25. random_ltd_module = RandomLTDBuilder().load()
  26. sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
  27. # Not certain the optimized kernel is actually better here, cause it kind of screws
  28. # with alignment right if the sequence length is not divisble by like 16
  29. # new_mask = random_ltd_module.mask_gather_gpt(attn_mask, reserved_length)
  30. if attn_mask is not None:
  31. new_mask = attn_mask[:, :, :reserved_length, :reserved_length]
  32. else:
  33. new_mask = None
  34. return sampled_indices, new_mask
  35. """
  36. Returns:
  37. sampled_indices: [layers, batch_size, reserved_length]
  38. new_mask: [layers, batch_size, 1, reserved_length, reserved_length]
  39. """
  40. def bert_sample_tokens(reserved_length: int,
  41. seq_length: int,
  42. batch_size: int,
  43. layers: int = 1,
  44. device: str = 'cpu',
  45. attn_mask: torch.Tensor = None):
  46. assert attn_mask is not None
  47. prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
  48. sampled_indices = torch.multinomial(prob_dist, reserved_length)
  49. sampled_indices = sampled_indices.reshape(layers,
  50. batch_size,
  51. reserved_length).to(torch.int32)
  52. global random_ltd_module
  53. if random_ltd_module is None:
  54. random_ltd_module = RandomLTDBuilder().load()
  55. sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
  56. dtype = sampled_indices.dtype
  57. sampled_indices = sampled_indices.to(torch.long)
  58. new_mask = []
  59. for l in range(layers):
  60. tmp_mask_list = []
  61. for i in range(batch_size):
  62. mask_tmp = attn_mask[i:i + 1, :, sampled_indices[l][i], :]
  63. tmp_mask_list.append(mask_tmp[:, :, :, sampled_indices[l][i]])
  64. new_mask.append(torch.cat(tmp_mask_list, dim=0))
  65. return sampled_indices.to(dtype), new_mask
  66. class GatherTokens(torch.autograd.Function):
  67. @staticmethod
  68. def forward(ctx,
  69. activations: torch.Tensor,
  70. sorted_indices: torch.Tensor,
  71. batch_first: bool):
  72. global random_ltd_module
  73. if random_ltd_module is None:
  74. random_ltd_module = RandomLTDBuilder().load()
  75. ctx.save_for_backward(activations, sorted_indices)
  76. ctx.batch_first = batch_first
  77. return activations, random_ltd_module.token_gather(activations, sorted_indices, batch_first)
  78. @staticmethod
  79. def backward(ctx, a_gradients: torch.Tensor, g_gradients: torch.Tensor):
  80. g_gradients = g_gradients.contiguous()
  81. global random_ltd_module
  82. if random_ltd_module is None:
  83. random_ltd_module = RandomLTDBuilder().load()
  84. activations, sorted_indices = ctx.saved_tensors
  85. batch_first = ctx.batch_first
  86. return random_ltd_module.token_scatter_(a_gradients,
  87. g_gradients,
  88. sorted_indices,
  89. batch_first), None, None
  90. class ScatterTokens(torch.autograd.Function):
  91. @staticmethod
  92. def forward(ctx,
  93. all_activations: torch.Tensor,
  94. layer_activations: torch.Tensor,
  95. sorted_indices: torch.Tensor,
  96. batch_first: bool):
  97. global random_ltd_module
  98. if random_ltd_module is None:
  99. random_ltd_module = RandomLTDBuilder().load()
  100. scatter_results = random_ltd_module.token_scatter_(all_activations.clone(),
  101. layer_activations,
  102. sorted_indices,
  103. batch_first)
  104. ctx.save_for_backward(sorted_indices)
  105. ctx.batch_first = batch_first
  106. return scatter_results
  107. @staticmethod
  108. def backward(ctx, out_gradients: torch.Tensor):
  109. out_gradients = out_gradients.contiguous()
  110. global random_ltd_module
  111. if random_ltd_module is None:
  112. random_ltd_module = RandomLTDBuilder().load()
  113. sorted_indices, = ctx.saved_tensors
  114. batch_first = ctx.batch_first
  115. ret_val = random_ltd_module.token_gather(out_gradients,
  116. sorted_indices,
  117. batch_first)
  118. return out_gradients, ret_val, None, None