elastic_agent.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
  5. from typing import Any, Dict, Optional, Tuple
  6. from datetime import datetime
  7. from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
  8. from torch.distributed.elastic.metrics import put_metric
  9. from torch.distributed.elastic.agent.server.api import (
  10. RunResult,
  11. WorkerGroup,
  12. WorkerSpec,
  13. WorkerState,
  14. )
  15. from torch.distributed import Store
  16. import time
  17. import os
  18. from torch.distributed.elastic.multiprocessing import start_processes
  19. from torch.distributed.elastic.utils import macros
  20. import shutil
  21. import copy
  22. from contextlib import closing
  23. import subprocess
  24. class DSElasticAgent(LocalElasticAgent):
  25. def __init__(
  26. self,
  27. spec: WorkerSpec,
  28. env: Dict,
  29. start_method="spawn",
  30. exit_barrier_timeout: float = 300,
  31. log_dir: Optional[str] = None,
  32. ):
  33. super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
  34. self.ds_env = env
  35. @staticmethod
  36. def _set_master_addr_port(store: Store,
  37. master_addr: Optional[str],
  38. master_port: Optional[int],
  39. local_addr: Optional[str] = None):
  40. if master_port is None:
  41. sock = _get_socket_with_port()
  42. with closing(sock):
  43. master_port = sock.getsockname()[1]
  44. if master_addr is None:
  45. # master_addr = _get_fq_hostname()
  46. result = subprocess.check_output("hostname -I", shell=True)
  47. master_addr = result.decode('utf-8').split()[0]
  48. store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
  49. store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
  50. def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
  51. spec = worker_group.spec
  52. store = worker_group.store
  53. assert store is not None
  54. master_addr, master_port = super()._get_master_addr_port(store)
  55. restart_count = spec.max_restarts - self._remaining_restarts
  56. use_agent_store = spec.rdzv_handler.get_backend() == "static"
  57. args: Dict[int, Tuple] = {}
  58. envs: Dict[int, Dict[str, str]] = {}
  59. for worker in worker_group.workers:
  60. local_rank = worker.local_rank
  61. worker_env_ds = copy.deepcopy(self.ds_env)
  62. worker_env_elastic = {
  63. "LOCAL_RANK": str(local_rank),
  64. "RANK": str(worker.global_rank),
  65. "GROUP_RANK": str(worker_group.group_rank),
  66. "ROLE_RANK": str(worker.role_rank),
  67. "ROLE_NAME": spec.role,
  68. "LOCAL_WORLD_SIZE": str(spec.local_world_size),
  69. "WORLD_SIZE": str(worker.world_size),
  70. "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
  71. "ROLE_WORLD_SIZE": str(worker.role_world_size),
  72. "MASTER_ADDR": master_addr,
  73. "MASTER_PORT": str(master_port),
  74. "TORCHELASTIC_RESTART_COUNT": str(restart_count),
  75. "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
  76. "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
  77. "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
  78. "NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)),
  79. }
  80. worker_env_ds.update(worker_env_elastic)
  81. if "OMP_NUM_THREADS" in os.environ:
  82. worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
  83. envs[local_rank] = worker_env_ds
  84. worker_args = list(spec.args)
  85. worker_args = macros.substitute(worker_args, str(local_rank))
  86. args[local_rank] = tuple(worker_args)
  87. # scaling events do not count towards restarts (gets same attempt #)
  88. # remove existing log dir if this restart is due to a scaling event
  89. attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
  90. shutil.rmtree(attempt_log_dir, ignore_errors=True)
  91. os.makedirs(attempt_log_dir)
  92. assert spec.entrypoint is not None
  93. self._pcontext = start_processes(
  94. name=spec.role,
  95. entrypoint=spec.entrypoint,
  96. args=args,
  97. envs=envs,
  98. log_dir=attempt_log_dir,
  99. start_method=self._start_method,
  100. redirects=spec.redirects,
  101. tee=spec.tee,
  102. )
  103. return self._pcontext.pids()
  104. def _invoke_run(self, role: str = "default") -> RunResult:
  105. # NOTE: currently only works for a single role
  106. spec = self._worker_group.spec
  107. role = spec.role
  108. log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
  109. self._initialize_workers(self._worker_group)
  110. monitor_interval = spec.monitor_interval
  111. rdzv_handler = spec.rdzv_handler
  112. participants = rdzv_handler._state_holder.state.participants
  113. while True:
  114. assert self._worker_group.state != WorkerState.INIT
  115. time.sleep(monitor_interval)
  116. run_result = self._monitor_workers(self._worker_group)
  117. state = run_result.state
  118. self._worker_group.state = state
  119. expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval *
  120. rdzv_handler._settings.keep_alive_max_attempt)
  121. _dead_nodes = [
  122. node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items()
  123. if last_heartbeat < expire_time
  124. ]
  125. put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
  126. put_metric(f"workers.{role}.{state.name.lower()}", 1)
  127. if state == WorkerState.SUCCEEDED:
  128. log.info(f"[{role}] worker group successfully finished."
  129. f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.")
  130. self._exit_barrier()
  131. return run_result
  132. elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED
  133. } or len(participants) > len(rdzv_handler._state_holder.state.participants):
  134. if self._remaining_restarts > 0:
  135. log.info(f"[{role}] Worker group {state.name}. "
  136. f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
  137. f" will restart worker group")
  138. self._remaining_restarts -= 1
  139. # rdzv_handler._state_holder.state.restart = False
  140. self._restart_workers(self._worker_group)
  141. participants = rdzv_handler._state_holder.state.participants
  142. else:
  143. self._stop_workers(self._worker_group)
  144. self._worker_group.state = WorkerState.FAILED
  145. self._exit_barrier()
  146. return run_result
  147. elif state == WorkerState.HEALTHY:
  148. # membership changes do not count as retries
  149. num_nodes_waiting = rdzv_handler.num_nodes_waiting()
  150. group_rank = self._worker_group.group_rank
  151. if num_nodes_waiting > 0:
  152. log.info(f"[{role}] Detected {num_nodes_waiting} "
  153. f"new nodes from group_rank={group_rank}; "
  154. f"will restart worker group")
  155. self._restart_workers(self._worker_group)
  156. participants = rdzv_handler._state_holder.state.participants
  157. else:
  158. raise Exception(f"[{role}] Worker group in {state.name} state")