__init__.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. import torch
  3. from .utils import *
  4. from deepspeed import utils
  5. supported_torch_version = False
  6. # See more details at: https://github.com/pytorch/pytorch/pull/48767
  7. # The PG API in torch versions lesser than 1.8 are different so it is
  8. # non-trivial to support both in the same API. We will just use the
  9. # DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
  10. if older_torch():
  11. # Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
  12. # NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
  13. supported_torch_version = False
  14. from torch.distributed import *
  15. def get_world_group():
  16. return group.WORLD
  17. def get_global_rank(group, group_rank):
  18. if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
  19. from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
  20. else:
  21. from torch.distributed.distributed_c10d import _get_global_rank
  22. return _get_global_rank(group, group_rank)
  23. def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
  24. from torch.distributed import all_gather, get_world_size
  25. from torch import chunk
  26. output_tensors = list(chunk(output_tensor, get_world_size(group)))
  27. return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)
  28. def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
  29. from torch.distributed import reduce_scatter, get_world_size
  30. from torch import chunk
  31. input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
  32. return reduce_scatter(output_tensor, input_tensor_lst, group=group)
  33. def configure(deepspeed_config=None,
  34. enabled=None,
  35. prof_all=None,
  36. prof_ops=None,
  37. verbose=None):
  38. utils.logger.warn(
  39. "Communication logging is not supported in torch versions older than 1.8")
  40. else:
  41. supported_torch_version = True
  42. from .comm import *