ccl.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. '''
  5. Copyright 2021 The Microsoft DeepSpeed Team
  6. '''
  7. import torch
  8. from deepspeed.accelerator import get_accelerator
  9. from .reduce_op import ReduceOp
  10. from .torch import TorchBackend
  11. def build_ccl_op():
  12. builder = get_accelerator().create_op_builder("CCLCommBuilder")
  13. if builder is None:
  14. return None
  15. ccl_cpp_module = builder.load()
  16. print(f'DeepSpeed {builder.absolute_name()} built successfully')
  17. return ccl_cpp_module
  18. class CCLHandler():
  19. def __init__(self, ccl_comm_op=None):
  20. self.ccl_comm_op = ccl_comm_op
  21. def wait(self):
  22. # backend covered it
  23. pass
  24. class CCLBackend(TorchBackend):
  25. def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
  26. self.ccl_comm_op = build_ccl_op()
  27. if self.ccl_comm_op is None:
  28. # set CCLBackend to uninitialized state if CCLCommBuilder cannot be loaded
  29. self.initialized = False
  30. return
  31. super(CCLBackend, self).__init__(backend='ccl',
  32. name='torch',
  33. rank=rank,
  34. world_size=world_size,
  35. timeout=timeout,
  36. init_method=init_method)
  37. self.name = 'ccl'
  38. size = self.get_world_size()
  39. rank = self.get_rank()
  40. main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
  41. main_kvs = torch.tensor(main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
  42. super(CCLBackend, self).broadcast(main_kvs, 0)
  43. self.ccl_comm_op.initialize(size, rank, main_kvs)
  44. self.initialized = True
  45. self.groups = [tuple(range(self.get_world_size()))]
  46. self.available_coll = self.ccl_comm_op.get_available_coll()
  47. def is_initialized(self):
  48. return self.initialized
  49. def run_collective(self, name, **kwargs):
  50. if name in self.available_coll:
  51. if 'group' in kwargs:
  52. kwargs['group'] = self.get_all_ranks_from_group(kwargs['group'])
  53. if 'dst' in kwargs:
  54. kwargs['dst'] = kwargs['group'].index(kwargs['dst'])
  55. if 'src' in kwargs:
  56. kwargs['src'] = kwargs['group'].index(kwargs['src'])
  57. func = "self.ccl_comm_op." + name
  58. eval(func)(*(kwargs.values()))
  59. return CCLHandler(self.ccl_comm_op)
  60. else:
  61. func = "super(CCLBackend, self)." + name
  62. eval(func)(*(kwargs.values()))
  63. return CCLHandler(self.ccl_comm_op)
  64. def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
  65. use_caching = False
  66. if use_caching:
  67. match_id = f"{tensor.size()}-{op}"
  68. name = "all_reduce_caching"
  69. if name in self.available_coll:
  70. group = self.get_all_ranks_from_group(group)
  71. return self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
  72. else:
  73. return self.run_collective(name=name,
  74. tensor=tensor,
  75. op=op,
  76. match_id=match_id,
  77. group=group,
  78. async_op=async_op)
  79. else:
  80. name = "all_reduce"
  81. if name in self.available_coll:
  82. group = self.get_all_ranks_from_group(group)
  83. return self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
  84. else:
  85. return self.run_collective(name=name, tensor=tensor, op=op, group=group, async_op=async_op)
  86. def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
  87. name = "inference_all_reduce"
  88. if name in self.available_coll:
  89. return self.ccl_comm_op.inference_all_reduce(tensor, op, async_op)
  90. else:
  91. return self.run_collective(name=name, tensor=tensor, op=op, group=None, async_op=async_op)
  92. def broadcast(self, tensor, src, group=None, async_op=False):
  93. return self.run_collective(name="broadcast", tensor=tensor, src=src, group=group, async_op=async_op)
  94. def all_gather(self, tensor_list, tensor, group=None, async_op=False):
  95. return self.run_collective(name="all_gather",
  96. tensor_list=tensor_list,
  97. tensor=tensor,
  98. group=group,
  99. async_op=async_op)
  100. def reduce_scatter_tensor(self, output_tensor, input_tensor, op, group=None, async_op=False):
  101. return self.run_collective(name="reduce_scatter_tensor",
  102. output_tensor=output_tensor,
  103. input_tensor=input_tensor,
  104. op=op,
  105. group=group)
  106. def all_gather_into_tensor(self, output_tensor, input_tensor, group=None, async_op=False):
  107. return self.run_collective(name="all_gather_into_tensor",
  108. output_tensor=output_tensor,
  109. input_tensor=input_tensor,
  110. group=group)
  111. def all_to_all_single(self, output, input, output_split_sizes, input_split_sizes, group=None, async_op=False):
  112. return self.run_collective(name="all_to_all_single",
  113. output=output,
  114. input=input,
  115. output_split_sizes=output_split_sizes,
  116. input_split_sizes=input_split_sizes,
  117. group=group)
  118. def send(self, tensor, dst, group=None, tag=0):
  119. return self.run_collective(name="send", tensor=tensor, dst=dst, group=group, tag=tag)
  120. def recv(self, tensor, src, group=None, tag=0):
  121. return self.run_collective(name="recv", tensor=tensor, src=src, group=group, tag=tag)
  122. def gather(self, tensor, gather_list, dst, group=None, async_op=False):
  123. return self.run_collective(name="gather", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
  124. def scatter(self, tensor, gather_list, dst, group=None, async_op=False):
  125. return self.run_collective(name="scatter", tensor=tensor, gather_list=gather_list, dst=dst, group=group)
  126. def barrier(self, group=None, async_op=False):
  127. return self.run_collective(name="barrier", group=group, async_op=async_op)
  128. def monitored_barrier(self, group=None, timeout=None, wait_all_ranks=False):
  129. return self.run_collective(name="monitored_barrier", group=group)
  130. def reduce_scatter(self, output, input_list, op=ReduceOp.SUM, group=None, async_op=False):
  131. return self.run_collective(name="reduce_scatter",
  132. output=output,
  133. input_list=input_list,
  134. op=op,
  135. group=group,
  136. async_op=async_op)
  137. def reduce(self, tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
  138. return self.run_collective(name="reduce", tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
  139. def new_group(self, ranks):
  140. return super(CCLBackend, self).new_group(ranks)
  141. def _new_group(self, ranks, group):
  142. size = len(ranks)
  143. rank = self.get_rank()
  144. sub_main_kvs = self.ccl_comm_op.get_sub_kvs_addr(rank == ranks[0])
  145. sub_main_kvs = torch.tensor(sub_main_kvs).to(torch.uint8).to(get_accelerator().current_device_name())
  146. super(CCLBackend, self).broadcast(sub_main_kvs, ranks[0], group)
  147. self.ccl_comm_op.initialize_sub_comm(size, ranks.index(rank), sub_main_kvs, ranks)
  148. self.groups.append(tuple(ranks))
  149. def get_all_ranks_from_group(self, group):
  150. if group is None:
  151. return list(range(self.get_world_size()))
  152. rank = 0
  153. results = []
  154. try:
  155. while True:
  156. results.append(super(CCLBackend, self).get_global_rank(group, rank))
  157. rank += 1
  158. except (ValueError, RuntimeError):
  159. pass
  160. if tuple(results) not in self.groups:
  161. self._new_group(results, group)
  162. return results