mappings.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. '''
  2. Copyright 2022 The Microsoft DeepSpeed Team
  3. '''
  4. # The file has been adapted from the following Megatron-LM file:
  5. # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/mappings.py
  6. # Git commit hash: 9dc3c42a84aa656f583703cf8b6b4f79f712b796
  7. # We retain the following copyright from the original files:
  8. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import torch
  21. import deepspeed
  22. def _gather_tokens(input_, dim=0):
  23. """Gather tensors and concatenate them along a dimension"""
  24. mpu = deepspeed.utils.groups.mpu
  25. input_ = input_.contiguous()
  26. # Size and dimension.
  27. rank = mpu.get_tensor_model_parallel_rank()
  28. tensor_list = [
  29. torch.empty_like(input_)
  30. for _ in range(mpu.get_tensor_model_parallel_world_size())
  31. ]
  32. tensor_list[rank] = input_
  33. deepspeed.comm.all_gather(tensor_list,
  34. input_,
  35. group=mpu.get_tensor_model_parallel_group())
  36. # Note: torch.cat already creates a contiguous tensor.
  37. output = torch.cat(tensor_list, dim=dim).contiguous()
  38. return output
  39. def _drop_tokens(input_, dim=0):
  40. """Divide a tensor among the tensor parallel ranks"""
  41. mpu = deepspeed.utils.groups.mpu
  42. total_chunks = mpu.get_tensor_model_parallel_world_size()
  43. this_chunk = mpu.get_tensor_model_parallel_rank()
  44. assert input_.shape[dim] % total_chunks == 0, f"input dimension {dim} ({input_.shape[dim]}) is not divisible by tensor parallel world size ({total_chunks})"
  45. chunk_size = input_.shape[dim] // total_chunks
  46. return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)
  47. class _GatherTokens(torch.autograd.Function):
  48. """All gather tokens among the tensor parallel ranks"""
  49. @staticmethod
  50. def symbolic(graph, input_, dim):
  51. return _gather_tokens(input_, dim)
  52. @staticmethod
  53. def forward(ctx, input_, dim):
  54. ctx.dim = dim
  55. return _gather_tokens(input_, dim)
  56. @staticmethod
  57. def backward(ctx, grad_output):
  58. return _drop_tokens(grad_output, ctx.dim), None
  59. class _DropTokens(torch.autograd.Function):
  60. "Divide tokens equally among the tensor parallel ranks"
  61. @staticmethod
  62. def symbolic(graph, input_, dim):
  63. return _drop_tokens(input_, dim)
  64. @staticmethod
  65. def forward(ctx, input_, dim):
  66. ctx.dim = dim
  67. return _drop_tokens(input_, dim)
  68. @staticmethod
  69. def backward(ctx, input_):
  70. return _gather_tokens(input_, ctx.dim), None
  71. def gather_tokens(input_, dim=0):
  72. mpu = deepspeed.utils.groups.mpu
  73. if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
  74. # no tensor parallelism for non-experts
  75. return input_
  76. return _GatherTokens.apply(input_, dim)
  77. def drop_tokens(input_, dim=0):
  78. mpu = deepspeed.utils.groups.mpu
  79. if mpu is None or mpu.get_tensor_model_parallel_world_size() == 1:
  80. # no tensor parallelism for non-experts
  81. return input_
  82. return _DropTokens.apply(input_, dim)