p2p.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import pickle
  5. import typing
  6. import torch
  7. from deepspeed import comm as dist
  8. # To query whether we have send/recv support
  9. from packaging.version import Version
  10. from deepspeed.git_version_info import torch_info
  11. from deepspeed.accelerator import get_accelerator
  12. _groups = None
  13. _grid = None
  14. _async = []
  15. def can_send_recv() -> bool:
  16. torch_version = Version(torch_info['version'])
  17. sendrecv_min = Version('1.8')
  18. return torch_version >= sendrecv_min
  19. #initializes adjacent process groups
  20. #run this only after deepspeed.init_distributed() has been called
  21. def init_process_groups(grid):
  22. global _groups, _grid
  23. _grid = grid
  24. assert _grid.pipe_parallel_size > 1, "There is no pipeline parallelism"
  25. if not can_send_recv():
  26. _groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups]
  27. def _is_valid_send_recv(src_stage, dest_stage):
  28. first_stage = 0
  29. last_stage = _grid.pipe_parallel_size - 1
  30. assert abs(src_stage-dest_stage) == 1 or \
  31. (src_stage == first_stage and dest_stage == last_stage) or \
  32. (src_stage == last_stage and dest_stage == first_stage), \
  33. "Functionality currently limited to send and receive between adjacent ranks only"
  34. def send(tensor, dest_stage, async_op=False):
  35. global _groups
  36. assert async_op == False, "Doesn't support async_op true"
  37. src_stage = _grid.get_stage_id()
  38. _is_valid_send_recv(src_stage, dest_stage)
  39. dest_rank = _grid.stage_to_global(stage_id=dest_stage)
  40. if async_op:
  41. global _async
  42. op = dist.isend(tensor, dest_rank)
  43. _async.append(op)
  44. else:
  45. if can_send_recv():
  46. return dist.send(tensor, dest_rank)
  47. else:
  48. group = _get_send_recv_group(src_stage, dest_stage)
  49. src_rank = _grid.stage_to_global(stage_id=src_stage)
  50. return dist.broadcast(tensor, src_rank, group=group, async_op=async_op)
  51. def recv(tensor, src_stage, async_op=False):
  52. global _groups
  53. assert async_op == False, "Doesn't support async_op true"
  54. dest_stage = _grid.get_stage_id()
  55. _is_valid_send_recv(src_stage, dest_stage)
  56. src_rank = _grid.stage_to_global(stage_id=src_stage)
  57. if async_op:
  58. global _async
  59. op = dist.irecv(tensor, src_rank)
  60. _async.append(op)
  61. else:
  62. if can_send_recv():
  63. return dist.recv(tensor, src_rank)
  64. else:
  65. group = _get_send_recv_group(src_stage, dest_stage)
  66. return dist.broadcast(tensor, src_rank, group=group, async_op=async_op)
  67. def wait():
  68. global _async
  69. for op in _async:
  70. op.wait()
  71. _async = []
  72. get_accelerator().synchronize()
  73. def send_obj(msg: typing.Any, dest: int):
  74. """Send an arbitrary python object to ``dest``.
  75. Note: ``msg`` must be pickleable.
  76. WARN: This incurs a CPU -> GPU transfer and should be used sparingly
  77. for performance reasons.
  78. Args:
  79. msg (typing.Any): The object to send.
  80. dest (int): Destination rank.
  81. """
  82. # serialize the message
  83. msg = pickle.dumps(msg)
  84. # construct a tensor to send
  85. msg = torch.ByteTensor(torch.ByteStorage.from_buffer(msg)).to(get_accelerator().device_name())
  86. # Send meta and message
  87. length_tensor = torch.tensor([len(msg)], dtype=torch.long).to(get_accelerator().device_name())
  88. dist.send(length_tensor, dst=dest)
  89. dist.send(msg, dst=dest)
  90. def recv_obj(sender: int) -> typing.Any:
  91. """Receive an arbitrary python object from ``sender``.
  92. WARN: This incur a CPU <-> GPU transfers and should be used sparingly
  93. for performance reasons.
  94. Args:
  95. sender (int): The rank sending the message.
  96. """
  97. # Get message meta
  98. length = torch.tensor([0], dtype=torch.long).to(get_accelerator().device_name())
  99. dist.recv(length, src=sender)
  100. # Receive and deserialize
  101. msg = torch.empty(length.item(), dtype=torch.uint8).to(get_accelerator().device_name())
  102. dist.recv(msg, src=sender)
  103. msg = pickle.loads(msg.cpu().numpy().tobytes())
  104. def _to(x):
  105. """Recursively move to the current device."""
  106. if torch.is_tensor(x):
  107. return x.to(get_accelerator().device_name())
  108. if isinstance(x, (tuple, list)):
  109. ret = [_to(x_) for x_ in x]
  110. if isinstance(x, tuple):
  111. ret = tuple(ret)
  112. return ret
  113. # handle kwargs
  114. if isinstance(x, dict):
  115. ret = dict()
  116. for key, val in x.items():
  117. ret[_to(key)] = _to(val)
  118. return ret
  119. # Anything else is a no-op
  120. return x
  121. msg = _to(msg)
  122. return msg
  123. def _get_send_recv_group(src_stage, dest_stage):
  124. '''the group id is always the smaller rank unless its a wrap around'''
  125. stage_id = None
  126. first_stage = 0
  127. last_stage = _grid.pipe_parallel_size - 1
  128. if (src_stage == first_stage and dest_stage == last_stage
  129. or dest_stage == first_stage and src_stage == last_stage):
  130. stage_id = last_stage
  131. elif src_stage > dest_stage:
  132. stage_id = dest_stage
  133. else:
  134. stage_id = src_stage
  135. '''group_id corresponds to group of [group_id, group_id+1]
  136. unless group_id is the rank of the last stage
  137. in which case group_id corresponds to group[group_id-num_stages+1, group_id]
  138. '''
  139. group_id = _grid.stage_to_global(stage_id=stage_id)
  140. return _groups[group_id]