elastic_agent.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. '''Copyright The Microsoft DeepSpeed Team'''
  2. from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
  3. from typing import Any, Dict, Optional, Tuple
  4. from datetime import datetime
  5. from torch.distributed.elastic.agent.server.api import log, _get_socket_with_port
  6. from torch.distributed.elastic.metrics import put_metric
  7. from torch.distributed.elastic.agent.server.api import (
  8. RunResult,
  9. WorkerGroup,
  10. WorkerSpec,
  11. WorkerState,
  12. )
  13. from torch.distributed import Store
  14. import time
  15. import os
  16. from torch.distributed.elastic.multiprocessing import start_processes
  17. from torch.distributed.elastic.utils import macros
  18. import shutil
  19. import copy
  20. from contextlib import closing
  21. import subprocess
  22. class DSElasticAgent(LocalElasticAgent):
  23. def __init__(
  24. self,
  25. spec: WorkerSpec,
  26. env: Dict,
  27. start_method="spawn",
  28. exit_barrier_timeout: float = 300,
  29. log_dir: Optional[str] = None,
  30. ):
  31. super().__init__(spec, start_method, exit_barrier_timeout, log_dir)
  32. self.ds_env = env
  33. @staticmethod
  34. def _set_master_addr_port(store: Store,
  35. master_addr: Optional[str],
  36. master_port: Optional[int]):
  37. if master_port is None:
  38. sock = _get_socket_with_port()
  39. with closing(sock):
  40. master_port = sock.getsockname()[1]
  41. if master_addr is None:
  42. # master_addr = _get_fq_hostname()
  43. result = subprocess.check_output("hostname -I", shell=True)
  44. master_addr = result.decode('utf-8').split()[0]
  45. store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
  46. store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
  47. def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
  48. spec = worker_group.spec
  49. store = worker_group.store
  50. assert store is not None
  51. master_addr, master_port = super()._get_master_addr_port(store)
  52. restart_count = spec.max_restarts - self._remaining_restarts
  53. use_agent_store = spec.rdzv_handler.get_backend() == "static"
  54. args: Dict[int, Tuple] = {}
  55. envs: Dict[int, Dict[str, str]] = {}
  56. for worker in worker_group.workers:
  57. local_rank = worker.local_rank
  58. worker_env_ds = copy.deepcopy(self.ds_env)
  59. worker_env_elastic = {
  60. "LOCAL_RANK": str(local_rank),
  61. "RANK": str(worker.global_rank),
  62. "GROUP_RANK": str(worker_group.group_rank),
  63. "ROLE_RANK": str(worker.role_rank),
  64. "ROLE_NAME": spec.role,
  65. "LOCAL_WORLD_SIZE": str(spec.local_world_size),
  66. "WORLD_SIZE": str(worker.world_size),
  67. "GROUP_WORLD_SIZE": str(worker_group.group_world_size),
  68. "ROLE_WORLD_SIZE": str(worker.role_world_size),
  69. "MASTER_ADDR": master_addr,
  70. "MASTER_PORT": str(master_port),
  71. "TORCHELASTIC_RESTART_COUNT": str(restart_count),
  72. "TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
  73. "TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
  74. "TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
  75. "NCCL_ASYNC_ERROR_HANDLING": os.getenv("NCCL_ASYNC_ERROR_HANDLING",
  76. str(1)),
  77. }
  78. worker_env_ds.update(worker_env_elastic)
  79. if "OMP_NUM_THREADS" in os.environ:
  80. worker_env_ds["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
  81. envs[local_rank] = worker_env_ds
  82. worker_args = list(spec.args)
  83. worker_args = macros.substitute(worker_args, str(local_rank))
  84. args[local_rank] = tuple(worker_args)
  85. # scaling events do not count towards restarts (gets same attempt #)
  86. # remove existing log dir if this restart is due to a scaling event
  87. attempt_log_dir = os.path.join(self._log_dir, f"attempt_{restart_count}")
  88. shutil.rmtree(attempt_log_dir, ignore_errors=True)
  89. os.makedirs(attempt_log_dir)
  90. assert spec.entrypoint is not None
  91. self._pcontext = start_processes(
  92. name=spec.role,
  93. entrypoint=spec.entrypoint,
  94. args=args,
  95. envs=envs,
  96. log_dir=attempt_log_dir,
  97. start_method=self._start_method,
  98. redirects=spec.redirects,
  99. tee=spec.tee,
  100. )
  101. return self._pcontext.pids()
  102. def _invoke_run(self, role: str = "default") -> RunResult:
  103. # NOTE: currently only works for a single role
  104. spec = self._worker_group.spec
  105. role = spec.role
  106. log.info(
  107. f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}")
  108. self._initialize_workers(self._worker_group)
  109. monitor_interval = spec.monitor_interval
  110. rdzv_handler = spec.rdzv_handler
  111. participants = rdzv_handler._state_holder.state.participants
  112. while True:
  113. assert self._worker_group.state != WorkerState.INIT
  114. time.sleep(monitor_interval)
  115. run_result = self._monitor_workers(self._worker_group)
  116. state = run_result.state
  117. self._worker_group.state = state
  118. expire_time = datetime.utcnow() - (
  119. rdzv_handler._settings.keep_alive_interval *
  120. rdzv_handler._settings.keep_alive_max_attempt)
  121. _dead_nodes = [
  122. node for node,
  123. last_heartbeat in
  124. rdzv_handler._state_holder.state.last_heartbeats.items()
  125. if last_heartbeat < expire_time
  126. ]
  127. put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
  128. put_metric(f"workers.{role}.{state.name.lower()}", 1)
  129. if state == WorkerState.SUCCEEDED:
  130. log.info(
  131. f"[{role}] worker group successfully finished."
  132. f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
  133. )
  134. self._exit_barrier()
  135. return run_result
  136. elif state in {
  137. WorkerState.UNHEALTHY,
  138. WorkerState.FAILED
  139. } or len(participants) > len(rdzv_handler._state_holder.state.participants):
  140. if self._remaining_restarts > 0:
  141. log.info(
  142. f"[{role}] Worker group {state.name}. "
  143. f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
  144. f" will restart worker group")
  145. self._remaining_restarts -= 1
  146. # rdzv_handler._state_holder.state.restart = False
  147. self._restart_workers(self._worker_group)
  148. participants = rdzv_handler._state_holder.state.participants
  149. else:
  150. self._stop_workers(self._worker_group)
  151. self._worker_group.state = WorkerState.FAILED
  152. self._exit_barrier()
  153. return run_result
  154. elif state == WorkerState.HEALTHY:
  155. # membership changes do not count as retries
  156. num_nodes_waiting = rdzv_handler.num_nodes_waiting()
  157. group_rank = self._worker_group.group_rank
  158. if num_nodes_waiting > 0:
  159. log.info(f"[{role}] Detected {num_nodes_waiting} "
  160. f"new nodes from group_rank={group_rank}; "
  161. f"will restart worker group")
  162. self._restart_workers(self._worker_group)
  163. participants = rdzv_handler._state_holder.state.participants
  164. else:
  165. raise Exception(f"[{role}] Worker group in {state.name} state")