launcher_helper.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) Microsoft Corporation.
  2. # SPDX-License-Identifier: Apache-2.0
  3. # DeepSpeed Team
  4. import os
  5. import sys
  6. import argparse
  7. import subprocess
  8. from deepspeed.utils import logger
  9. from deepspeed.launcher.constants import MPICH_LAUNCHER
  10. def parse_args(args=None):
  11. parser = argparse.ArgumentParser(description="DeepSpeed launcher helper to map environment variables for"
  12. "multi-node/multi-gpu training jobs.",
  13. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  14. parser.add_argument("--launcher",
  15. default=MPICH_LAUNCHER,
  16. type=str,
  17. help="(optional) choose launcher backend for multi-node "
  18. "training. Options currently include MPICH.")
  19. parser.add_argument("--module",
  20. action="store_true",
  21. help="Change each process to interpret the launch "
  22. "script as a Python module, executing with the same "
  23. "behavior as 'python -m'.")
  24. parser.add_argument("--no_python",
  25. action="store_true",
  26. help="Skip prepending the training script with "
  27. "'python' - just execute it directly.")
  28. parser.add_argument("user_script", type=str, help="User script to launch, followed by any required "
  29. "arguments.")
  30. parser.add_argument('user_args', nargs=argparse.REMAINDER)
  31. parser.add_argument("--bind_cores_to_rank",
  32. action="store_true",
  33. help="Bind each rank to different cores of the host")
  34. parser.add_argument("--bind_core_list",
  35. type=str,
  36. default=None,
  37. help="List of cores to bind to with comma separated list of "
  38. "numbers and range. i.e. 1,3-5,7 => [1,3,4,5,7]. When not "
  39. "specified, all cores on system would be used rank binding")
  40. return parser.parse_args(args=args)
  41. def env_mapping(env, rank_name_list=None, local_rank_name_list=None):
  42. rank = None
  43. for rank_name in rank_name_list:
  44. if rank_name in env:
  45. if rank == None:
  46. rank = env.get(rank_name)
  47. elif rank != env.get(rank_name):
  48. raise EnvironmentError(f"rank number doesn't match!")
  49. if rank == None:
  50. raise EnvironmentError(f"rank number is not in current env!")
  51. env['RANK'] = rank
  52. local_rank = None
  53. for local_rank_name in local_rank_name_list:
  54. if local_rank_name in env:
  55. if local_rank == None:
  56. local_rank = env.get(local_rank_name)
  57. elif local_rank != env.get(local_rank_name):
  58. raise EnvironmentError(f"local_rank number doesn't match!")
  59. if local_rank == None:
  60. raise EnvironmentError(f"rank number is not in current env!")
  61. env['LOCAL_RANK'] = local_rank
  62. return env
  63. def main(args=None):
  64. args = parse_args(args)
  65. env = os.environ.copy()
  66. args.launcher = args.launcher.lower()
  67. if args.launcher == MPICH_LAUNCHER:
  68. rank_name_list = ["PMIX_RANK"] + ["PMI_RANK"]
  69. local_rank_name_list = ["PALS_LOCAL_RANKID"] + ["MPI_LOCALRANKID"]
  70. env = env_mapping(env, rank_name_list=rank_name_list, local_rank_name_list=local_rank_name_list)
  71. else:
  72. raise NotImplementedError(f"Unknown launcher {args.launcher}")
  73. python_exec = []
  74. if not args.no_python:
  75. python_exec += [sys.executable, "-u"]
  76. if args.module:
  77. python_exec.append("-m")
  78. cmd = python_exec + [args.user_script] + args.user_args
  79. logger.info(f"launcher_helper cmd = {' '.join(cmd)}")
  80. result = subprocess.Popen(cmd, env=env, close_fds=False)
  81. result.wait()
  82. if __name__ == "__main__":
  83. main()