slurm-launch.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # slurm-launch.py
  2. # Usage:
  3. # python slurm-launch.py --exp-name test \
  4. # --command "rllib train --run PPO --env CartPole-v0"
  5. import argparse
  6. import subprocess
  7. import sys
  8. import time
  9. from pathlib import Path
  10. template_file = Path(__file__) / "slurm-template.sh"
  11. JOB_NAME = "${JOB_NAME}"
  12. NUM_NODES = "${NUM_NODES}"
  13. NUM_GPUS_PER_NODE = "${NUM_GPUS_PER_NODE}"
  14. PARTITION_OPTION = "${PARTITION_OPTION}"
  15. COMMAND_PLACEHOLDER = "${COMMAND_PLACEHOLDER}"
  16. GIVEN_NODE = "${GIVEN_NODE}"
  17. LOAD_ENV = "${LOAD_ENV}"
  18. if __name__ == "__main__":
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. "--exp-name",
  22. type=str,
  23. required=True,
  24. help="The job name and path to logging file (exp_name.log).",
  25. )
  26. parser.add_argument(
  27. "--num-nodes", "-n", type=int, default=1, help="Number of nodes to use."
  28. )
  29. parser.add_argument(
  30. "--node",
  31. "-w",
  32. type=str,
  33. help="The specified nodes to use. Same format as the "
  34. "return of 'sinfo'. Default: ''.",
  35. )
  36. parser.add_argument(
  37. "--num-gpus",
  38. type=int,
  39. default=0,
  40. help="Number of GPUs to use in each node. (Default: 0)",
  41. )
  42. parser.add_argument(
  43. "--partition",
  44. "-p",
  45. type=str,
  46. )
  47. parser.add_argument(
  48. "--load-env",
  49. type=str,
  50. help="The script to load your environment ('module load cuda/10.1')",
  51. default="",
  52. )
  53. parser.add_argument(
  54. "--command",
  55. type=str,
  56. required=True,
  57. help="The command you wish to execute. For example: "
  58. " --command 'python test.py'. "
  59. "Note that the command must be a string.",
  60. )
  61. args = parser.parse_args()
  62. if args.node:
  63. # assert args.num_nodes == 1
  64. node_info = "#SBATCH -w {}".format(args.node)
  65. else:
  66. node_info = ""
  67. job_name = "{}_{}".format(
  68. args.exp_name, time.strftime("%m%d-%H%M", time.localtime())
  69. )
  70. partition_option = (
  71. "#SBATCH --partition={}".format(args.partition) if args.partition else ""
  72. )
  73. # ===== Modified the template script =====
  74. with open(template_file, "r") as f:
  75. text = f.read()
  76. text = text.replace(JOB_NAME, job_name)
  77. text = text.replace(NUM_NODES, str(args.num_nodes))
  78. text = text.replace(NUM_GPUS_PER_NODE, str(args.num_gpus))
  79. text = text.replace(PARTITION_OPTION, partition_option)
  80. text = text.replace(COMMAND_PLACEHOLDER, str(args.command))
  81. text = text.replace(LOAD_ENV, str(args.load_env))
  82. text = text.replace(GIVEN_NODE, node_info)
  83. text = text.replace(
  84. "# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
  85. "# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
  86. "RUNNABLE!",
  87. )
  88. # ===== Save the script =====
  89. script_file = "{}.sh".format(job_name)
  90. with open(script_file, "w") as f:
  91. f.write(text)
  92. # ===== Submit the job =====
  93. print("Starting to submit job!")
  94. subprocess.Popen(["sbatch", script_file])
  95. print(
  96. "Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
  97. script_file, "{}.log".format(job_name)
  98. )
  99. )
  100. sys.exit(0)