layer.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import torch
  5. from typing import Any, Tuple
  6. from torch import Tensor
  7. from torch.nn import Module
  8. import deepspeed.comm as dist
  9. def single_all_to_all(input, scatter_idx, gather_idx, group):
  10. seq_world_size = dist.get_world_size(group)
  11. inp_shape = list(input.shape)
  12. inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size
  13. if scatter_idx < 2:
  14. input_t = input.reshape(
  15. [seq_world_size, inp_shape[scatter_idx]] + \
  16. inp_shape[scatter_idx + 1:]
  17. ).contiguous()
  18. else:
  19. # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
  20. input_t = input.reshape(
  21. [-1, seq_world_size, inp_shape[scatter_idx]] + \
  22. inp_shape[scatter_idx + 1:]
  23. ).transpose(0, 1).contiguous()
  24. output = torch.empty_like(input_t)
  25. dist.all_to_all_single(output, input_t, group=group)
  26. # if scattering the seq-dim, transpose the heads back to the original dimension
  27. if scatter_idx < 2:
  28. output = output.transpose(0, 2).contiguous()
  29. return output.reshape(
  30. inp_shape[: gather_idx] + \
  31. [inp_shape[gather_idx] * seq_world_size,] + \
  32. inp_shape[gather_idx + 1:]).contiguous()
  33. class _SeqAllToAll(torch.autograd.Function):
  34. @staticmethod
  35. def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, gather_idx: int) -> Tensor:
  36. ctx.group = group
  37. ctx.scatter_idx = scatter_idx
  38. ctx.gather_idx = gather_idx
  39. return single_all_to_all(input, scatter_idx, gather_idx, group)
  40. @staticmethod
  41. def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
  42. return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None)
  43. class DistributedAttention(torch.nn.Module):
  44. """Initialization.
  45. Arguments:
  46. local_attention (Module): local attention with q,k,v
  47. sequence_process_group (ProcessGroup): sequence parallel process group
  48. scatter_idx (int): scatter_idx for all2all comm
  49. gather_idx (int): gather_idx for all2all comm
  50. """
  51. def __init__(
  52. self,
  53. local_attention: Module,
  54. sequence_process_group: dist.ProcessGroup,
  55. scatter_idx: int = 2,
  56. gather_idx: int = 0,
  57. ) -> None:
  58. super(DistributedAttention, self).__init__()
  59. self.local_attn = local_attention
  60. self.spg = sequence_process_group
  61. self.scatter_idx = scatter_idx
  62. self.gather_idx = gather_idx
  63. def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
  64. """ forward
  65. Arguments:
  66. query (Tensor): query input to the layer
  67. key (Tensor): key input to the layer
  68. value (Tensor): value input to the layer
  69. args: other args
  70. Returns:
  71. * output (Tensor): context output
  72. """
  73. # TODO Merge three alltoall calls into one
  74. # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
  75. #in shape : e.g., [s/p:h:]
  76. query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx)
  77. key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx)
  78. value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx)
  79. #out shape : e.g., [s:h/p:]
  80. context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)
  81. output = _SeqAllToAll.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
  82. #out e.g., [s/p::h]
  83. return output