12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- # Copyright (c) Microsoft Corporation.
- # SPDX-License-Identifier: Apache-2.0
- # DeepSpeed Team
- import torch
- from typing import Any, Tuple
- from torch import Tensor
- from torch.nn import Module
- import deepspeed.comm as dist
- class _SeqAllToAll(torch.autograd.Function):
- @staticmethod
- def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
- ctx.group = group
- ctx.scatter_idx = scatter_idx
- ctx.gather_idx = gather_idx
- seq_world_size = dist.get_world_size(group)
- input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
- output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
- # TODO Use all_to_all_single instead
- dist.all_to_all(output_list, input_list, group=group)
- return torch.cat(output_list, dim=gather_idx).contiguous()
- @staticmethod
- def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
- return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
- class DistributedAttention(torch.nn.Module):
- """Initialization.
- Arguments:
- local_attention (Module): local attention with q,k,v
- sequence_process_group (ProcessGroup): sequence parallel process group
- scatter_idx (int): scatter_idx for all2all comm
- gather_idx (int): gather_idx for all2all comm
- """
- def __init__(
- self,
- local_attention: Module,
- sequence_process_group: dist.ProcessGroup,
- scatter_idx: int = 2,
- gather_idx: int = 0,
- ) -> None:
- super(DistributedAttention, self).__init__()
- self.local_attn = local_attention
- self.spg = sequence_process_group
- self.scatter_idx = scatter_idx
- self.gather_idx = gather_idx
- def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
- """ forward
- Arguments:
- query (Tensor): query input to the layer
- key (Tensor): key input to the layer
- value (Tensor): value input to the layer
- args: other args
- Returns:
- * output (Tensor): context output
- """
- # TODO Merge three alltoall calls into one
- #in shape : e.g., [s/p:h:]
- query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
- key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
- value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
- #out shape : e.g., [s:h/p:]
- context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
- output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
- #out e.g., [s/p::h]
- return output
|