123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769 |
- """
- Copyright 2021 The Microsoft DeepSpeed Team
- DeepSpeed Communication Package: deepspeed.comm
- deepspeed.comm
- -- import and use deepspeeed.ops.comm
- -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
- -- use custom collectives
- -- can either use torch.dist or ds.ops.comm?
- Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
- deepspeed.comm API
- -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
- -- e.g. if a client code used
- from deepspeed import comm as dist
- instead of
- import torch.distributed as dist
- The code should work without breaking any of the public torch.distributed functionality
- Future:
- -- deepspeed groups API should be brought into ds.comm
- """
- from enum import Enum
- import torch
- import os
- from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
- from .constants import *
- from deepspeed.accelerator import get_accelerator
- class ReduceOp(Enum):
- SUM = 0
- PRODUCT = 1
- MIN = 2
- MAX = 3
- BAND = 4
- BOR = 5
- BXOR = 6
- AVG = 7
- UNUSED = 8
- from deepspeed.utils.comms_logging import CommsLogger
- from deepspeed.utils import timer, get_caller_func
- from deepspeed.comm.torch import TorchBackend
- from deepspeed import utils
- from datetime import timedelta
- # Current deepspeed.comm backend (cdb) global object for simple access by client code
- use_ds_backend = False
- cdb = None
- # Create global timer for ops
- timers = timer.SynchronizedWallClockTimer()
- timer_summary = {}
- comms_logger = CommsLogger()
- # Maintain objects of all initialized ds backends and assign them using the API functions in this file
- nccl_backend = None
- mpi_backend = None
- # This should be set here so all rank/size information from the launcher can be propagated
- from deepspeed.comm.utils import *
- def _configure_using_config_file(config):
- if config.comms_logger_enabled:
- comms_logger.configure(config)
- def configure(
- deepspeed_config=None,
- enabled=None,
- prof_all=None,
- prof_ops=None,
- verbose=None,
- debug=None,
- ):
- if deepspeed_config is not None:
- _configure_using_config_file(deepspeed_config.comms_config)
- if enabled is not None:
- comms_logger.enabled = enabled
- if prof_all is not None:
- comms_logger.prof_all = prof_all
- if prof_ops is not None:
- comms_logger.prof_ops = prof_ops
- if verbose is not None:
- comms_logger.verbose = verbose
- if debug is not None:
- comms_logger.debug = debug
- # Logging wrapper for timing ops
- def timed_op(func):
- def log_wrapper(*args, **kwargs):
- # Add enabled flag so that overhead to each comm op is two if conditions at most
- if comms_logger.enabled:
- if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
- 'log_name' in kwargs
- and kwargs['log_name'] in comms_logger.prof_ops):
- # Need func args for their defaults
- func_args = get_default_args(func)
- func_args.update(kwargs)
- msg_size = get_msg_size_from_args(func, *args, **kwargs)
- log_name = get_debug_log_name(func_args, comms_logger.debug)
- timers(log_name).start()
- # Return the op, then stop the op's timer
- try:
- return func(*args, **kwargs)
- finally:
- if comms_logger.enabled:
- # Need to make op blocking for accurate logging
- get_accelerator().synchronize()
- # If we're using MPI, we can't simply sync the stream
- if cdb.using_mpi:
- cdb.barrier()
- if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
- 'log_name' in kwargs
- and kwargs['log_name'] in comms_logger.prof_ops):
- log_name = get_debug_log_name(func_args, comms_logger.debug)
- raw_name = func.__name__
- timers(log_name).stop()
- # need temp var since 'elapsed' resets events
- time_elapsed = timers(log_name).elapsed(reset=False)
- comms_logger.append(raw_name, log_name, time_elapsed, msg_size)
- return log_wrapper
- # For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
- # DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
- # Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
- # UNUSED: Future helper function to initialize DS backends
- def init_deepspeed_backend(ds_backend):
- global cdb
- global nccl_backend
- global mpi_backend
- global use_ds_backend
- if ds_backend == NCCL_BACKEND:
- utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
- elif ds_backend == MPI_BACKEND:
- utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
- elif ds_backend == GLOO_BACKEND:
- utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
- else:
- utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
- def is_initialized():
- #assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
- if cdb is None:
- return False
- else:
- return cdb.is_initialized()
- def destroy_process_group(group=None):
- global cdb
- return cdb.destroy_process_group(group=group)
- def new_group(ranks):
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- return cdb.new_group(ranks)
- def is_available() -> bool:
- # Returns ``True`` if the deepspeed comm package is available.
- # TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
- # any communication related primitives from this package.
- # use hasattr(deepspeed.csrc.ops, "_comm") or something
- return True
- def set_backend(backend):
- if not use_ds_backend:
- utils.logger.error(
- "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
- )
- raise RuntimeError(
- 'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
- )
- global cdb
- global nccl_backend
- global mpi_backend
- try:
- if backend_name == NCCL_BACKEND:
- if nccl_backend is not None and nccl_backend.is_initialized():
- cdb = nccl_backend
- elif backend_name == MPI_BACKEND:
- if mpi_backend is not None and mpi_backend.is_initialized():
- cdb = mpi_backend
- except Exception as inst:
- print(inst)
- @timed_op
- def broadcast(tensor,
- src,
- group=None,
- async_op=False,
- prof=False,
- log_name='broadcast',
- debug=get_caller_func()):
- global cdb
- return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
- @timed_op
- def all_gather(tensor_list,
- tensor,
- group=None,
- async_op=False,
- prof=False,
- log_name='all_gather',
- debug=get_caller_func()):
- global cdb
- return cdb.all_gather(tensor_list=tensor_list,
- tensor=tensor,
- group=group,
- async_op=async_op)
- def has_reduce_scatter_base():
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
- return cdb.has_reduce_scatter_base
- def reduce_scatter_fn(output_tensor,
- tensor,
- op=ReduceOp.SUM,
- group=None,
- async_op=False,
- prof=False,
- debug=get_caller_func()):
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- if cdb.has_reduce_scatter_base:
- return reduce_scatter_base(output_tensor,
- tensor,
- op=op,
- group=group,
- async_op=async_op,
- prof=prof,
- debug=debug)
- else:
- utils.logger.warning_once(
- "unable to find torch.distributed._reduce_scatter_base. will fall back to "
- "torch.distributed.all_gather which will result in suboptimal performance. "
- "please consider upgrading your pytorch installation.")
- input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
- return reduce_scatter(output_tensor,
- input_tensor_lst,
- op=op,
- group=group,
- async_op=async_op,
- prof=prof,
- debug=debug)
- @timed_op
- def reduce_scatter_base(output_tensor,
- tensor,
- op=ReduceOp.SUM,
- group=None,
- async_op=False,
- prof=False,
- log_name='reduce_scatter_base',
- debug=get_caller_func()):
- global cdb
- return cdb.reduce_scatter_base(output_tensor=output_tensor,
- input_tensor=tensor,
- op=op,
- group=group,
- async_op=async_op)
- @timed_op
- def all_gather_base(output_tensor,
- tensor,
- group=None,
- async_op=False,
- prof=False,
- log_name='all_gather_base',
- debug=get_caller_func()):
- global cdb
- return cdb.all_gather_base(output_tensor=output_tensor,
- input_tensor=tensor,
- group=group,
- async_op=async_op)
- def has_allgather_base():
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
- return cdb.has_allgather_base
- def allgather_fn(output_tensor,
- input_tensor,
- group=None,
- async_op=False,
- debug=get_caller_func()):
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- if cdb.has_allgather_base:
- return all_gather_base(output_tensor,
- input_tensor,
- group=group,
- async_op=async_op,
- debug=debug)
- else:
- if get_rank() == 0:
- utils.logger.warning_once(
- "unable to find torch.distributed._all_gather_base. will fall back to "
- "torch.distributed.all_gather which will result in suboptimal performance. "
- "please consider upgrading your pytorch installation.")
- output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
- return all_gather(output_tensors,
- input_tensor,
- group=group,
- async_op=async_op,
- debug=debug)
- @timed_op
- def all_to_all_single(output,
- tensor,
- output_split_sizes=None,
- input_split_sizes=None,
- group=None,
- async_op=False,
- prof=False,
- log_name='all_to_all_single',
- debug=get_caller_func()):
- global cdb
- return cdb.all_to_all_single(output=output,
- input=tensor,
- output_split_sizes=output_split_sizes,
- input_split_sizes=input_split_sizes,
- group=group,
- async_op=async_op)
- @timed_op
- def send(tensor,
- dst,
- group=None,
- tag=0,
- prof=False,
- log_name='send',
- debug=get_caller_func()):
- global cdb
- return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
- @timed_op
- def recv(tensor,
- src=None,
- group=None,
- tag=0,
- prof=False,
- log_name='recv',
- debug=get_caller_func()):
- global cdb
- return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
- @timed_op
- def isend(tensor,
- dst,
- group=None,
- tag=0,
- prof=False,
- log_name='isend',
- debug=get_caller_func()):
- global cdb
- return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
- @timed_op
- def irecv(tensor,
- src=None,
- group=None,
- tag=0,
- prof=False,
- log_name='irecv',
- debug=get_caller_func()):
- global cdb
- return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
- @timed_op
- def gather(tensor,
- gather_list=None,
- dst=0,
- group=None,
- async_op=False,
- prof=False,
- log_name='gather',
- debug=get_caller_func()):
- global cdb
- return cdb.gather(tensor=tensor,
- gather_list=gather_list,
- dst=dst,
- group=group,
- async_op=async_op)
- @timed_op
- def scatter(tensor,
- scatter_list=None,
- src=0,
- group=None,
- async_op=False,
- prof=False,
- log_name='scatter',
- debug=get_caller_func()):
- global cdb
- return cdb.scatter(tensor=tensor,
- scatter_list=scatter_list,
- src=src,
- group=group,
- async_op=async_op)
- @timed_op
- def barrier(group=None,
- async_op=False,
- device_ids=None,
- prof=False,
- log_name='barrier',
- debug=get_caller_func()):
- global cdb
- return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
- @timed_op
- def monitored_barrier(group=None,
- timeout=None,
- wait_all_ranks=False,
- prof=False,
- log_name='monitored_barrier',
- debug=get_caller_func()):
- global cdb
- return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
- def log_summary():
- global cdb
- barrier(log_name='log_summary_barrier')
- if cdb.get_rank() == 0:
- comms_logger.log_all()
- barrier(log_name='log_summary_barrier')
- @timed_op
- def reduce(tensor,
- dst,
- op=ReduceOp.SUM,
- group=None,
- async_op=False,
- prof=False,
- log_name='reduce',
- debug=get_caller_func()):
- global cdb
- return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
- @timed_op
- def reduce_scatter(output,
- input_list,
- op=ReduceOp.SUM,
- group=None,
- async_op=False,
- prof=False,
- log_name='reduce_scatter',
- debug=get_caller_func()):
- global cdb
- return cdb.reduce_scatter(output=output,
- input_list=input_list,
- op=op,
- group=group,
- async_op=async_op)
- @timed_op
- def all_reduce(tensor,
- op=ReduceOp.SUM,
- group=None,
- async_op=False,
- prof=False,
- log_name='all_reduce',
- debug=get_caller_func()):
- #if profile_comm:
- # context of the timers?
- # timers.start()
- # TensorBoard logging for comm calls.?
- global cdb
- #print(f'op = {op}, cdb= {cdb.name}')
- return cdb.all_reduce(tensor, op, group, async_op)
- def get_world_group():
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- return cdb.get_world_group()
- def get_world_size(group=None) -> int:
- """
- Returns the number of processes in the current process group
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- The world size of the process group
- -1, if not part of the group
- """
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- return cdb.get_world_size(group)
- def get_rank(group=None):
- """
- Returns the rank of the current process in the provided ``group`` or the
- default group if none was provided.
- Rank is a unique identifier assigned to each process within a distributed
- process group. They are always consecutive integers ranging from 0 to
- ``world_size``.
- Args:
- group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- The rank of the process group
- -1, if not part of the group
- """
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- return cdb.get_rank(group)
- def get_local_rank():
- """
- Helper function to get local rank after a backend has been set and initialized
- Args:
- None
- Returns:
- local rank (= GPU device ID)
- """
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- return get_local_rank_from_launcher()
- def get_global_rank(group=None, group_rank=0):
- global cdb
- assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
- return cdb.get_global_rank(group, group_rank)
- # Main DeepSpeed Comms. public API.
- def init_distributed(dist_backend=None,
- auto_mpi_discovery=True,
- distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
- verbose=True,
- timeout=default_pg_timeout,
- init_method=None,
- dist_init_required=None,
- config=None,
- rank=-1,
- world_size=-1):
- ''' Initialize dist backend, potentially performing MPI discovery if needed
- Arguments:
- dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
- auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
- distributed_port: Optional (int). torch distributed backend port
- verbose: Optional (bool). verbose logging
- timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
- init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
- config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
- rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
- world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
- '''
- global cdb
- configure(deepspeed_config=config)
- if dist_init_required is None:
- dist_init_required = cdb is None or not cdb.is_initialized()
- if cdb is None and torch.distributed.is_initialized():
- # The user initialized torch.dist themselves, create cdb and short-circuit
- cdb = TorchBackend(dist_backend, timeout, init_method)
- return
- if dist_init_required is False:
- assert (
- cdb is not None and cdb.is_initialized() is True
- ), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
- else:
- # Initialize torch distributed if needed
- required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
- if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
- if verbose:
- utils.logger.info(
- "Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
- )
- if in_aml() and not in_dlts():
- patch_aml_env_for_torch_nccl_backend(verbose=verbose)
- elif in_aws_sm():
- patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
- else:
- mpi_discovery(distributed_port=distributed_port, verbose=verbose)
- if cdb is not None and cdb.is_initialized():
- if int(os.getenv('RANK', '0')) == 0:
- utils.logger.info('Distributed backend already initialized')
- else:
- assert isinstance(timeout, timedelta)
- if dist_backend == None:
- dist_backend = get_accelerator().communication_backend_name()
- if int(os.getenv('RANK', '0')) == 0:
- utils.logger.info(
- 'Initializing TorchBackend in DeepSpeed with backend {}'.format(
- dist_backend))
- # Create a torch backend object, initialize torch distributed, and assign to cdb
- cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
- def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
- '''
- Discovery MPI environment via mpi4py and map to relevant dist state
- '''
- from mpi4py import MPI
- import subprocess
- comm = MPI.COMM_WORLD
- rank = comm.Get_rank()
- world_size = comm.Get_size()
- master_addr = None
- if rank == 0:
- hostname_cmd = ["hostname -I"]
- result = subprocess.check_output(hostname_cmd, shell=True)
- master_addr = result.decode('utf-8').split()[0]
- master_addr = comm.bcast(master_addr, root=0)
- # Determine local rank by assuming hostnames are unique
- proc_name = MPI.Get_processor_name()
- all_procs = comm.allgather(proc_name)
- local_rank = sum([i == proc_name for i in all_procs[:rank]])
- os.environ['RANK'] = str(rank)
- os.environ['WORLD_SIZE'] = str(world_size)
- os.environ['LOCAL_RANK'] = str(local_rank)
- os.environ['MASTER_ADDR'] = master_addr
- os.environ['MASTER_PORT'] = str(distributed_port)
- if verbose:
- utils.logger.info(
- "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
- .format(os.environ['RANK'],
- os.environ['LOCAL_RANK'],
- os.environ['WORLD_SIZE'],
- os.environ['MASTER_ADDR'],
- os.environ['MASTER_PORT']))
- if cdb is not None and cdb.is_initialized():
- assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
- rank, cdb.get_rank())
- assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
- world_size, cdb.get_world_size())
- def in_aml():
- # Are we running inside an Azure Machine Learning (AML) environment?
- return 'AZUREML_EXPERIMENT_ID' in os.environ
- def in_aws_sm():
- # Are we running inside an AWS SageMaker environment?
- return 'SM_TRAINING_ENV' in os.environ
- def in_dlts():
- # Are we running on a DLTS cluster?
- return 'DLTS_JOB_ID' in os.environ
- def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
- """Helper routine to get and set environment variables.
- This is adapted from Azure ML's documentation available from:
- https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
- """
- os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
- os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
- single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
- os.environ["WORLD_SIZE"])
- if not single_node:
- master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
- os.environ["MASTER_ADDR"] = master_node_params[0]
- # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
- if "MASTER_PORT" not in os.environ:
- os.environ["MASTER_PORT"] = str(master_port)
- else:
- os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
- os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
- if verbose:
- utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
- os.environ["NCCL_SOCKET_IFNAME"]))
- os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
- os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
- if verbose:
- utils.logger.info(
- "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
- .format(os.environ['RANK'],
- os.environ['LOCAL_RANK'],
- os.environ['WORLD_SIZE'],
- os.environ['MASTER_ADDR'],
- os.environ['MASTER_PORT']))
- def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
- """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
- """
- os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
- os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
- os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
- if verbose:
- utils.logger.info(
- "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
- .format(os.environ['RANK'],
- os.environ['LOCAL_RANK'],
- os.environ['WORLD_SIZE'],
- os.environ['MASTER_ADDR'],
- os.environ['MASTER_PORT']))
|