elastic_agent.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
  5. from typing import Any, Dict, Optional, Tuple
  6. from datetime import datetime
  7. from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
  8. from torch.distributed.elastic.metrics import put_metric
  9. from torch.distributed.elastic.agent.server.api import (
  10. RunResult,
  11. WorkerGroup,
  12. WorkerSpec,
  13. WorkerState,
  14. )
  15. from torch.distributed import Store
  16. import time
  17. import os
  18. from torch.distributed.elastic.multiprocessing import start_processes
  19. from torch.distributed.elastic.utils import macros
  20. import shutil
  21. import copy
  22. from contextlib import closing
  23. import subprocess
  24. class DSElasticAgent(LocalElasticAgent):
  25. def __init__(
  26. self,
  27. spec: WorkerSpec,
  28. env: Dict,
  29. start_method="spawn",
  30. exit_barrier_timeout: float = 300,
  31. log_dir: Optional[str] = None,
  32. ):
  33. super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
  34. self.ds_env = env
  35. @staticmethod
  36. def _set_master_addr_port(store: Store, master_addr: Optional[str], master_port: Optional[int]):
  37. if master_port is None:
  38. sock = _get_socket_with_port()
  39. with closing(sock):
  40. master_port = sock.getsockname()[1]
  41. if master_addr is None:
  42. # master_addr = _get_fq_hostname()
  43. result = subprocess.check_output("hostname -I", shell=True)
  44. master_addr = result.decode('utf-8').split()[0]
  45. store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
  46. store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
  47. def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
  48. spec = worker_group.spec
  49. store = worker_group.store
  50. assert store is not None
  51. master_addr, master_port = super()._get_master_addr_port(store)
  52. restart_count = spec.max_restarts - self._remaining_restarts
  53. use_agent_store = spec.rdzv_handler.get_backend() == "static"
  54. args: Dict[int, Tuple] = {}
  55. envs: Dict[int, Dict[str, str]] = {}
  56. for worker in worker_group.workers:
  57. local_rank = worker.local_rank
  58. worker_env_ds = copy.deepcopy(self.ds_env)
  59. worker_env_elastic = {
  60. "LOCAL_RANK": str(local_rank),
  61. "RANK": str(worker.global_rank),
  62. "GROUP_RANK": str(worker_group.group_rank),
  63. "ROLE_RANK": str(worker.role_rank),
  64. "ROLE_NAME": spec.role,
  65. "LOCAL_WORLD_SIZE": str(spec.local_world_size),
  66. "WORLD_SIZE": str(worker.world_size),
  67. "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
  68. "ROLE_WORLD_SIZE": str(worker.role_world_size),
  69. "MASTER_ADDR": master_addr,
  70. "MASTER_PORT": str(master_port),
  71. "TORCHELASTIC_RESTART_COUNT": str(restart_count),
  72. "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
  73. "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
  74. "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
  75. "NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING", str(1)),
  76. }
  77. worker_env_ds.update(worker_env_elastic)
  78. if "OMP_NUM_THREADS" in os.environ:
  79. worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
  80. envs[local_rank] = worker_env_ds
  81. worker_args = list(spec.args)
  82. worker_args = macros.substitute(worker_args, str(local_rank))
  83. args[local_rank] = tuple(worker_args)
  84. # scaling events do not count towards restarts (gets same attempt #)
  85. # remove existing log dir if this restart is due to a scaling event
  86. attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
  87. shutil.rmtree(attempt_log_dir, ignore_errors=True)
  88. os.makedirs(attempt_log_dir)
  89. assert spec.entrypoint is not None
  90. self._pcontext = start_processes(
  91. name=spec.role,
  92. entrypoint=spec.entrypoint,
  93. args=args,
  94. envs=envs,
  95. log_dir=attempt_log_dir,
  96. start_method=self._start_method,
  97. redirects=spec.redirects,
  98. tee=spec.tee,
  99. )
  100. return self._pcontext.pids()
  101. def _invoke_run(self, role: str = "default") -> RunResult:
  102. # NOTE: currently only works for a single role
  103. spec = self._worker_group.spec
  104. role = spec.role
  105. log.info(f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
  106. self._initialize_workers(self._worker_group)
  107. monitor_interval = spec.monitor_interval
  108. rdzv_handler = spec.rdzv_handler
  109. participants = rdzv_handler._state_holder.state.participants
  110. while True:
  111. assert self._worker_group.state != WorkerState.INIT
  112. time.sleep(monitor_interval)
  113. run_result = self._monitor_workers(self._worker_group)
  114. state = run_result.state
  115. self._worker_group.state = state
  116. expire_time = datetime.utcnow() - (rdzv_handler._settings.keep_alive_interval *
  117. rdzv_handler._settings.keep_alive_max_attempt)
  118. _dead_nodes = [
  119. node for node, last_heartbeat in rdzv_handler._state_holder.state.last_heartbeats.items()
  120. if last_heartbeat < expire_time
  121. ]
  122. put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
  123. put_metric(f"workers.{role}.{state.name.lower()}", 1)
  124. if state == WorkerState.SUCCEEDED:
  125. log.info(f"[{role}] worker group successfully finished."
  126. f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish.")
  127. self._exit_barrier()
  128. return run_result
  129. elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED
  130. } or len(participants) > len(rdzv_handler._state_holder.state.participants):
  131. if self._remaining_restarts > 0:
  132. log.info(f"[{role}] Worker group {state.name}. "
  133. f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
  134. f" will restart worker group")
  135. self._remaining_restarts -= 1
  136. # rdzv_handler._state_holder.state.restart = False
  137. self._restart_workers(self._worker_group)
  138. participants = rdzv_handler._state_holder.state.participants
  139. else:
  140. self._stop_workers(self._worker_group)
  141. self._worker_group.state = WorkerState.FAILED
  142. self._exit_barrier()
  143. return run_result
  144. elif state == WorkerState.HEALTHY:
  145. # membership changes do not count as retries
  146. num_nodes_waiting = rdzv_handler.num_nodes_waiting()
  147. group_rank = self._worker_group.group_rank
  148. if num_nodes_waiting > 0:
  149. log.info(f"[{role}] Detected {num_nodes_waiting} "
  150. f"new nodes from group_rank={group_rank}; "
  151. f"will restart worker group")
  152. self._restart_workers(self._worker_group)
  153. participants = rdzv_handler._state_holder.state.participants
  154. else:
  155. raise Exception(f"[{role}] Worker group in {state.name} state")