distributed.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. '''
  2. Copyright 2020 The Microsoft DeepSpeed Team
  3. '''
  4. import os
  5. import torch
  6. from datetime import timedelta
  7. from .logging import logger
  8. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
  9. def init_distributed(dist_backend="nccl",
  10. auto_mpi_discovery=True,
  11. distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
  12. verbose=True,
  13. timeout=default_pg_timeout,
  14. init_method=None):
  15. """Initialize torch.distributed backend, potentially performing MPI discovery if needed
  16. Arguments:
  17. dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
  18. auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
  19. distributed_port: Optional (int). torch distributed backend port
  20. verbose: Optional (bool). verbose logging
  21. timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
  22. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
  23. """
  24. required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
  25. if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
  26. if verbose:
  27. logger.info(
  28. "Not using the DeepSpeed or torch.distributed launchers, attempting to detect MPI environment..."
  29. )
  30. if in_aml() and not in_dlts():
  31. patch_aml_env_for_torch_nccl_backend(verbose=verbose)
  32. else:
  33. mpi_discovery(distributed_port=distributed_port, verbose=verbose)
  34. if not torch.distributed.is_initialized():
  35. if verbose:
  36. logger.info(
  37. "Initializing torch distributed with backend: {}".format(dist_backend))
  38. assert isinstance(timeout, timedelta)
  39. torch.distributed.init_process_group(backend=dist_backend,
  40. timeout=timeout,
  41. init_method=init_method)
  42. def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
  43. """
  44. Discovery MPI environment via mpi4py and map to relevant torch.distributed state
  45. """
  46. from mpi4py import MPI
  47. import subprocess
  48. comm = MPI.COMM_WORLD
  49. rank = comm.Get_rank()
  50. world_size = comm.Get_size()
  51. master_addr = None
  52. if rank == 0:
  53. hostname_cmd = ["hostname -I"]
  54. result = subprocess.check_output(hostname_cmd, shell=True)
  55. master_addr = result.decode('utf-8').split()[0]
  56. master_addr = comm.bcast(master_addr, root=0)
  57. # Determine local rank by assuming hostnames are unique
  58. proc_name = MPI.Get_processor_name()
  59. all_procs = comm.allgather(proc_name)
  60. local_rank = sum([i == proc_name for i in all_procs[:rank]])
  61. os.environ['RANK'] = str(rank)
  62. os.environ['WORLD_SIZE'] = str(world_size)
  63. os.environ['LOCAL_RANK'] = str(local_rank)
  64. os.environ['MASTER_ADDR'] = master_addr
  65. os.environ['MASTER_PORT'] = str(distributed_port)
  66. if verbose:
  67. logger.info(
  68. "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  69. .format(os.environ['RANK'],
  70. os.environ['LOCAL_RANK'],
  71. os.environ['WORLD_SIZE'],
  72. os.environ['MASTER_ADDR'],
  73. os.environ['MASTER_PORT']))
  74. if torch.distributed.is_initialized():
  75. assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
  76. rank, torch.distributed.get_rank())
  77. assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
  78. world_size, torch.distributed.get_world_size())
  79. def in_aml():
  80. # Are we running inside an Azure Machine Learning (AML) environment?
  81. return 'AZUREML_EXPERIMENT_ID' in os.environ
  82. def in_dlts():
  83. # Are we running on a DLTS cluster?
  84. return 'DLTS_JOB_ID' in os.environ
  85. def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
  86. """Helper routine to get and set environment variables.
  87. This is adapted from Azure ML's documentation available from:
  88. https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
  89. """
  90. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  91. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  92. single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
  93. os.environ["WORLD_SIZE"])
  94. if not single_node:
  95. master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
  96. os.environ["MASTER_ADDR"] = master_node_params[0]
  97. # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
  98. if "MASTER_PORT" not in os.environ:
  99. os.environ["MASTER_PORT"] = str(master_port)
  100. else:
  101. os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
  102. os.environ["MASTER_PORT"] = "54965"
  103. if verbose:
  104. logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
  105. os.environ["NCCL_SOCKET_IFNAME"]))
  106. os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
  107. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  108. if verbose:
  109. logger.info(
  110. "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  111. .format(os.environ['RANK'],
  112. os.environ['LOCAL_RANK'],
  113. os.environ['WORLD_SIZE'],
  114. os.environ['MASTER_ADDR'],
  115. os.environ['MASTER_PORT']))