layer.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from typing import Any, Tuple
  6. from torch import Tensor
  7. from torch.nn import Module
  8. import deepspeed.comm as dist
  9. class _SeqAllToAll(torch.autograd.Function):
  10. @staticmethod
  11. def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
  12. ctx.group = group
  13. ctx.scatter_idx = scatter_idx
  14. ctx.gather_idx = gather_idx
  15. seq_world_size = dist.get_world_size(group)
  16. input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)]
  17. output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
  18. # TODO Use all_to_all_single instead
  19. dist.all_to_all(output_list, input_list, group=group)
  20. return torch.cat(output_list, dim=gather_idx).contiguous()
  21. @staticmethod
  22. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
  23. return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
  24. class DistributedAttention(torch.nn.Module):
  25. """Initialization.
  26. Arguments:
  27. local_attention (Module): local attention with q,k,v
  28. sequence_process_group (ProcessGroup): sequence parallel process group
  29. scatter_idx (int): scatter_idx for all2all comm
  30. gather_idx (int): gather_idx for all2all comm
  31. """
  32. def __init__(
  33. self,
  34. local_attention: Module,
  35. sequence_process_group: dist.ProcessGroup,
  36. scatter_idx: int = 2,
  37. gather_idx: int = 0,
  38. ) -> None:
  39. super(DistributedAttention, self).__init__()
  40. self.local_attn = local_attention
  41. self.spg = sequence_process_group
  42. self.scatter_idx = scatter_idx
  43. self.gather_idx = gather_idx
  44. def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tensor:
  45. """ forward
  46. Arguments:
  47. query (Tensor): query input to the layer
  48. key (Tensor): key input to the layer
  49. value (Tensor): value input to the layer
  50. args: other args
  51. Returns:
  52. * output (Tensor): context output
  53. """
  54. # TODO Merge three alltoall calls into one
  55. #in shape : e.g., [s/p:h:]
  56. query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
  57. key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
  58. value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
  59. #out shape : e.g., [s:h/p:]
  60. context_layer = self.local_attn(query_layer, key_layer, value_layer, *args)
  61. output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
  62. #out e.g., [s/p::h]
  63. return output