train.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #!/usr/bin/env python
  2. import argparse
  3. import os
  4. from pathlib import Path
  5. import yaml
  6. import ray
  7. from ray.tune.config_parser import make_parser
  8. from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
  9. from ray.tune.result import DEFAULT_RESULTS_DIR
  10. from ray.tune.resources import resources_to_json
  11. from ray.tune.tune import run_experiments
  12. from ray.tune.schedulers import create_scheduler
  13. from ray.rllib.utils.deprecation import deprecation_warning
  14. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  15. try:
  16. class_name = get_ipython().__class__.__name__
  17. IS_NOTEBOOK = True if "Terminal" not in class_name else False
  18. except NameError:
  19. IS_NOTEBOOK = False
  20. # Try to import both backends for flag checking/warnings.
  21. tf1, tf, tfv = try_import_tf()
  22. torch, _ = try_import_torch()
  23. EXAMPLE_USAGE = """
  24. Training example via RLlib CLI:
  25. rllib train --run DQN --env CartPole-v0
  26. Grid search example via RLlib CLI:
  27. rllib train -f tuned_examples/cartpole-grid-search-example.yaml
  28. Grid search example via executable:
  29. ./train.py -f tuned_examples/cartpole-grid-search-example.yaml
  30. Note that -f overrides all other trial-specific command-line options.
  31. """
  32. def create_parser(parser_creator=None):
  33. parser = make_parser(
  34. parser_creator=parser_creator,
  35. formatter_class=argparse.RawDescriptionHelpFormatter,
  36. description="Train a reinforcement learning agent.",
  37. epilog=EXAMPLE_USAGE)
  38. # See also the base parser definition in ray/tune/config_parser.py
  39. parser.add_argument(
  40. "--ray-address",
  41. default=None,
  42. type=str,
  43. help="Connect to an existing Ray cluster at this address instead "
  44. "of starting a new one.")
  45. parser.add_argument(
  46. "--ray-ui",
  47. action="store_true",
  48. help="Whether to enable the Ray web UI.")
  49. # Deprecated: Use --ray-ui, instead.
  50. parser.add_argument(
  51. "--no-ray-ui",
  52. action="store_true",
  53. help="Deprecated! Ray UI is disabled by default now. "
  54. "Use `--ray-ui` to enable.")
  55. parser.add_argument(
  56. "--local-mode",
  57. action="store_true",
  58. help="Run ray in local mode for easier debugging.")
  59. parser.add_argument(
  60. "--ray-num-cpus",
  61. default=None,
  62. type=int,
  63. help="--num-cpus to use if starting a new cluster.")
  64. parser.add_argument(
  65. "--ray-num-gpus",
  66. default=None,
  67. type=int,
  68. help="--num-gpus to use if starting a new cluster.")
  69. parser.add_argument(
  70. "--ray-num-nodes",
  71. default=None,
  72. type=int,
  73. help="Emulate multiple cluster nodes for debugging.")
  74. parser.add_argument(
  75. "--ray-object-store-memory",
  76. default=None,
  77. type=int,
  78. help="--object-store-memory to use if starting a new cluster.")
  79. parser.add_argument(
  80. "--experiment-name",
  81. default="default",
  82. type=str,
  83. help="Name of the subdirectory under `local_dir` to put results in.")
  84. parser.add_argument(
  85. "--local-dir",
  86. default=DEFAULT_RESULTS_DIR,
  87. type=str,
  88. help="Local dir to save training results to. Defaults to '{}'.".format(
  89. DEFAULT_RESULTS_DIR))
  90. parser.add_argument(
  91. "--upload-dir",
  92. default="",
  93. type=str,
  94. help="Optional URI to sync training results to (e.g. s3://bucket).")
  95. # This will override any framework setting found in a yaml file.
  96. parser.add_argument(
  97. "--framework",
  98. choices=["tf", "tf2", "tfe", "torch"],
  99. default=None,
  100. help="The DL framework specifier.")
  101. parser.add_argument(
  102. "-v", action="store_true", help="Whether to use INFO level logging.")
  103. parser.add_argument(
  104. "-vv", action="store_true", help="Whether to use DEBUG level logging.")
  105. parser.add_argument(
  106. "--resume",
  107. action="store_true",
  108. help="Whether to attempt to resume previous Tune experiments.")
  109. parser.add_argument(
  110. "--trace",
  111. action="store_true",
  112. help="Whether to attempt to enable tracing for eager mode.")
  113. parser.add_argument(
  114. "--env", default=None, type=str, help="The gym environment to use.")
  115. parser.add_argument(
  116. "-f",
  117. "--config-file",
  118. default=None,
  119. type=str,
  120. help="If specified, use config options from this file. Note that this "
  121. "overrides any trial-specific options set via flags above.")
  122. # Obsolete: Use --framework=torch|tf2|tfe instead!
  123. parser.add_argument(
  124. "--torch",
  125. action="store_true",
  126. help="Whether to use PyTorch (instead of tf) as the DL framework.")
  127. parser.add_argument(
  128. "--eager",
  129. action="store_true",
  130. help="Whether to attempt to enable TF eager execution.")
  131. return parser
  132. def run(args, parser):
  133. if args.config_file:
  134. with open(args.config_file) as f:
  135. experiments = yaml.safe_load(f)
  136. else:
  137. # Note: keep this in sync with tune/config_parser.py
  138. experiments = {
  139. args.experiment_name: { # i.e. log to ~/ray_results/default
  140. "run": args.run,
  141. "checkpoint_freq": args.checkpoint_freq,
  142. "checkpoint_at_end": args.checkpoint_at_end,
  143. "keep_checkpoints_num": args.keep_checkpoints_num,
  144. "checkpoint_score_attr": args.checkpoint_score_attr,
  145. "local_dir": args.local_dir,
  146. "resources_per_trial": (args.resources_per_trial
  147. and resources_to_json(
  148. args.resources_per_trial)),
  149. "stop": args.stop,
  150. "config": dict(args.config, env=args.env),
  151. "restore": args.restore,
  152. "num_samples": args.num_samples,
  153. "sync_config": {
  154. "upload_dir": args.upload_dir,
  155. }
  156. }
  157. }
  158. # Ray UI.
  159. if args.no_ray_ui:
  160. deprecation_warning(old="--no-ray-ui", new="--ray-ui", error=False)
  161. args.ray_ui = False
  162. verbose = 1
  163. for exp in experiments.values():
  164. # Bazel makes it hard to find files specified in `args` (and `data`).
  165. # Look for them here.
  166. # NOTE: Some of our yaml files don't have a `config` section.
  167. input_ = exp.get("config", {}).get("input")
  168. if input_ and input_ != "sampler":
  169. # This script runs in the ray/rllib dir.
  170. rllib_dir = Path(__file__).parent
  171. def patch_path(path):
  172. if isinstance(path, list):
  173. return [patch_path(i) for i in path]
  174. elif isinstance(path, dict):
  175. return {
  176. patch_path(k): patch_path(v)
  177. for k, v in path.items()
  178. }
  179. elif isinstance(path, str):
  180. if os.path.exists(path):
  181. return path
  182. else:
  183. abs_path = str(rllib_dir.absolute().joinpath(path))
  184. return abs_path if os.path.exists(abs_path) else path
  185. else:
  186. return path
  187. exp["config"]["input"] = patch_path(input_)
  188. if not exp.get("run"):
  189. parser.error("the following arguments are required: --run")
  190. if not exp.get("env") and not exp.get("config", {}).get("env"):
  191. parser.error("the following arguments are required: --env")
  192. if args.torch:
  193. deprecation_warning("--torch", "--framework=torch")
  194. exp["config"]["framework"] = "torch"
  195. elif args.eager:
  196. deprecation_warning("--eager", "--framework=[tf2|tfe]")
  197. exp["config"]["framework"] = "tfe"
  198. elif args.framework is not None:
  199. exp["config"]["framework"] = args.framework
  200. if args.trace:
  201. if exp["config"]["framework"] not in ["tf2", "tfe"]:
  202. raise ValueError("Must enable --eager to enable tracing.")
  203. exp["config"]["eager_tracing"] = True
  204. if args.v:
  205. exp["config"]["log_level"] = "INFO"
  206. verbose = 3 # Print details on trial result
  207. if args.vv:
  208. exp["config"]["log_level"] = "DEBUG"
  209. verbose = 3 # Print details on trial result
  210. if args.ray_num_nodes:
  211. # Import this only here so that train.py also works with
  212. # older versions (and user doesn't use `--ray-num-nodes`).
  213. from ray.cluster_utils import Cluster
  214. cluster = Cluster()
  215. for _ in range(args.ray_num_nodes):
  216. cluster.add_node(
  217. num_cpus=args.ray_num_cpus or 1,
  218. num_gpus=args.ray_num_gpus or 0,
  219. object_store_memory=args.ray_object_store_memory)
  220. ray.init(address=cluster.address)
  221. else:
  222. ray.init(
  223. include_dashboard=args.ray_ui,
  224. address=args.ray_address,
  225. object_store_memory=args.ray_object_store_memory,
  226. num_cpus=args.ray_num_cpus,
  227. num_gpus=args.ray_num_gpus,
  228. local_mode=args.local_mode)
  229. if IS_NOTEBOOK:
  230. progress_reporter = JupyterNotebookReporter(
  231. overwrite=verbose >= 3, print_intermediate_tables=verbose >= 1)
  232. else:
  233. progress_reporter = CLIReporter(print_intermediate_tables=verbose >= 1)
  234. run_experiments(
  235. experiments,
  236. scheduler=create_scheduler(args.scheduler, **args.scheduler_config),
  237. resume=args.resume,
  238. verbose=verbose,
  239. progress_reporter=progress_reporter,
  240. concurrent=True)
  241. ray.shutdown()
  242. def main():
  243. parser = create_parser()
  244. args = parser.parse_args()
  245. run(args, parser)
  246. if __name__ == "__main__":
  247. main()