comm.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778
  1. """
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. DeepSpeed Communication Package: deepspeed.comm
  4. deepspeed.comm
  5. -- import and use deepspeeed.ops.comm
  6. -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
  7. -- use custom collectives
  8. -- can either use torch.dist or ds.ops.comm?
  9. Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
  10. deepspeed.comm API
  11. -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
  12. -- e.g. if a client code used
  13. from deepspeed import comm as dist
  14. instead of
  15. import torch.distributed as dist
  16. The code should work without breaking any of the public torch.distributed functionality
  17. Future:
  18. -- deepspeed groups API should be brought into ds.comm
  19. """
  20. from enum import Enum
  21. import torch
  22. import os
  23. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
  24. from .constants import *
  25. from deepspeed.accelerator import get_accelerator
  26. class ReduceOp(Enum):
  27. SUM = 0
  28. PRODUCT = 1
  29. MIN = 2
  30. MAX = 3
  31. BAND = 4
  32. BOR = 5
  33. BXOR = 6
  34. AVG = 7
  35. UNUSED = 8
  36. from deepspeed.utils.comms_logging import CommsLogger
  37. from deepspeed.utils import timer, get_caller_func
  38. from deepspeed.comm.torch import TorchBackend
  39. from deepspeed import utils
  40. from datetime import timedelta
  41. # Current deepspeed.comm backend (cdb) global object for simple access by client code
  42. use_ds_backend = False
  43. cdb = None
  44. # Create global timer for ops
  45. timers = timer.SynchronizedWallClockTimer()
  46. timer_summary = {}
  47. comms_logger = CommsLogger()
  48. # Ensure we don't warn about base collectives more than once
  49. has_warned_all_gather = False
  50. has_warned_reduce_scatter = False
  51. # Maintain objects of all initialized ds backends and assign them using the API functions in this file
  52. nccl_backend = None
  53. mpi_backend = None
  54. # This should be set here so all rank/size information from the launcher can be propagated
  55. from deepspeed.comm.utils import *
  56. def _configure_using_config_file(config):
  57. if config.comms_logger_enabled:
  58. comms_logger.configure(config)
  59. def configure(
  60. deepspeed_config=None,
  61. enabled=None,
  62. prof_all=None,
  63. prof_ops=None,
  64. verbose=None,
  65. debug=None,
  66. ):
  67. if deepspeed_config is not None:
  68. _configure_using_config_file(deepspeed_config.comms_config)
  69. if enabled is not None:
  70. comms_logger.enabled = enabled
  71. if prof_all is not None:
  72. comms_logger.prof_all = prof_all
  73. if prof_ops is not None:
  74. comms_logger.prof_ops = prof_ops
  75. if verbose is not None:
  76. comms_logger.verbose = verbose
  77. if debug is not None:
  78. comms_logger.debug = debug
  79. # Logging wrapper for timing ops
  80. def timed_op(func):
  81. def log_wrapper(*args, **kwargs):
  82. # Add enabled flag so that overhead to each comm op is two if conditions at most
  83. if comms_logger.enabled:
  84. if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
  85. 'log_name' in kwargs
  86. and kwargs['log_name'] in comms_logger.prof_ops):
  87. # Need func args for their defaults
  88. func_args = get_default_args(func)
  89. func_args.update(kwargs)
  90. msg_size = get_msg_size_from_args(func, *args, **kwargs)
  91. log_name = get_debug_log_name(func_args, comms_logger.debug)
  92. timers(log_name).start()
  93. # Return the op, then stop the op's timer
  94. try:
  95. return func(*args, **kwargs)
  96. finally:
  97. if comms_logger.enabled:
  98. # Need to make op blocking for accurate logging
  99. get_accelerator().synchronize()
  100. # If we're using MPI, we can't simply sync the stream
  101. if cdb.using_mpi:
  102. cdb.barrier()
  103. if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
  104. 'log_name' in kwargs
  105. and kwargs['log_name'] in comms_logger.prof_ops):
  106. log_name = get_debug_log_name(func_args, comms_logger.debug)
  107. raw_name = func.__name__
  108. timers(log_name).stop()
  109. # need temp var since 'elapsed' resets events
  110. time_elapsed = timers(log_name).elapsed(reset=False)
  111. comms_logger.append(raw_name, log_name, time_elapsed, msg_size)
  112. return log_wrapper
  113. # For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
  114. # DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
  115. # Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
  116. # UNUSED: Future helper function to initialize DS backends
  117. def init_deepspeed_backend(ds_backend):
  118. global cdb
  119. global nccl_backend
  120. global mpi_backend
  121. global use_ds_backend
  122. if ds_backend == NCCL_BACKEND:
  123. utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
  124. elif ds_backend == MPI_BACKEND:
  125. utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
  126. elif ds_backend == GLOO_BACKEND:
  127. utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
  128. else:
  129. utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
  130. def is_initialized():
  131. #assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
  132. if cdb is None:
  133. return False
  134. else:
  135. return cdb.is_initialized()
  136. def destroy_process_group(group=None):
  137. global cdb
  138. return cdb.destroy_process_group(group=group)
  139. def new_group(ranks):
  140. global cdb
  141. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  142. return cdb.new_group(ranks)
  143. def is_available() -> bool:
  144. # Returns ``True`` if the deepspeed comm package is available.
  145. # TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
  146. # any communication related primitives from this package.
  147. # use hasattr(deepspeed.csrc.ops, "_comm") or something
  148. return True
  149. def set_backend(backend):
  150. if not use_ds_backend:
  151. utils.logger.error(
  152. "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
  153. )
  154. raise RuntimeError(
  155. 'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
  156. )
  157. global cdb
  158. global nccl_backend
  159. global mpi_backend
  160. try:
  161. if backend_name == NCCL_BACKEND:
  162. if nccl_backend is not None and nccl_backend.is_initialized():
  163. cdb = nccl_backend
  164. elif backend_name == MPI_BACKEND:
  165. if mpi_backend is not None and mpi_backend.is_initialized():
  166. cdb = mpi_backend
  167. except Exception as inst:
  168. print(inst)
  169. @timed_op
  170. def broadcast(tensor,
  171. src,
  172. group=None,
  173. async_op=False,
  174. prof=False,
  175. log_name='broadcast',
  176. debug=get_caller_func()):
  177. global cdb
  178. return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
  179. @timed_op
  180. def all_gather(tensor_list,
  181. tensor,
  182. group=None,
  183. async_op=False,
  184. prof=False,
  185. log_name='all_gather',
  186. debug=get_caller_func()):
  187. global cdb
  188. return cdb.all_gather(tensor_list=tensor_list,
  189. tensor=tensor,
  190. group=group,
  191. async_op=async_op)
  192. def has_reduce_scatter_base():
  193. global cdb
  194. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  195. assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
  196. return cdb.has_reduce_scatter_base
  197. def reduce_scatter_fn(output_tensor,
  198. tensor,
  199. op=ReduceOp.SUM,
  200. group=None,
  201. async_op=False,
  202. prof=False,
  203. debug=get_caller_func()):
  204. global cdb
  205. global has_warned_reduce_scatter
  206. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  207. if cdb.has_reduce_scatter_base:
  208. return reduce_scatter_base(output_tensor,
  209. tensor,
  210. op=op,
  211. group=group,
  212. async_op=async_op,
  213. prof=prof,
  214. debug=debug)
  215. else:
  216. if not has_warned_reduce_scatter:
  217. utils.logger.warning(
  218. "unable to find torch.distributed._reduce_scatter_base. will fall back to "
  219. "torch.distributed.all_gather which will result in suboptimal performance. "
  220. "please consider upgrading your pytorch installation.")
  221. has_warned_reduce_scatter = True
  222. input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
  223. return reduce_scatter(output_tensor,
  224. input_tensor_lst,
  225. op=op,
  226. group=group,
  227. async_op=async_op,
  228. prof=prof,
  229. debug=debug)
  230. @timed_op
  231. def reduce_scatter_base(output_tensor,
  232. tensor,
  233. op=ReduceOp.SUM,
  234. group=None,
  235. async_op=False,
  236. prof=False,
  237. log_name='reduce_scatter_base',
  238. debug=get_caller_func()):
  239. global cdb
  240. return cdb.reduce_scatter_base(output_tensor=output_tensor,
  241. input_tensor=tensor,
  242. op=op,
  243. group=group,
  244. async_op=async_op)
  245. @timed_op
  246. def all_gather_base(output_tensor,
  247. tensor,
  248. group=None,
  249. async_op=False,
  250. prof=False,
  251. log_name='all_gather_base',
  252. debug=get_caller_func()):
  253. global cdb
  254. return cdb.all_gather_base(output_tensor=output_tensor,
  255. input_tensor=tensor,
  256. group=group,
  257. async_op=async_op)
  258. def has_allgather_base():
  259. global cdb
  260. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  261. assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
  262. return cdb.has_allgather_base
  263. def allgather_fn(output_tensor,
  264. input_tensor,
  265. group=None,
  266. async_op=False,
  267. debug=get_caller_func()):
  268. global cdb
  269. global has_warned_all_gather
  270. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  271. if cdb.has_allgather_base:
  272. return all_gather_base(output_tensor,
  273. input_tensor,
  274. group=group,
  275. async_op=async_op,
  276. debug=debug)
  277. else:
  278. if not has_warned_all_gather and get_rank() == 0:
  279. utils.logger.warning(
  280. "unable to find torch.distributed._all_gather_base. will fall back to "
  281. "torch.distributed.all_gather which will result in suboptimal performance. "
  282. "please consider upgrading your pytorch installation.")
  283. has_warned_all_gather = True
  284. output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
  285. return all_gather(output_tensors,
  286. input_tensor,
  287. group=group,
  288. async_op=async_op,
  289. debug=debug)
  290. @timed_op
  291. def all_to_all_single(output,
  292. tensor,
  293. output_split_sizes=None,
  294. input_split_sizes=None,
  295. group=None,
  296. async_op=False,
  297. prof=False,
  298. log_name='all_to_all_single',
  299. debug=get_caller_func()):
  300. global cdb
  301. return cdb.all_to_all_single(output=output,
  302. input=tensor,
  303. output_split_sizes=output_split_sizes,
  304. input_split_sizes=input_split_sizes,
  305. group=group,
  306. async_op=async_op)
  307. @timed_op
  308. def send(tensor,
  309. dst,
  310. group=None,
  311. tag=0,
  312. prof=False,
  313. log_name='send',
  314. debug=get_caller_func()):
  315. global cdb
  316. return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
  317. @timed_op
  318. def recv(tensor,
  319. src=None,
  320. group=None,
  321. tag=0,
  322. prof=False,
  323. log_name='recv',
  324. debug=get_caller_func()):
  325. global cdb
  326. return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
  327. @timed_op
  328. def isend(tensor,
  329. dst,
  330. group=None,
  331. tag=0,
  332. prof=False,
  333. log_name='isend',
  334. debug=get_caller_func()):
  335. global cdb
  336. return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
  337. @timed_op
  338. def irecv(tensor,
  339. src=None,
  340. group=None,
  341. tag=0,
  342. prof=False,
  343. log_name='irecv',
  344. debug=get_caller_func()):
  345. global cdb
  346. return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
  347. @timed_op
  348. def gather(tensor,
  349. gather_list=None,
  350. dst=0,
  351. group=None,
  352. async_op=False,
  353. prof=False,
  354. log_name='gather',
  355. debug=get_caller_func()):
  356. global cdb
  357. return cdb.gather(tensor=tensor,
  358. gather_list=gather_list,
  359. dst=dst,
  360. group=group,
  361. async_op=async_op)
  362. @timed_op
  363. def scatter(tensor,
  364. scatter_list=None,
  365. src=0,
  366. group=None,
  367. async_op=False,
  368. prof=False,
  369. log_name='scatter',
  370. debug=get_caller_func()):
  371. global cdb
  372. return cdb.scatter(tensor=tensor,
  373. scatter_list=scatter_list,
  374. src=src,
  375. group=group,
  376. async_op=async_op)
  377. @timed_op
  378. def barrier(group=None,
  379. async_op=False,
  380. device_ids=None,
  381. prof=False,
  382. log_name='barrier',
  383. debug=get_caller_func()):
  384. global cdb
  385. return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
  386. @timed_op
  387. def monitored_barrier(group=None,
  388. timeout=None,
  389. wait_all_ranks=False,
  390. prof=False,
  391. log_name='monitored_barrier',
  392. debug=get_caller_func()):
  393. global cdb
  394. return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
  395. def log_summary():
  396. global cdb
  397. barrier(log_name='log_summary_barrier')
  398. if cdb.get_rank() == 0:
  399. comms_logger.log_all()
  400. barrier(log_name='log_summary_barrier')
  401. @timed_op
  402. def reduce(tensor,
  403. dst,
  404. op=ReduceOp.SUM,
  405. group=None,
  406. async_op=False,
  407. prof=False,
  408. log_name='reduce',
  409. debug=get_caller_func()):
  410. global cdb
  411. return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
  412. @timed_op
  413. def reduce_scatter(output,
  414. input_list,
  415. op=ReduceOp.SUM,
  416. group=None,
  417. async_op=False,
  418. prof=False,
  419. log_name='reduce_scatter',
  420. debug=get_caller_func()):
  421. global cdb
  422. return cdb.reduce_scatter(output=output,
  423. input_list=input_list,
  424. op=op,
  425. group=group,
  426. async_op=async_op)
  427. @timed_op
  428. def all_reduce(tensor,
  429. op=ReduceOp.SUM,
  430. group=None,
  431. async_op=False,
  432. prof=False,
  433. log_name='all_reduce',
  434. debug=get_caller_func()):
  435. #if profile_comm:
  436. # context of the timers?
  437. # timers.start()
  438. # TensorBoard logging for comm calls.?
  439. global cdb
  440. #print(f'op = {op}, cdb= {cdb.name}')
  441. return cdb.all_reduce(tensor, op, group, async_op)
  442. def get_world_group():
  443. global cdb
  444. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  445. return cdb.get_world_group()
  446. def get_world_size(group=None) -> int:
  447. """
  448. Returns the number of processes in the current process group
  449. Args:
  450. group (ProcessGroup, optional): The process group to work on. If None,
  451. the default process group will be used.
  452. Returns:
  453. The world size of the process group
  454. -1, if not part of the group
  455. """
  456. global cdb
  457. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  458. return cdb.get_world_size(group)
  459. def get_rank(group=None):
  460. """
  461. Returns the rank of the current process in the provided ``group`` or the
  462. default group if none was provided.
  463. Rank is a unique identifier assigned to each process within a distributed
  464. process group. They are always consecutive integers ranging from 0 to
  465. ``world_size``.
  466. Args:
  467. group (ProcessGroup, optional): The process group to work on. If None,
  468. the default process group will be used.
  469. Returns:
  470. The rank of the process group
  471. -1, if not part of the group
  472. """
  473. global cdb
  474. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  475. return cdb.get_rank(group)
  476. def get_local_rank():
  477. """
  478. Helper function to get local rank after a backend has been set and initialized
  479. Args:
  480. None
  481. Returns:
  482. local rank (= GPU device ID)
  483. """
  484. global cdb
  485. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  486. return get_local_rank_from_launcher()
  487. def get_global_rank(group=None, group_rank=0):
  488. global cdb
  489. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  490. return cdb.get_global_rank(group, group_rank)
  491. # Main DeepSpeed Comms. public API.
  492. def init_distributed(dist_backend=None,
  493. auto_mpi_discovery=True,
  494. distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
  495. verbose=True,
  496. timeout=default_pg_timeout,
  497. init_method=None,
  498. dist_init_required=None,
  499. config=None,
  500. rank=-1,
  501. world_size=-1):
  502. ''' Initialize dist backend, potentially performing MPI discovery if needed
  503. Arguments:
  504. dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
  505. auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
  506. distributed_port: Optional (int). torch distributed backend port
  507. verbose: Optional (bool). verbose logging
  508. timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
  509. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
  510. config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
  511. rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
  512. world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
  513. '''
  514. global cdb
  515. configure(deepspeed_config=config)
  516. if dist_init_required is None:
  517. dist_init_required = cdb is None or not cdb.is_initialized()
  518. if cdb is None and torch.distributed.is_initialized():
  519. # The user initialized torch.dist themselves, create cdb and short-circuit
  520. cdb = TorchBackend(dist_backend, timeout, init_method)
  521. return
  522. if dist_init_required is False:
  523. assert (
  524. cdb is not None and cdb.is_initialized() is True
  525. ), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
  526. else:
  527. # Initialize torch distributed if needed
  528. required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
  529. if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
  530. if verbose:
  531. utils.logger.info(
  532. "Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
  533. )
  534. if in_aml() and not in_dlts():
  535. patch_aml_env_for_torch_nccl_backend(verbose=verbose)
  536. elif in_aws_sm():
  537. patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
  538. else:
  539. mpi_discovery(distributed_port=distributed_port, verbose=verbose)
  540. if cdb is not None and cdb.is_initialized():
  541. if int(os.getenv('RANK', '0')) == 0:
  542. utils.logger.info('Distributed backend already initialized')
  543. else:
  544. assert isinstance(timeout, timedelta)
  545. if dist_backend == None:
  546. dist_backend = get_accelerator().communication_backend_name()
  547. if int(os.getenv('RANK', '0')) == 0:
  548. utils.logger.info(
  549. 'Initializing TorchBackend in DeepSpeed with backend {}'.format(
  550. dist_backend))
  551. # Create a torch backend object, initialize torch distributed, and assign to cdb
  552. cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
  553. def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
  554. '''
  555. Discovery MPI environment via mpi4py and map to relevant dist state
  556. '''
  557. from mpi4py import MPI
  558. import subprocess
  559. comm = MPI.COMM_WORLD
  560. rank = comm.Get_rank()
  561. world_size = comm.Get_size()
  562. master_addr = None
  563. if rank == 0:
  564. hostname_cmd = ["hostname -I"]
  565. result = subprocess.check_output(hostname_cmd, shell=True)
  566. master_addr = result.decode('utf-8').split()[0]
  567. master_addr = comm.bcast(master_addr, root=0)
  568. # Determine local rank by assuming hostnames are unique
  569. proc_name = MPI.Get_processor_name()
  570. all_procs = comm.allgather(proc_name)
  571. local_rank = sum([i == proc_name for i in all_procs[:rank]])
  572. os.environ['RANK'] = str(rank)
  573. os.environ['WORLD_SIZE'] = str(world_size)
  574. os.environ['LOCAL_RANK'] = str(local_rank)
  575. os.environ['MASTER_ADDR'] = master_addr
  576. os.environ['MASTER_PORT'] = str(distributed_port)
  577. if verbose:
  578. utils.logger.info(
  579. "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  580. .format(os.environ['RANK'],
  581. os.environ['LOCAL_RANK'],
  582. os.environ['WORLD_SIZE'],
  583. os.environ['MASTER_ADDR'],
  584. os.environ['MASTER_PORT']))
  585. if cdb is not None and cdb.is_initialized():
  586. assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
  587. rank, cdb.get_rank())
  588. assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
  589. world_size, cdb.get_world_size())
  590. def in_aml():
  591. # Are we running inside an Azure Machine Learning (AML) environment?
  592. return 'AZUREML_EXPERIMENT_ID' in os.environ
  593. def in_aws_sm():
  594. # Are we running inside an AWS SageMaker environment?
  595. return 'SM_TRAINING_ENV' in os.environ
  596. def in_dlts():
  597. # Are we running on a DLTS cluster?
  598. return 'DLTS_JOB_ID' in os.environ
  599. def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
  600. """Helper routine to get and set environment variables.
  601. This is adapted from Azure ML's documentation available from:
  602. https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
  603. """
  604. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  605. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  606. single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
  607. os.environ["WORLD_SIZE"])
  608. if not single_node:
  609. master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
  610. os.environ["MASTER_ADDR"] = master_node_params[0]
  611. # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
  612. if "MASTER_PORT" not in os.environ:
  613. os.environ["MASTER_PORT"] = str(master_port)
  614. else:
  615. os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
  616. os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
  617. if verbose:
  618. utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
  619. os.environ["NCCL_SOCKET_IFNAME"]))
  620. os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
  621. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  622. if verbose:
  623. utils.logger.info(
  624. "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  625. .format(os.environ['RANK'],
  626. os.environ['LOCAL_RANK'],
  627. os.environ['WORLD_SIZE'],
  628. os.environ['MASTER_ADDR'],
  629. os.environ['MASTER_PORT']))
  630. def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
  631. """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
  632. """
  633. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  634. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  635. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  636. if verbose:
  637. utils.logger.info(
  638. "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  639. .format(os.environ['RANK'],
  640. os.environ['LOCAL_RANK'],
  641. os.environ['WORLD_SIZE'],
  642. os.environ['MASTER_ADDR'],
  643. os.environ['MASTER_PORT']))