comm.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. """
  5. DeepSpeed Communication Package: deepspeed.comm
  6. deepspeed.comm
  7. -- import and use deepspeed.ops.comm
  8. -- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
  9. -- use custom collectives
  10. -- can either use torch.dist or ds.ops.comm?
  11. Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
  12. deepspeed.comm API
  13. -- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
  14. -- e.g. if a client code used
  15. from deepspeed import comm as dist
  16. instead of
  17. import torch.distributed as dist
  18. The code should work without breaking any of the public torch.distributed functionality
  19. Future:
  20. -- deepspeed groups API should be brought into ds.comm
  21. """
  22. import torch
  23. import os
  24. from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout
  25. from .constants import *
  26. from .reduce_op import *
  27. from deepspeed.accelerator import get_accelerator
  28. from deepspeed.comm.ccl import CCLBackend
  29. from deepspeed.utils.comms_logging import CommsLogger
  30. from deepspeed.utils import timer, get_caller_func
  31. from deepspeed.comm.torch import TorchBackend
  32. from deepspeed import utils
  33. from datetime import timedelta
  34. # Current deepspeed.comm backend (cdb) global object for simple access by client code
  35. cdb = None
  36. # Create global timer for ops
  37. timers = timer.SynchronizedWallClockTimer()
  38. timer_summary = {}
  39. comms_logger = CommsLogger()
  40. # Maintain objects of all initialized ds backends and assign them using the API functions in this file
  41. nccl_backend = None
  42. mpi_backend = None
  43. ccl_backend = None
  44. hccl_backend = None
  45. # This should be set here so all rank/size information from the launcher can be propagated
  46. from deepspeed.comm.utils import *
  47. class ProcessGroup():
  48. def __init__(self, comm_id, ranks=[]):
  49. self.ranks = ranks
  50. self.comm_id = comm_id
  51. self.size = len(ranks)
  52. def _configure_using_config_file(config):
  53. if config.comms_logger_enabled:
  54. comms_logger.configure(config)
  55. def configure(
  56. deepspeed_config=None,
  57. enabled=None,
  58. prof_all=None,
  59. prof_ops=None,
  60. verbose=None,
  61. debug=None,
  62. ):
  63. if deepspeed_config is not None:
  64. _configure_using_config_file(deepspeed_config.comms_config)
  65. if enabled is not None:
  66. comms_logger.enabled = enabled
  67. if prof_all is not None:
  68. comms_logger.prof_all = prof_all
  69. if prof_ops is not None:
  70. comms_logger.prof_ops = prof_ops
  71. if verbose is not None:
  72. comms_logger.verbose = verbose
  73. if debug is not None:
  74. comms_logger.debug = debug
  75. # Logging wrapper for timing ops
  76. def timed_op(func):
  77. def log_wrapper(*args, **kwargs):
  78. # Add enabled flag so that overhead to each comm op is two if conditions at most
  79. if comms_logger.enabled:
  80. if ('prof' in kwargs
  81. and kwargs['prof']) or comms_logger.prof_all or ('log_name' in kwargs
  82. and kwargs['log_name'] in comms_logger.prof_ops):
  83. # Need func args for their defaults
  84. func_args = get_default_args(func)
  85. func_args.update(kwargs)
  86. msg_size = get_msg_size_from_args(func, *args, **kwargs)
  87. log_name = get_debug_log_name(func_args, comms_logger.debug)
  88. timers(log_name).start()
  89. # Return the op, then stop the op's timer
  90. try:
  91. return func(*args, **kwargs)
  92. finally:
  93. if comms_logger.enabled:
  94. # Need to make op blocking for accurate logging
  95. get_accelerator().synchronize()
  96. # If we're using MPI, we can't simply sync the stream
  97. if cdb.using_mpi:
  98. cdb.barrier()
  99. if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or (
  100. 'log_name' in kwargs and kwargs['log_name'] in comms_logger.prof_ops):
  101. log_name = get_debug_log_name(func_args, comms_logger.debug)
  102. raw_name = func.__name__
  103. timers(log_name).stop()
  104. # need temp var since 'elapsed' resets events
  105. time_elapsed = timers(log_name).elapsed(reset=False)
  106. comms_logger.append(raw_name, log_name, time_elapsed, msg_size)
  107. return log_wrapper
  108. # For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code.
  109. # DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation.
  110. # Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html
  111. # UNUSED: Future helper function to initialize DS backends
  112. def init_deepspeed_backend(ds_backend, timeout, init_method):
  113. global cdb
  114. global nccl_backend
  115. global mpi_backend
  116. global ccl_backend
  117. global hccl_backend
  118. rank = int(os.getenv('RANK', '-1'))
  119. size = int(os.getenv('WORLD_SIZE', '-1'))
  120. if ds_backend == NCCL_BACKEND:
  121. utils.logger.debug("NCCL backend in DeepSpeed not yet implemented")
  122. elif ds_backend == MPI_BACKEND:
  123. utils.logger.debug("MPI backend in DeepSpeed not yet implemented")
  124. elif ds_backend == GLOO_BACKEND:
  125. utils.logger.debug("Gloo backend in DeepSpeed not yet implemented")
  126. elif ds_backend == CCL_BACKEND:
  127. ccl_backend = CCLBackend(rank=rank, world_size=size, timeout=timeout, init_method=init_method)
  128. utils.logger.info(f"Initialize {ds_backend} backend")
  129. elif ds_backend == HCCL_BACKEND:
  130. utils.logger.debug("HCCL backend in DeepSpeed not yet implemented")
  131. else:
  132. utils.logger.debug(f"DeepSpeed does not support {ds_backend} backend")
  133. def is_initialized():
  134. #assert cdb is not None, 'DeepSpeed backend not set, please initialize it using init_process_group()'
  135. if cdb is None:
  136. return False
  137. else:
  138. return cdb.is_initialized()
  139. def destroy_process_group(group=None):
  140. global cdb
  141. return cdb.destroy_process_group(group=group)
  142. def new_group(ranks):
  143. global cdb
  144. assert cdb is not None and cdb.is_initialized(
  145. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  146. return cdb.new_group(ranks)
  147. def is_available() -> bool:
  148. # Returns ``True`` if the deepspeed comm package is available.
  149. # TODO: load other ops. Clients including deepspeed itself should use deepspeed.comm to import
  150. # any communication related primitives from this package.
  151. # use hasattr(deepspeed.csrc.ops, "_comm") or something
  152. return True
  153. def set_backend():
  154. global cdb
  155. global nccl_backend
  156. global mpi_backend
  157. global ccl_backend
  158. global hccl_backend
  159. backend_name = get_accelerator().communication_backend_name()
  160. if backend_name == NCCL_BACKEND:
  161. if nccl_backend is not None and nccl_backend.is_initialized():
  162. cdb = nccl_backend
  163. elif backend_name == MPI_BACKEND:
  164. if mpi_backend is not None and mpi_backend.is_initialized():
  165. cdb = mpi_backend
  166. elif backend_name == CCL_BACKEND:
  167. if ccl_backend is not None and ccl_backend.is_initialized():
  168. cdb = ccl_backend
  169. elif backend_name == HCCL_BACKEND:
  170. if hccl_backend is not None and hccl_backend.is_initialized():
  171. cdb = hccl_backend
  172. @timed_op
  173. def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='broadcast', debug=get_caller_func()):
  174. global cdb
  175. return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
  176. @timed_op
  177. def all_gather(tensor_list,
  178. tensor,
  179. group=None,
  180. async_op=False,
  181. prof=False,
  182. log_name='all_gather',
  183. debug=get_caller_func()):
  184. global cdb
  185. return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
  186. def has_reduce_scatter_tensor():
  187. global cdb
  188. assert cdb is not None and cdb.is_initialized(
  189. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  190. return cdb.has_reduce_scatter_tensor()
  191. def reduce_scatter_fn(output_tensor,
  192. tensor,
  193. op=ReduceOp.SUM,
  194. group=None,
  195. async_op=False,
  196. prof=False,
  197. debug=get_caller_func()):
  198. global cdb
  199. assert cdb is not None and cdb.is_initialized(
  200. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  201. if cdb.has_reduce_scatter_tensor():
  202. return reduce_scatter_tensor(output_tensor,
  203. tensor,
  204. op=op,
  205. group=group,
  206. async_op=async_op,
  207. prof=prof,
  208. debug=debug)
  209. else:
  210. if get_rank() == 0:
  211. utils.logger.warning_once("unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
  212. "torch.distributed.reduce_scatter which will result in suboptimal performance. "
  213. "please consider upgrading your pytorch installation.")
  214. input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group)))
  215. return reduce_scatter(output_tensor,
  216. input_tensor_lst,
  217. op=op,
  218. group=group,
  219. async_op=async_op,
  220. prof=prof,
  221. debug=debug)
  222. @timed_op
  223. def reduce_scatter_tensor(output_tensor,
  224. tensor,
  225. op=ReduceOp.SUM,
  226. group=None,
  227. async_op=False,
  228. prof=False,
  229. log_name='reduce_scatter_tensor',
  230. debug=get_caller_func()):
  231. global cdb
  232. return cdb.reduce_scatter_tensor(output_tensor=output_tensor,
  233. input_tensor=tensor,
  234. op=op,
  235. group=group,
  236. async_op=async_op)
  237. @timed_op
  238. def all_gather_into_tensor(output_tensor,
  239. tensor,
  240. group=None,
  241. async_op=False,
  242. prof=False,
  243. log_name='all_gather_into_tensor',
  244. debug=get_caller_func()):
  245. global cdb
  246. return cdb.all_gather_into_tensor(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
  247. def has_all_gather_into_tensor():
  248. global cdb
  249. assert cdb is not None and cdb.is_initialized(
  250. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  251. return cdb.has_all_gather_into_tensor()
  252. def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()):
  253. global cdb
  254. assert cdb is not None and cdb.is_initialized(
  255. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  256. if cdb.has_all_gather_into_tensor():
  257. return all_gather_into_tensor(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug)
  258. else:
  259. if get_rank() == 0:
  260. utils.logger.warning_once("unable to find torch.distributed.all_gather_into_tensor. will fall back to "
  261. "torch.distributed.all_gather which will result in suboptimal performance. "
  262. "please consider upgrading your pytorch installation.")
  263. output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group)))
  264. return all_gather(output_tensors, input_tensor, group=group, async_op=async_op, debug=debug)
  265. @timed_op
  266. def all_to_all_single(output,
  267. tensor,
  268. output_split_sizes=None,
  269. input_split_sizes=None,
  270. group=None,
  271. async_op=False,
  272. prof=False,
  273. log_name='all_to_all_single',
  274. debug=get_caller_func()):
  275. global cdb
  276. return cdb.all_to_all_single(output=output,
  277. input=tensor,
  278. output_split_sizes=output_split_sizes,
  279. input_split_sizes=input_split_sizes,
  280. group=group,
  281. async_op=async_op)
  282. @timed_op
  283. def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False):
  284. global cdb
  285. return cdb.all_to_all(output_tensor_list, input_tensor_list, group=group, async_op=async_op)
  286. @timed_op
  287. def send(tensor, dst, group=None, tag=0, prof=False, log_name='send', debug=get_caller_func()):
  288. global cdb
  289. return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
  290. @timed_op
  291. def recv(tensor, src=None, group=None, tag=0, prof=False, log_name='recv', debug=get_caller_func()):
  292. global cdb
  293. return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
  294. @timed_op
  295. def isend(tensor, dst, group=None, tag=0, prof=False, log_name='isend', debug=get_caller_func()):
  296. global cdb
  297. return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
  298. @timed_op
  299. def irecv(tensor, src=None, group=None, tag=0, prof=False, log_name='irecv', debug=get_caller_func()):
  300. global cdb
  301. return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
  302. @timed_op
  303. def gather(tensor,
  304. gather_list=None,
  305. dst=0,
  306. group=None,
  307. async_op=False,
  308. prof=False,
  309. log_name='gather',
  310. debug=get_caller_func()):
  311. global cdb
  312. return cdb.gather(tensor=tensor, gather_list=gather_list, dst=dst, group=group, async_op=async_op)
  313. @timed_op
  314. def scatter(tensor,
  315. scatter_list=None,
  316. src=0,
  317. group=None,
  318. async_op=False,
  319. prof=False,
  320. log_name='scatter',
  321. debug=get_caller_func()):
  322. global cdb
  323. return cdb.scatter(tensor=tensor, scatter_list=scatter_list, src=src, group=group, async_op=async_op)
  324. @timed_op
  325. def barrier(group=None, async_op=False, device_ids=None, prof=False, log_name='barrier', debug=get_caller_func()):
  326. global cdb
  327. return cdb.barrier(group=group, async_op=async_op)
  328. @timed_op
  329. def monitored_barrier(group=None,
  330. timeout=None,
  331. wait_all_ranks=False,
  332. prof=False,
  333. log_name='monitored_barrier',
  334. debug=get_caller_func()):
  335. global cdb
  336. return cdb.monitored_barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)
  337. def log_summary(show_straggler=False):
  338. global cdb
  339. barrier(log_name='log_summary_barrier')
  340. if cdb.get_rank() == 0:
  341. comms_logger.log_all(print_log=True, show_straggler=show_straggler)
  342. else:
  343. comms_logger.log_all(print_log=False, show_straggler=show_straggler)
  344. barrier(log_name='log_summary_barrier')
  345. @timed_op
  346. def reduce(tensor,
  347. dst,
  348. op=ReduceOp.SUM,
  349. group=None,
  350. async_op=False,
  351. prof=False,
  352. log_name='reduce',
  353. debug=get_caller_func()):
  354. global cdb
  355. return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
  356. @timed_op
  357. def reduce_scatter(output,
  358. input_list,
  359. op=ReduceOp.SUM,
  360. group=None,
  361. async_op=False,
  362. prof=False,
  363. log_name='reduce_scatter',
  364. debug=get_caller_func()):
  365. global cdb
  366. return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
  367. def has_all_reduce_coalesced():
  368. """"""
  369. global cdb
  370. assert cdb is not None and cdb.is_initialized(
  371. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  372. assert cdb.has_all_reduce_coalesced is not None, 'has_all_reduce_coalesced is not yet defined'
  373. return cdb.has_all_reduce_coalesced
  374. def has_coalescing_manager():
  375. global cdb
  376. assert cdb is not None and cdb.is_initialized(
  377. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  378. assert cdb.has_coalescing_manager is not None, 'has_coalescing_manager is not yet defined'
  379. return cdb.has_coalescing_manager
  380. def all_gather_coalesced(output_tensors, input_tensors, group=None, async_op=False):
  381. global cdb
  382. assert cdb is not None and cdb.is_initialized(
  383. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  384. return cdb.all_gather_coalesced(output_tensors, input_tensors, group=group, async_op=async_op)
  385. @timed_op
  386. def all_reduce(tensor,
  387. op=ReduceOp.SUM,
  388. group=None,
  389. async_op=False,
  390. prof=False,
  391. log_name='all_reduce',
  392. debug=get_caller_func()):
  393. #if profile_comm:
  394. # context of the timers?
  395. # timers.start()
  396. # TensorBoard logging for comm calls.?
  397. global cdb
  398. #print(f'op = {op}, cdb= {cdb.name}')
  399. return cdb.all_reduce(tensor, op, group, async_op)
  400. @timed_op
  401. def inference_all_reduce(tensor,
  402. op=ReduceOp.SUM,
  403. group=None,
  404. async_op=False,
  405. prof=False,
  406. log_name='all_reduce',
  407. debug=get_caller_func()):
  408. global cdb
  409. return cdb.inference_all_reduce(tensor, op, group)
  410. @timed_op
  411. def all_reduce_coalesced(tensors,
  412. op=ReduceOp.SUM,
  413. group=None,
  414. async_op=False,
  415. prof=False,
  416. log_name='all_reduce',
  417. debug=get_caller_func()):
  418. global cdb
  419. return cdb.all_reduce_coalesced(tensors, op, group, async_op)
  420. def get_world_group():
  421. global cdb
  422. assert cdb is not None and cdb.is_initialized(
  423. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  424. return cdb.get_world_group()
  425. def get_world_size(group=None) -> int:
  426. """
  427. Returns the number of processes in the current process group
  428. Args:
  429. group (ProcessGroup, optional): The process group to work on. If None,
  430. the default process group will be used.
  431. Returns:
  432. The world size of the process group
  433. -1, if not part of the group
  434. """
  435. global cdb
  436. assert cdb is not None and cdb.is_initialized(
  437. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  438. return cdb.get_world_size(group)
  439. def get_rank(group=None):
  440. """
  441. Returns the rank of the current process in the provided ``group`` or the
  442. default group if none was provided.
  443. Rank is a unique identifier assigned to each process within a distributed
  444. process group. They are always consecutive integers ranging from 0 to
  445. ``world_size``.
  446. Args:
  447. group (ProcessGroup, optional): The process group to work on. If None,
  448. the default process group will be used.
  449. Returns:
  450. The rank of the process group
  451. -1, if not part of the group
  452. """
  453. global cdb
  454. assert cdb is not None and cdb.is_initialized(
  455. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  456. return cdb.get_rank(group)
  457. def get_local_rank():
  458. """
  459. Helper function to get local rank after a backend has been set and initialized
  460. Args:
  461. None
  462. Returns:
  463. local rank (= GPU device ID)
  464. """
  465. global cdb
  466. assert cdb is not None and cdb.is_initialized(
  467. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  468. return get_local_rank_from_launcher()
  469. def get_global_rank(group=None, group_rank=0):
  470. global cdb
  471. assert cdb is not None and cdb.is_initialized(
  472. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  473. return cdb.get_global_rank(group, group_rank)
  474. def get_all_ranks_from_group(group=None):
  475. global cdb
  476. assert cdb is not None and cdb.is_initialized(
  477. ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
  478. rank = 0
  479. group_ranks = []
  480. try:
  481. while True:
  482. group_ranks.append(cdb.get_global_rank(group, rank))
  483. rank += 1
  484. except (RuntimeError, ValueError):
  485. pass
  486. return group_ranks
  487. # Main DeepSpeed Comms. public API.
  488. def init_distributed(dist_backend=None,
  489. auto_mpi_discovery=True,
  490. distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT,
  491. verbose=True,
  492. timeout=default_pg_timeout,
  493. init_method=None,
  494. dist_init_required=None,
  495. config=None,
  496. rank=-1,
  497. world_size=-1):
  498. ''' Initialize dist backend, potentially performing MPI discovery if needed
  499. Arguments:
  500. dist_backend: Optional (str). torch distributed backend, e.g., nccl, mpi, gloo, hccl
  501. auto_mpi_discovery Optional (bool). if distributed environment variables are not set, attempt to discover them from MPI
  502. distributed_port: Optional (int). torch distributed backend port
  503. verbose: Optional (bool). verbose logging
  504. timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
  505. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
  506. config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
  507. rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
  508. world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
  509. '''
  510. global cdb
  511. configure(deepspeed_config=config)
  512. if dist_init_required is None:
  513. dist_init_required = cdb is None or not cdb.is_initialized()
  514. if cdb is None:
  515. init_deepspeed_backend(get_accelerator().communication_backend_name(), timeout, init_method)
  516. set_backend()
  517. utils.logger.info(f'cdb={cdb}')
  518. if cdb is None and torch.distributed.is_initialized():
  519. # The user initialized torch.dist themselves, create cdb and short-circuit
  520. cdb = TorchBackend(dist_backend, timeout, init_method)
  521. return
  522. if dist_init_required is False:
  523. assert (
  524. cdb is not None and cdb.is_initialized() is True
  525. ), "Distributed backend is not initialized. Please set dist_init_required to True or initialize before calling deepspeed.initialize()"
  526. else:
  527. # Initialize torch distributed if needed
  528. required_env = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
  529. if auto_mpi_discovery and not all(map(lambda v: v in os.environ, required_env)):
  530. if verbose:
  531. utils.logger.info("Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...")
  532. if in_aml() and not in_dlts():
  533. patch_aml_env_for_torch_nccl_backend(verbose=verbose)
  534. elif in_aws_sm():
  535. patch_aws_sm_env_for_torch_nccl_backend(verbose=verbose)
  536. else:
  537. mpi_discovery(distributed_port=distributed_port, verbose=verbose)
  538. if cdb is not None and cdb.is_initialized():
  539. if int(os.getenv('RANK', '0')) == 0:
  540. utils.logger.info('Distributed backend already initialized')
  541. else:
  542. assert isinstance(timeout, timedelta)
  543. if dist_backend is None:
  544. dist_backend = get_accelerator().communication_backend_name()
  545. if int(os.getenv('RANK', '0')) == 0:
  546. utils.logger.info('Initializing TorchBackend in DeepSpeed with backend {}'.format(dist_backend))
  547. # Create a torch backend object, initialize torch distributed, and assign to cdb
  548. cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)
  549. def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
  550. '''
  551. Discovery MPI environment via mpi4py and map to relevant dist state
  552. '''
  553. from mpi4py import MPI
  554. import subprocess
  555. comm = MPI.COMM_WORLD
  556. rank = comm.Get_rank()
  557. world_size = comm.Get_size()
  558. master_addr = None
  559. if rank == 0:
  560. hostname_cmd = ["hostname -I"]
  561. result = subprocess.check_output(hostname_cmd, shell=True)
  562. master_addr = result.decode('utf-8').split()[0]
  563. master_addr = comm.bcast(master_addr, root=0)
  564. # Determine local rank by assuming hostnames are unique
  565. proc_name = MPI.Get_processor_name()
  566. all_procs = comm.allgather(proc_name)
  567. local_rank = sum([i == proc_name for i in all_procs[:rank]])
  568. os.environ['RANK'] = str(rank)
  569. os.environ['WORLD_SIZE'] = str(world_size)
  570. os.environ['LOCAL_RANK'] = str(local_rank)
  571. os.environ['MASTER_ADDR'] = master_addr
  572. os.environ['MASTER_PORT'] = str(distributed_port)
  573. if verbose:
  574. utils.logger.info(
  575. "Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}".
  576. format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
  577. os.environ['MASTER_PORT']))
  578. if cdb is not None and cdb.is_initialized():
  579. assert cdb.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, cdb.get_rank())
  580. assert cdb.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format(
  581. world_size, cdb.get_world_size())
  582. def in_aml():
  583. # Are we running inside an Azure Machine Learning (AML) environment?
  584. return 'AZUREML_EXPERIMENT_ID' in os.environ
  585. def in_aws_sm():
  586. # Are we running inside an AWS SageMaker environment?
  587. return 'SM_TRAINING_ENV' in os.environ
  588. def in_dlts():
  589. # Are we running on a DLTS cluster?
  590. return 'DLTS_JOB_ID' in os.environ
  591. def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
  592. """Helper routine to get and set environment variables.
  593. This is adapted from Azure ML's documentation available from:
  594. https://azure.github.io/azureml-web/docs/cheatsheet/distributed-training/#environment-variables-from-openmpi
  595. """
  596. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  597. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  598. single_node = int(os.environ["OMPI_COMM_WORLD_LOCAL_SIZE"]) == int(os.environ["WORLD_SIZE"])
  599. if not single_node:
  600. master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
  601. os.environ["MASTER_ADDR"] = master_node_params[0]
  602. # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
  603. if "MASTER_PORT" not in os.environ:
  604. os.environ["MASTER_PORT"] = str(master_port)
  605. else:
  606. os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
  607. os.environ["MASTER_PORT"] = DEFAULT_AML_MASTER_PORT
  608. if verbose:
  609. utils.logger.info("NCCL_SOCKET_IFNAME original value = {}".format(os.environ["NCCL_SOCKET_IFNAME"]))
  610. os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_AML_NCCL_SOCKET_IFNAME
  611. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  612. if verbose:
  613. utils.logger.info(
  614. "Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  615. .format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
  616. os.environ['MASTER_PORT']))
  617. def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
  618. """Helper routine to get and set environment variables when running inside an AWS SageMaker environment.
  619. """
  620. os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
  621. os.environ['LOCAL_RANK'] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
  622. os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
  623. if verbose:
  624. utils.logger.info(
  625. "Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
  626. .format(os.environ['RANK'], os.environ['LOCAL_RANK'], os.environ['WORLD_SIZE'], os.environ['MASTER_ADDR'],
  627. os.environ['MASTER_PORT']))