comm.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. """
  2. Copyright 2021 The Microsoft DeepSpeed Team
  3. DeepSpeed Communication Package: deepspeed.comm
  4. deepspeed.comm
  5. -- import and use deepspeeed.ops.comm
  6. -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
  7. -- use custom collectives
  8. -- can either use torch.dist or ds.ops.comm?
  9. Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
  10. deepspeed.comm API
  11. -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
  12. -- e.g. if a client code used
  13. from deepspeed import comm as dist
  14. instead of
  15. import torch.distributed as dist
  16. The code should work without breaking any of the public torch.distributed functionality
  17. Future:
  18. -- deepspeed groups API should be brought into ds.comm
  19. """
  20. from enum import Enum
  21. import torch
  22. import os
  23. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
  24. from .constants import *
  25. from deepspeed.accelerator import get_accelerator
  26. class ReduceOp(Enum):
  27. SUM = 0
  28. PRODUCT = 1
  29. MIN = 2
  30. MAX = 3
  31. BAND = 4
  32. BOR = 5
  33. BXOR = 6
  34. AVG = 7
  35. UNUSED = 8
  36. from deepspeed.utils.comms_logging import CommsLogger
  37. from deepspeed.utils import timer, get_caller_func
  38. from deepspeed.comm.torch import TorchBackend
  39. from deepspeed import utils
  40. from datetime import timedelta
  41. # Current deepspeed.comm backend (cdb) global object for simple access by client code
  42. use_ds_backend = False
  43. cdb = None
  44. # Create global timer for ops
  45. timers = timer.SynchronizedWallClockTimer()
  46. timer_summary = {}
  47. comms_logger = CommsLogger()
  48. # Maintain objects of all initialized ds backends and assign them using the API functions in this file
  49. nccl_backend = None
  50. mpi_backend = None
  51. # This should be set here so all rank/size information from the launcher can be propagated
  52. from deepspeed.comm.utils import *
  53. def _configure_using_config_file(config):
  54. if config.comms_logger_enabled:
  55. comms_logger.configure(config)
  56. def configure(
  57. deepspeed_config=None,
  58. enabled=None,
  59. prof_all=None,
  60. prof_ops=None,
  61. verbose=None,
  62. debug=None,
  63. ):
  64. if deepspeed_config is not None:
  65. _configure_using_config_file(deepspeed_config.comms_config)
  66. if enabled is not None:
  67. comms_logger.enabled = enabled
  68. if prof_all is not None:
  69. comms_logger.prof_all = prof_all
  70. if prof_ops is not None:
  71. comms_logger.prof_ops = prof_ops
  72. if verbose is not None:
  73. comms_logger.verbose = verbose
  74. if debug is not None:
  75. comms_logger.debug = debug
  76. # Logging wrapper for timing ops
  77. def timed_op(func):
  78. def log_wrapper(*args, **kwargs):
  79. # Add enabled flag so that overhead to each comm op is two if conditions at most
  80. if comms_logger.enabled:
  81. if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
  82. 'log_name' in kwargs
  83. and kwargs['log_name'] in comms_logger.prof_ops):
  84. # Need func args for their defaults
  85. func_args = get_default_args(func)
  86. func_args.update(kwargs)
  87. msg_size = get_msg_size_from_args(func, *args, **kwargs)
  88. log_name = get_debug_log_name(func_args, comms_logger.debug)
  89. timers(log_name).start()
  90. # Return the op, then stop the op's timer
  91. try:
  92. return func(*args, **kwargs)
  93. finally:
  94. if comms_logger.enabled:
  95. # Need to make op blocking for accurate logging
  96. get_accelerator().synchronize()
  97. # If we're using MPI, we can't simply sync the stream
  98. if cdb.using_mpi:
  99. cdb.barrier()
  100. if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
  101. 'log_name' in kwargs
  102. and kwargs['log_name'] in comms_logger.prof_ops):
  103. log_name = get_debug_log_name(func_args, comms_logger.debug)
  104. raw_name = func.__name__
  105. timers(log_name).stop()
  106. # need temp var since 'elapsed' resets events
  107. time_elapsed = timers(log_name).elapsed(reset=False)
  108. comms_logger.append(raw_name, log_name, time_elapsed, msg_size)
  109. return log_wrapper
  110. # For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
  111. # DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
  112. # Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
  113. # UNUSED: Future helper function to initialize DS backends
  114. def init_deepspeed_backend(ds_backend):
  115. global cdb
  116. global nccl_backend
  117. global mpi_backend
  118. global use_ds_backend
  119. if ds_backend == NCCL_BACKEND:
  120. utils.logger.warn("NCCL backend in DeepSpeed not yet implemented")
  121. elif ds_backend == MPI_BACKEND:
  122. utils.logger.warn("MPI backend in DeepSpeed not yet implemented")
  123. elif ds_backend == GLOO_BACKEND:
  124. utils.logger.warn("Gloo backend in DeepSpeed not yet implemented")
  125. else:
  126. utils.logger.warn(f"DeepSpeed does not support {ds_backend} backend")
  127. def is_initialized():
  128. #assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
  129. if cdb is None:
  130. return False
  131. else:
  132. return cdb.is_initialized()
  133. def destroy_process_group(group=None):
  134. global cdb
  135. return cdb.destroy_process_group(group=group)
  136. def new_group(ranks):
  137. global cdb
  138. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  139. return cdb.new_group(ranks)
  140. def is_available() -> bool:
  141. # Returns ``True`` if the deepspeed comm package is available.
  142. # TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
  143. # any communication related primitives from this package.
  144. # use hasattr(deepspeed.csrc.ops, "_comm") or something
  145. return True
  146. def set_backend(backend):
  147. if not use_ds_backend:
  148. utils.logger.error(
  149. "DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
  150. )
  151. raise RuntimeError(
  152. 'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
  153. )
  154. global cdb
  155. global nccl_backend
  156. global mpi_backend
  157. try:
  158. if backend_name == NCCL_BACKEND:
  159. if nccl_backend is not None and nccl_backend.is_initialized():
  160. cdb = nccl_backend
  161. elif backend_name == MPI_BACKEND:
  162. if mpi_backend is not None and mpi_backend.is_initialized():
  163. cdb = mpi_backend
  164. except Exception as inst:
  165. print(inst)
  166. @timed_op
  167. def broadcast(tensor,
  168. src,
  169. group=None,
  170. async_op=False,
  171. prof=False,
  172. log_name='broadcast',
  173. debug=get_caller_func()):
  174. global cdb
  175. return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
  176. @timed_op
  177. def all_gather(tensor_list,
  178. tensor,
  179. group=None,
  180. async_op=False,
  181. prof=False,
  182. log_name='all_gather',
  183. debug=get_caller_func()):
  184. global cdb
  185. return cdb.all_gather(tensor_list=tensor_list,
  186. tensor=tensor,
  187. group=group,
  188. async_op=async_op)
  189. def has_reduce_scatter_base():
  190. global cdb
  191. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  192. assert cdb.has_reduce_scatter_base is not None, 'has_reduce_scatter_base is not yet defined'
  193. return cdb.has_reduce_scatter_base
  194. def reduce_scatter_fn(output_tensor,
  195. tensor,
  196. op=ReduceOp.SUM,
  197. group=None,
  198. async_op=False,
  199. prof=False,
  200. debug=get_caller_func()):
  201. global cdb
  202. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  203. if cdb.has_reduce_scatter_base:
  204. return reduce_scatter_base(output_tensor,
  205. tensor,
  206. op=op,
  207. group=group,
  208. async_op=async_op,
  209. prof=prof,
  210. debug=debug)
  211. else:
  212. utils.logger.warning_once(
  213. "unable to find torch.distributed._reduce_scatter_base. will fall back to "
  214. "torch.distributed.all_gather which will result in suboptimal performance. "
  215. "please consider upgrading your pytorch installation.")
  216. input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
  217. return reduce_scatter(output_tensor,
  218. input_tensor_lst,
  219. op=op,
  220. group=group,
  221. async_op=async_op,
  222. prof=prof,
  223. debug=debug)
  224. @timed_op
  225. def reduce_scatter_base(output_tensor,
  226. tensor,
  227. op=ReduceOp.SUM,
  228. group=None,
  229. async_op=False,
  230. prof=False,
  231. log_name='reduce_scatter_base',
  232. debug=get_caller_func()):
  233. global cdb
  234. return cdb.reduce_scatter_base(output_tensor=output_tensor,
  235. input_tensor=tensor,
  236. op=op,
  237. group=group,
  238. async_op=async_op)
  239. @timed_op
  240. def all_gather_base(output_tensor,
  241. tensor,
  242. group=None,
  243. async_op=False,
  244. prof=False,
  245. log_name='all_gather_base',
  246. debug=get_caller_func()):
  247. global cdb
  248. return cdb.all_gather_base(output_tensor=output_tensor,
  249. input_tensor=tensor,
  250. group=group,
  251. async_op=async_op)
  252. def has_allgather_base():
  253. global cdb
  254. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  255. assert cdb.has_allgather_base is not None, 'has_allgather_base is not yet defined'
  256. return cdb.has_allgather_base
  257. def allgather_fn(output_tensor,
  258. input_tensor,
  259. group=None,
  260. async_op=False,
  261. debug=get_caller_func()):
  262. global cdb
  263. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  264. if cdb.has_allgather_base:
  265. return all_gather_base(output_tensor,
  266. input_tensor,
  267. group=group,
  268. async_op=async_op,
  269. debug=debug)
  270. else:
  271. if get_rank() == 0:
  272. utils.logger.warning_once(
  273. "unable to find torch.distributed._all_gather_base. will fall back to "
  274. "torch.distributed.all_gather which will result in suboptimal performance. "
  275. "please consider upgrading your pytorch installation.")
  276. output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
  277. return all_gather(output_tensors,
  278. input_tensor,
  279. group=group,
  280. async_op=async_op,
  281. debug=debug)
  282. @timed_op
  283. def all_to_all_single(output,
  284. tensor,
  285. output_split_sizes=None,
  286. input_split_sizes=None,
  287. group=None,
  288. async_op=False,
  289. prof=False,
  290. log_name='all_to_all_single',
  291. debug=get_caller_func()):
  292. global cdb
  293. return cdb.all_to_all_single(output=output,
  294. input=tensor,
  295. output_split_sizes=output_split_sizes,
  296. input_split_sizes=input_split_sizes,
  297. group=group,
  298. async_op=async_op)
  299. @timed_op
  300. def send(tensor,
  301. dst,
  302. group=None,
  303. tag=0,
  304. prof=False,
  305. log_name='send',
  306. debug=get_caller_func()):
  307. global cdb
  308. return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
  309. @timed_op
  310. def recv(tensor,
  311. src=None,
  312. group=None,
  313. tag=0,
  314. prof=False,
  315. log_name='recv',
  316. debug=get_caller_func()):
  317. global cdb
  318. return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
  319. @timed_op
  320. def isend(tensor,
  321. dst,
  322. group=None,
  323. tag=0,
  324. prof=False,
  325. log_name='isend',
  326. debug=get_caller_func()):
  327. global cdb
  328. return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
  329. @timed_op
  330. def irecv(tensor,
  331. src=None,
  332. group=None,
  333. tag=0,
  334. prof=False,
  335. log_name='irecv',
  336. debug=get_caller_func()):
  337. global cdb
  338. return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
  339. @timed_op
  340. def gather(tensor,
  341. gather_list=None,
  342. dst=0,
  343. group=None,
  344. async_op=False,
  345. prof=False,
  346. log_name='gather',
  347. debug=get_caller_func()):
  348. global cdb
  349. return cdb.gather(tensor=tensor,
  350. gather_list=gather_list,
  351. dst=dst,
  352. group=group,
  353. async_op=async_op)
  354. @timed_op
  355. def scatter(tensor,
  356. scatter_list=None,
  357. src=0,
  358. group=None,
  359. async_op=False,
  360. prof=False,
  361. log_name='scatter',
  362. debug=get_caller_func()):
  363. global cdb
  364. return cdb.scatter(tensor=tensor,
  365. scatter_list=scatter_list,
  366. src=src,
  367. group=group,
  368. async_op=async_op)
  369. @timed_op
  370. def barrier(group=None,
  371. async_op=False,
  372. device_ids=None,
  373. prof=False,
  374. log_name='barrier',
  375. debug=get_caller_func()):
  376. global cdb
  377. return cdb.barrier(group=group, async_op=async_op, device_ids=device_ids)
  378. @timed_op
  379. def monitored_barrier(group=None,
  380. timeout=None,
  381. wait_all_ranks=False,
  382. prof=False,
  383. log_name='monitored_barrier',
  384. debug=get_caller_func()):
  385. global cdb
  386. return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
  387. def log_summary():
  388. global cdb
  389. barrier(log_name='log_summary_barrier')
  390. if cdb.get_rank() == 0:
  391. comms_logger.log_all()
  392. barrier(log_name='log_summary_barrier')
  393. @timed_op
  394. def reduce(tensor,
  395. dst,
  396. op=ReduceOp.SUM,
  397. group=None,
  398. async_op=False,
  399. prof=False,
  400. log_name='reduce',
  401. debug=get_caller_func()):
  402. global cdb
  403. return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
  404. @timed_op
  405. def reduce_scatter(output,
  406. input_list,
  407. op=ReduceOp.SUM,
  408. group=None,
  409. async_op=False,
  410. prof=False,
  411. log_name='reduce_scatter',
  412. debug=get_caller_func()):
  413. global cdb
  414. return cdb.reduce_scatter(output=output,
  415. input_list=input_list,
  416. op=op,
  417. group=group,
  418. async_op=async_op)
  419. @timed_op
  420. def all_reduce(tensor,
  421. op=ReduceOp.SUM,
  422. group=None,
  423. async_op=False,
  424. prof=False,
  425. log_name='all_reduce',
  426. debug=get_caller_func()):
  427. #if profile_comm:
  428. # context of the timers?
  429. # timers.start()
  430. # TensorBoard logging for comm calls.?
  431. global cdb
  432. #print(f'op = {op}, cdb= {cdb.name}')
  433. return cdb.all_reduce(tensor, op, group, async_op)
  434. def get_world_group():
  435. global cdb
  436. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  437. return cdb.get_world_group()
  438. def get_world_size(group=None) -> int:
  439. """
  440. Returns the number of processes in the current process group
  441. Args:
  442. group (ProcessGroup, optional): The process group to work on. If None,
  443. the default process group will be used.
  444. Returns:
  445. The world size of the process group
  446. -1, if not part of the group
  447. """
  448. global cdb
  449. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  450. return cdb.get_world_size(group)
  451. def get_rank(group=None):
  452. """
  453. Returns the rank of the current process in the provided ``group`` or the
  454. default group if none was provided.
  455. Rank is a unique identifier assigned to each process within a distributed
  456. process group. They are always consecutive integers ranging from 0 to
  457. ``world_size``.
  458. Args:
  459. group (ProcessGroup, optional): The process group to work on. If None,
  460. the default process group will be used.
  461. Returns:
  462. The rank of the process group
  463. -1, if not part of the group
  464. """
  465. global cdb
  466. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  467. return cdb.get_rank(group)
  468. def get_local_rank():
  469. """
  470. Helper function to get local rank after a backend has been set and initialized
  471. Args:
  472. None
  473. Returns:
  474. local rank (= GPU device ID)
  475. """
  476. global cdb
  477. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  478. return get_local_rank_from_launcher()
  479. def get_global_rank(group=None, group_rank=0):
  480. global cdb
  481. assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  482. return cdb.get_global_rank(group, group_rank)
  483. # Main DeepSpeed Comms. public API.
  484. def init_distributed(dist_backend=None,
  485. auto_mpi_discovery=True,
  486. distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
  487. verbose=True,
  488. timeout=default_pg_timeout,
  489. init_method=None,
  490. dist_init_required=None,
  491. config=None,
  492. rank=-1,
  493. world_size=-1):
  494. ''' Initialize dist backend, potentially performing MPI discovery if needed
  495. Arguments:
  496. dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo
  497. auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
  498. distributed_port: Optional (int). torch distributed backend port
  499. verbose: Optional (bool). verbose logging
  500. timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
  501. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
  502. config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
  503. rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
  504. world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
  505. '''
  506. global cdb
  507. configure(deepspeed_config=config)
  508. if dist_init_required is None:
  509. dist_init_required = cdb is None or not cdb.is_initialized()
  510. if cdb is None and torch.distributed.is_initialized():
  511. # The user initialized torch.dist themselves, create cdb and short-circuit
  512. cdb = TorchBackend(dist_backend, timeout, init_method)
  513. return
  514. if dist_init_required is False:
  515. assert (
  516. cdb is not None and cdb.is_initialized() is True
  517. ), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
  518. else:
  519. # Initialize torch distributed if needed
  520. required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
  521. if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
  522. if verbose:
  523. utils.logger.info(
  524. "Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
  525. )
  526. if in_aml() and not in_dlts():
  527. patch_aml_env_for_torch_nccl_backend(verbose=verbose)
  528. elif in_aws_sm():
  529. patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
  530. else:
  531. mpi_discovery(distributed_port=distributed_port, verbose=verbose)
  532. if cdb is not None and cdb.is_initialized():
  533. if int(os.getenv('RANK', '0')) == 0:
  534. utils.logger.info('Distributed backend already initialized')
  535. else:
  536. assert isinstance(timeout, timedelta)
  537. if dist_backend == None:
  538. dist_backend = get_accelerator().communication_backend_name()
  539. if int(os.getenv('RANK', '0')) == 0:
  540. utils.logger.info(
  541. 'Initializing TorchBackend in DeepSpeed with backend {}'.format(
  542. dist_backend))
  543. # Create a torch backend object, initialize torch distributed, and assign to cdb
  544. cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
  545. def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
  546. '''
  547. Discovery MPI environment via mpi4py and map to relevant dist state
  548. '''
  549. from mpi4py import MPI
  550. import subprocess
  551. comm = MPI.COMM_WORLD
  552. rank = comm.Get_rank()
  553. world_size = comm.Get_size()
  554. master_addr = None
  555. if rank == 0:
  556. hostname_cmd = ["hostname -I"]
  557. result = subprocess.check_output(hostname_cmd, shell=True)
  558. master_addr = result.decode('utf-8').split()[0]
  559. master_addr = comm.bcast(master_addr, root=0)
  560. # Determine local rank by assuming hostnames are unique
  561. proc_name = MPI.Get_processor_name()
  562. all_procs = comm.allgather(proc_name)
  563. local_rank = sum([i == proc_name for i in all_procs[:rank]])
  564. os.environ['RANK'] = str(rank)
  565. os.environ['WORLD_SIZE'] = str(world_size)
  566. os.environ['LOCAL_RANK'] = str(local_rank)
  567. os.environ['MASTER_ADDR'] = master_addr
  568. os.environ['MASTER_PORT'] = str(distributed_port)
  569. if verbose:
  570. utils.logger.info(
  571. "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  572. .format(os.environ['RANK'],
  573. os.environ['LOCAL_RANK'],
  574. os.environ['WORLD_SIZE'],
  575. os.environ['MASTER_ADDR'],
  576. os.environ['MASTER_PORT']))
  577. if cdb is not None and cdb.is_initialized():
  578. assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(
  579. rank, cdb.get_rank())
  580. assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
  581. world_size, cdb.get_world_size())
  582. def in_aml():
  583. # Are we running inside an Azure Machine Learning (AML) environment?
  584. return 'AZUREML_EXPERIMENT_ID' in os.environ
  585. def in_aws_sm():
  586. # Are we running inside an AWS SageMaker environment?
  587. return 'SM_TRAINING_ENV' in os.environ
  588. def in_dlts():
  589. # Are we running on a DLTS cluster?
  590. return 'DLTS_JOB_ID' in os.environ
  591. def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
  592. """Helper routine to get and set environment variables.
  593. This is adapted from Azure ML's documentation available from:
  594. https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
  595. """
  596. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  597. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  598. single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(
  599. os.environ["WORLD_SIZE"])
  600. if not single_node:
  601. master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
  602. os.environ["MASTER_ADDR"] = master_node_params[0]
  603. # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
  604. if "MASTER_PORT" not in os.environ:
  605. os.environ["MASTER_PORT"] = str(master_port)
  606. else:
  607. os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
  608. os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
  609. if verbose:
  610. utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(
  611. os.environ["NCCL_SOCKET_IFNAME"]))
  612. os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
  613. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  614. if verbose:
  615. utils.logger.info(
  616. "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  617. .format(os.environ['RANK'],
  618. os.environ['LOCAL_RANK'],
  619. os.environ['WORLD_SIZE'],
  620. os.environ['MASTER_ADDR'],
  621. os.environ['MASTER_PORT']))
  622. def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
  623. """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
  624. """
  625. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  626. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  627. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  628. if verbose:
  629. utils.logger.info(
  630. "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  631. .format(os.environ['RANK'],
  632. os.environ['LOCAL_RANK'],
  633. os.environ['WORLD_SIZE'],
  634. os.environ['MASTER_ADDR'],
  635. os.environ['MASTER_PORT']))