ccl.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. '''
  5. Copyright 2021 The Microsoft DeepSpeed Team
  6. '''
  7. import torch
  8. from deepspeed.accelerator import get_accelerator
  9. from .reduce_op import ReduceOp
  10. from .torch import TorchBackend
  11. def build_ccl_op():
  12. builder = get_accelerator().create_op_builder("CCLCommBuilder")
  13. if builder is None:
  14. return None
  15. ccl_cpp_module = builder.load()
  16. print(f'DeepSpeed {builder.absolute_name()} built successfully')
  17. return ccl_cpp_module
  18. class CCLHandler():
  19. def __init__(self, ccl_comm_op=None):
  20. self.ccl_comm_op = ccl_comm_op
  21. def wait(self):
  22. # backend covered it
  23. pass
  24. class CCLBackend(TorchBackend):
  25. def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
  26. self.ccl_comm_op = build_ccl_op()
  27. if self.ccl_comm_op is None:
  28. # set CCLBackend to uninitialized state if CCLCommBuilder cannot be loaded
  29. self.initialized = False
  30. return
  31. super(CCLBackend, self).__init__(backend='ccl',
  32. name='torch',
  33. rank=rank,
  34. world_size=world_size,
  35. timeout=timeout,
  36. init_method=init_method)
  37. self.name = 'ccl'
  38. size = self.get_world_size()
  39. rank = self.get_rank()
  40. main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
  41. main_kvs = torch.tensor(main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
  42. super(CCLBackend, self).broadcast(main_kvs, 0)
  43. self.ccl_comm_op.initialize(size, rank, main_kvs)
  44. self.initialized = True
  45. self.groups = [tuple(range(self.get_world_size()))]
  46. self.available_coll = self.ccl_comm_op.get_available_coll()
  47. def is_initialized(self):
  48. return self.initialized
  49. def run_collective(self, name, **kwargs):
  50. if name in self.available_coll:
  51. kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
  52. if 'dst' in kwargs:
  53. kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
  54. if 'src' in kwargs:
  55. kwargs['src'] = kwargs['group'].index(kwargs['src'])
  56. func = "self.ccl_comm_op." + name
  57. eval(func)(*(kwargs.values()))
  58. return CCLHandler(self.ccl_comm_op)
  59. else:
  60. func = "super(CCLBackend, self)." + name
  61. return eval(func)(*(kwargs.values()))
  62. def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
  63. use_caching = False
  64. if use_caching:
  65. match_id = f"{tensor.size()}-{op}"
  66. return self.run_collective(name="all_reduce_caching",
  67. tensor=tensor,
  68. op=op,
  69. match_id=match_id,
  70. group=group,
  71. async_op=async_op)
  72. else:
  73. return self.run_collective(name="all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
  74. def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
  75. return self.run_collective(name="inference_all_reduce", tensor=tensor, op=op, group=group, async_op=async_op)
  76. def broadcast(self, tensor, src, group=None, async_op=False):
  77. return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
  78. def all_gather(self, tensor_list, tensor, group=None, async_op=False):
  79. return self.run_collective(name="all_gather",
  80. tensor_list=tensor_list,
  81. tensor=tensor,
  82. group=group,
  83. async_op=async_op)
  84. def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False):
  85. return self.run_collective(name="reduce_scatter_tensor",
  86. output_tensor=output_tensor,
  87. input_tensor=input_tensor,
  88. op=op,
  89. group=group)
  90. def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
  91. return self.run_collective(name="all_gather_into_tensor",
  92. output_tensor=output_tensor,
  93. input_tensor=input_tensor,
  94. group=group)
  95. def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False):
  96. return self.run_collective(name="all_to_all_single",
  97. output=output,
  98. input=input,
  99. output_split_sizes=output_split_sizes,
  100. input_split_sizes=input_split_sizes,
  101. group=group)
  102. def send(self, tensor, dst, group=None, async_op=False):
  103. return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, async_op=async_op)
  104. def recv(self, tensor, src, group=None, async_op=False):
  105. return self.run_collective(name="recv", tensor=tensor, src=src, group=group, async_op=async_op)
  106. def gather(self, tensor, gather_list, dst, group=None, async_op=False):
  107. return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
  108. def scatter(self, tensor, gather_list, dst, group=None, async_op=False):
  109. return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
  110. def barrier(self, group=None, async_op=False):
  111. return self.run_collective(name="barrier", group=group, async_op=async_op)
  112. def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False):
  113. return self.run_collective(name="monitored_barrier", group=group)
  114. def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
  115. return self.run_collective(name="reduce_scatter",
  116. output=output,
  117. input_list=input_list,
  118. op=op,
  119. group=group,
  120. async_op=async_op)
  121. def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
  122. return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
  123. def new_group(self, ranks):
  124. return super(CCLBackend, self).new_group(ranks)
  125. def _new_group(self, ranks, group):
  126. size = len(ranks)
  127. rank = self.get_rank()
  128. sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0])
  129. sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
  130. super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group)
  131. self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks)
  132. self.groups.append(tuple(ranks))
  133. def get_all_ranks_from_group(self, group):
  134. if group is None:
  135. return list(range(self.get_world_size()))
  136. rank = 0
  137. results = []
  138. try:
  139. while True:
  140. results.append(super(CCLBackend, self).get_global_rank(group, rank))
  141. rank += 1
  142. except ValueError:
  143. pass
  144. if tuple(results) not in self.groups:
  145. self._new_group(results, group)
  146. return results