123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from deepspeed.ops.op_builder import RandomLTDBuilder
- """
- Returns:
- sampled_indices: [layers, batch_size, reserved_length]
- new_mask: [batch_size, 1, reserved_length, reserved_length]
- """
- random_ltd_module = None
- def gpt_sample_tokens(reserved_length: int,
- seq_length: int,
- batch_size: int,
- layers: int = 1,
- device: str = 'cpu',
- attn_mask: torch.Tensor = None):
- prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
- sampled_indices = torch.multinomial(prob_dist, reserved_length)
- sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
- global random_ltd_module
- if random_ltd_module is None:
- random_ltd_module = RandomLTDBuilder().load()
- sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
- # Not certain the optimized kernel is actually better here, cause it kind of screws
- # with alignment right if the sequence length is not divisible by like 16
- # new_mask = random_ltd_module.mask_gather_gpt(attn_mask, reserved_length)
- if attn_mask is not None:
- new_mask = attn_mask[:, :, :reserved_length, :reserved_length]
- else:
- new_mask = None
- return sampled_indices, new_mask
- """
- Returns:
- sampled_indices: [layers, batch_size, reserved_length]
- new_mask: [layers, batch_size, 1, reserved_length, reserved_length]
- """
- def bert_sample_tokens(reserved_length: int,
- seq_length: int,
- batch_size: int,
- layers: int = 1,
- device: str = 'cpu',
- attn_mask: torch.Tensor = None):
- assert attn_mask is not None
- prob_dist = torch.ones((layers * batch_size, seq_length), device=device)
- sampled_indices = torch.multinomial(prob_dist, reserved_length)
- sampled_indices = sampled_indices.reshape(layers, batch_size, reserved_length).to(torch.int32)
- global random_ltd_module
- if random_ltd_module is None:
- random_ltd_module = RandomLTDBuilder().load()
- sampled_indices = random_ltd_module.token_sort_(sampled_indices, seq_length)
- dtype = sampled_indices.dtype
- sampled_indices = sampled_indices.to(torch.long)
- new_mask = []
- for l in range(layers):
- tmp_mask_list = []
- for i in range(batch_size):
- mask_tmp = attn_mask[i:i + 1, :, sampled_indices[l][i], :]
- tmp_mask_list.append(mask_tmp[:, :, :, sampled_indices[l][i]])
- new_mask.append(torch.cat(tmp_mask_list, dim=0))
- return sampled_indices.to(dtype), new_mask
- class GatherTokens(torch.autograd.Function):
- @staticmethod
- def forward(ctx, activations: torch.Tensor, sorted_indices: torch.Tensor, batch_first: bool):
- global random_ltd_module
- if random_ltd_module is None:
- random_ltd_module = RandomLTDBuilder().load()
- ctx.save_for_backward(activations, sorted_indices)
- ctx.batch_first = batch_first
- return activations, random_ltd_module.token_gather(activations, sorted_indices, batch_first)
- @staticmethod
- def backward(ctx, a_gradients: torch.Tensor, g_gradients: torch.Tensor):
- g_gradients = g_gradients.contiguous()
- global random_ltd_module
- if random_ltd_module is None:
- random_ltd_module = RandomLTDBuilder().load()
- activations, sorted_indices = ctx.saved_tensors
- batch_first = ctx.batch_first
- return random_ltd_module.token_scatter_(a_gradients, g_gradients, sorted_indices, batch_first), None, None
- class ScatterTokens(torch.autograd.Function):
- @staticmethod
- def forward(ctx, all_activations: torch.Tensor, layer_activations: torch.Tensor, sorted_indices: torch.Tensor,
- batch_first: bool):
- global random_ltd_module
- if random_ltd_module is None:
- random_ltd_module = RandomLTDBuilder().load()
- scatter_results = random_ltd_module.token_scatter_(all_activations.clone(), layer_activations, sorted_indices,
- batch_first)
- ctx.save_for_backward(sorted_indices)
- ctx.batch_first = batch_first
- return scatter_results
- @staticmethod
- def backward(ctx, out_gradients: torch.Tensor):
- out_gradients = out_gradients.contiguous()
- global random_ltd_module
- if random_ltd_module is None:
- random_ltd_module = RandomLTDBuilder().load()
- sorted_indices, = ctx.saved_tensors
- batch_first = ctx.batch_first
- ret_val = random_ltd_module.token_gather(out_gradients, sorted_indices, batch_first)
- return out_gradients, ret_val, None, None
|