ccl.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. '''
  5. Copyright 2021 The Microsoft DeepSpeed Team
  6. '''
  7. import torch
  8. from deepspeed.accelerator import get_accelerator
  9. from .reduce_op import ReduceOp
  10. from .torch import TorchBackend
  11. def build_ccl_op():
  12. builder = get_accelerator().create_op_builder("CCLCommBuilder")
  13. if builder is None:
  14. return None
  15. ccl_cpp_module = builder.load()
  16. print(f'DeepSpeed {builder.absolute_name()} built successfully')
  17. return ccl_cpp_module
  18. class CCLBackend(TorchBackend):
  19. def __init__(self, name='ccl', rank=-1, world_size=-1, mpu=None, timeout=None, init_method=None):
  20. self.ccl_comm_op = build_ccl_op()
  21. if self.ccl_comm_op is None:
  22. # set CCLBackend to uninitialized state if CCLCommBuilder cannot be loaded
  23. self.initialized = False
  24. return
  25. super(CCLBackend, self).__init__(backend='ccl',
  26. name='torch',
  27. rank=rank,
  28. world_size=world_size,
  29. timeout=timeout,
  30. init_method=init_method)
  31. self.name = 'ccl'
  32. size = self.get_world_size()
  33. rank = self.get_rank()
  34. main_kvs = self.ccl_comm_op.get_kvs_addr(rank)
  35. main_kvs = torch.tensor(main_kvs).to(torch.uint8)
  36. super(CCLBackend, self).broadcast(main_kvs, 0)
  37. self.ccl_comm_op.initialize(size, rank, main_kvs)
  38. self.initialized = True
  39. def is_initialized(self):
  40. return self.initialized
  41. def broadcast(self, tensor, src, group=None, async_op=False):
  42. self.ccl_comm_op.broadcast(tensor, src, group, async_op)
  43. def all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
  44. use_caching = False
  45. if use_caching:
  46. match_id = f"{tensor.size()}-{op}"
  47. self.ccl_comm_op.all_reduce_caching(tensor, op, match_id, group, async_op)
  48. else:
  49. self.ccl_comm_op.all_reduce(tensor, op, group, async_op)
  50. def inference_all_reduce(self, tensor, op=ReduceOp.SUM, group=None, async_op=False):
  51. self.ccl_comm_op.inference_all_reduce(tensor, op, group, async_op)
  52. def barrier(self, group=None, async_op=False):
  53. self.ccl_comm_op.barrier(group, async_op)