12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- '''Copyright The Microsoft DeepSpeed Team'''
- import torch
- from .utils import *
- from deepspeed import utils
- supported_torch_version = False
- # See more details at: https://github.com/pytorch/pytorch/pull/48767
- # The PG API in torch versions lesser than 1.8 are different so it is
- # non-trivial to support both in the same API. We will just use the
- # DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
- if older_torch():
- # Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
- # NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
- supported_torch_version = False
- from torch.distributed import *
- def get_world_group():
- return group.WORLD
- def get_global_rank(group, group_rank):
- if hasattr(torch.distributed.distributed_c10d, "get_global_rank"):
- from torch.distributed.distributed_c10d import get_global_rank as _get_global_rank
- else:
- from torch.distributed.distributed_c10d import _get_global_rank
- return _get_global_rank(group, group_rank)
- def allgather_fn(output_tensor, input_tensor, group=None, async_op=False):
- from torch.distributed import all_gather, get_world_size
- from torch import chunk
- output_tensors = list(chunk(output_tensor, get_world_size(group)))
- return all_gather(output_tensors, input_tensor, group=group, async_op=async_op)
- def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False):
- from torch.distributed import reduce_scatter, get_world_size
- from torch import chunk
- input_tensor_lst = list(chunk(input_tensor, get_world_size(group)))
- return reduce_scatter(output_tensor, input_tensor_lst, group=group)
- def configure(deepspeed_config=None,
- enabled=None,
- prof_all=None,
- prof_ops=None,
- verbose=None):
- utils.logger.warn(
- "Communication logging is not supported in torch versions older than 1.8")
- else:
- supported_torch_version = True
- from .comm import *
|