mappings.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. # The file has been adapted from the following Megatron-LM file:
  5. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
  6. # Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796
  7. # We retain the following copyright from the original files:
  8. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import torch
  21. import deepspeed
  22. from deepspeed.utils.bwc import (bwc_tensor_model_parallel_world_size, bwc_tensor_model_parallel_rank,
  23. bwc_tensor_model_parallel_group)
  24. def _gather_tokens(input_, dim=0):
  25. """Gather tensors and concatenate them along a dimension"""
  26. mpu = deepspeed.utils.groups.mpu
  27. input_ = input_.contiguous()
  28. # Size and dimension.
  29. rank = bwc_tensor_model_parallel_rank(mpu)
  30. tensor_list = [torch.empty_like(input_) for _ in range(bwc_tensor_model_parallel_world_size(mpu))]
  31. tensor_list[rank] = input_
  32. deepspeed.comm.all_gather(tensor_list, input_, group=bwc_tensor_model_parallel_group(mpu))
  33. # Note: torch.cat already creates a contiguous tensor.
  34. output = torch.cat(tensor_list, dim=dim).contiguous()
  35. return output
  36. def _drop_tokens(input_, dim=0):
  37. """Divide a tensor among the tensor parallel ranks"""
  38. mpu = deepspeed.utils.groups.mpu
  39. total_chunks = bwc_tensor_model_parallel_world_size(mpu)
  40. this_chunk = bwc_tensor_model_parallel_rank(mpu)
  41. assert input_.shape[
  42. dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
  43. chunk_size = input_.shape[dim] // total_chunks
  44. return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
  45. class _GatherTokens(torch.autograd.Function):
  46. """All gather tokens among the tensor parallel ranks"""
  47. @staticmethod
  48. def symbolic(graph, input_, dim):
  49. return _gather_tokens(input_, dim)
  50. @staticmethod
  51. def forward(ctx, input_, dim):
  52. ctx.dim = dim
  53. return _gather_tokens(input_, dim)
  54. @staticmethod
  55. def backward(ctx, grad_output):
  56. return _drop_tokens(grad_output, ctx.dim), None
  57. class _DropTokens(torch.autograd.Function):
  58. "Divide tokens equally among the tensor parallel ranks"
  59. @staticmethod
  60. def symbolic(graph, input_, dim):
  61. return _drop_tokens(input_, dim)
  62. @staticmethod
  63. def forward(ctx, input_, dim):
  64. ctx.dim = dim
  65. return _drop_tokens(input_, dim)
  66. @staticmethod
  67. def backward(ctx, input_):
  68. return _gather_tokens(input_, ctx.dim), None
  69. def gather_tokens(input_, dim=0):
  70. mpu = deepspeed.utils.groups.mpu
  71. if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
  72. # no tensor parallelism for non-experts
  73. return input_
  74. return _GatherTokens.apply(input_, dim)
  75. def drop_tokens(input_, dim=0):
  76. mpu = deepspeed.utils.groups.mpu
  77. if mpu is None or bwc_tensor_model_parallel_world_size(mpu) == 1:
  78. # no tensor parallelism for non-experts
  79. return input_
  80. return _DropTokens.apply(input_, dim)