123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279 |
- #!/usr/bin/env python
- import argparse
- import os
- from pathlib import Path
- import yaml
- import ray
- from ray.tune.config_parser import make_parser
- from ray.tune.progress_reporter import CLIReporter, JupyterNotebookReporter
- from ray.tune.result import DEFAULT_RESULTS_DIR
- from ray.tune.resources import resources_to_json
- from ray.tune.tune import run_experiments
- from ray.tune.schedulers import create_scheduler
- from ray.rllib.utils.deprecation import deprecation_warning
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- try:
- class_name = get_ipython().__class__.__name__
- IS_NOTEBOOK = True if "Terminal" not in class_name else False
- except NameError:
- IS_NOTEBOOK = False
- # Try to import both backends for flag checking/warnings.
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- EXAMPLE_USAGE = """
- Training example via RLlib CLI:
- rllib train --run DQN --env CartPole-v0
- Grid search example via RLlib CLI:
- rllib train -f tuned_examples/cartpole-grid-search-example.yaml
- Grid search example via executable:
- ./train.py -f tuned_examples/cartpole-grid-search-example.yaml
- Note that -f overrides all other trial-specific command-line options.
- """
- def create_parser(parser_creator=None):
- parser = make_parser(
- parser_creator=parser_creator,
- formatter_class=argparse.RawDescriptionHelpFormatter,
- description="Train a reinforcement learning agent.",
- epilog=EXAMPLE_USAGE)
- # See also the base parser definition in ray/tune/config_parser.py
- parser.add_argument(
- "--ray-address",
- default=None,
- type=str,
- help="Connect to an existing Ray cluster at this address instead "
- "of starting a new one.")
- parser.add_argument(
- "--ray-ui",
- action="store_true",
- help="Whether to enable the Ray web UI.")
- # Deprecated: Use --ray-ui, instead.
- parser.add_argument(
- "--no-ray-ui",
- action="store_true",
- help="Deprecated! Ray UI is disabled by default now. "
- "Use `--ray-ui` to enable.")
- parser.add_argument(
- "--local-mode",
- action="store_true",
- help="Run ray in local mode for easier debugging.")
- parser.add_argument(
- "--ray-num-cpus",
- default=None,
- type=int,
- help="--num-cpus to use if starting a new cluster.")
- parser.add_argument(
- "--ray-num-gpus",
- default=None,
- type=int,
- help="--num-gpus to use if starting a new cluster.")
- parser.add_argument(
- "--ray-num-nodes",
- default=None,
- type=int,
- help="Emulate multiple cluster nodes for debugging.")
- parser.add_argument(
- "--ray-object-store-memory",
- default=None,
- type=int,
- help="--object-store-memory to use if starting a new cluster.")
- parser.add_argument(
- "--experiment-name",
- default="default",
- type=str,
- help="Name of the subdirectory under `local_dir` to put results in.")
- parser.add_argument(
- "--local-dir",
- default=DEFAULT_RESULTS_DIR,
- type=str,
- help="Local dir to save training results to. Defaults to '{}'.".format(
- DEFAULT_RESULTS_DIR))
- parser.add_argument(
- "--upload-dir",
- default="",
- type=str,
- help="Optional URI to sync training results to (e.g. s3://bucket).")
- # This will override any framework setting found in a yaml file.
- parser.add_argument(
- "--framework",
- choices=["tf", "tf2", "tfe", "torch"],
- default=None,
- help="The DL framework specifier.")
- parser.add_argument(
- "-v", action="store_true", help="Whether to use INFO level logging.")
- parser.add_argument(
- "-vv", action="store_true", help="Whether to use DEBUG level logging.")
- parser.add_argument(
- "--resume",
- action="store_true",
- help="Whether to attempt to resume previous Tune experiments.")
- parser.add_argument(
- "--trace",
- action="store_true",
- help="Whether to attempt to enable tracing for eager mode.")
- parser.add_argument(
- "--env", default=None, type=str, help="The gym environment to use.")
- parser.add_argument(
- "-f",
- "--config-file",
- default=None,
- type=str,
- help="If specified, use config options from this file. Note that this "
- "overrides any trial-specific options set via flags above.")
- # Obsolete: Use --framework=torch|tf2|tfe instead!
- parser.add_argument(
- "--torch",
- action="store_true",
- help="Whether to use PyTorch (instead of tf) as the DL framework.")
- parser.add_argument(
- "--eager",
- action="store_true",
- help="Whether to attempt to enable TF eager execution.")
- return parser
- def run(args, parser):
- if args.config_file:
- with open(args.config_file) as f:
- experiments = yaml.safe_load(f)
- else:
- # Note: keep this in sync with tune/config_parser.py
- experiments = {
- args.experiment_name: { # i.e. log to ~/ray_results/default
- "run": args.run,
- "checkpoint_freq": args.checkpoint_freq,
- "checkpoint_at_end": args.checkpoint_at_end,
- "keep_checkpoints_num": args.keep_checkpoints_num,
- "checkpoint_score_attr": args.checkpoint_score_attr,
- "local_dir": args.local_dir,
- "resources_per_trial": (args.resources_per_trial
- and resources_to_json(
- args.resources_per_trial)),
- "stop": args.stop,
- "config": dict(args.config, env=args.env),
- "restore": args.restore,
- "num_samples": args.num_samples,
- "sync_config": {
- "upload_dir": args.upload_dir,
- }
- }
- }
- # Ray UI.
- if args.no_ray_ui:
- deprecation_warning(old="--no-ray-ui", new="--ray-ui", error=False)
- args.ray_ui = False
- verbose = 1
- for exp in experiments.values():
- # Bazel makes it hard to find files specified in `args` (and `data`).
- # Look for them here.
- # NOTE: Some of our yaml files don't have a `config` section.
- input_ = exp.get("config", {}).get("input")
- if input_ and input_ != "sampler":
- # This script runs in the ray/rllib dir.
- rllib_dir = Path(__file__).parent
- def patch_path(path):
- if isinstance(path, list):
- return [patch_path(i) for i in path]
- elif isinstance(path, dict):
- return {
- patch_path(k): patch_path(v)
- for k, v in path.items()
- }
- elif isinstance(path, str):
- if os.path.exists(path):
- return path
- else:
- abs_path = str(rllib_dir.absolute().joinpath(path))
- return abs_path if os.path.exists(abs_path) else path
- else:
- return path
- exp["config"]["input"] = patch_path(input_)
- if not exp.get("run"):
- parser.error("the following arguments are required: --run")
- if not exp.get("env") and not exp.get("config", {}).get("env"):
- parser.error("the following arguments are required: --env")
- if args.torch:
- deprecation_warning("--torch", "--framework=torch")
- exp["config"]["framework"] = "torch"
- elif args.eager:
- deprecation_warning("--eager", "--framework=[tf2|tfe]")
- exp["config"]["framework"] = "tfe"
- elif args.framework is not None:
- exp["config"]["framework"] = args.framework
- if args.trace:
- if exp["config"]["framework"] not in ["tf2", "tfe"]:
- raise ValueError("Must enable --eager to enable tracing.")
- exp["config"]["eager_tracing"] = True
- if args.v:
- exp["config"]["log_level"] = "INFO"
- verbose = 3 # Print details on trial result
- if args.vv:
- exp["config"]["log_level"] = "DEBUG"
- verbose = 3 # Print details on trial result
- if args.ray_num_nodes:
- # Import this only here so that train.py also works with
- # older versions (and user doesn't use `--ray-num-nodes`).
- from ray.cluster_utils import Cluster
- cluster = Cluster()
- for _ in range(args.ray_num_nodes):
- cluster.add_node(
- num_cpus=args.ray_num_cpus or 1,
- num_gpus=args.ray_num_gpus or 0,
- object_store_memory=args.ray_object_store_memory)
- ray.init(address=cluster.address)
- else:
- ray.init(
- include_dashboard=args.ray_ui,
- address=args.ray_address,
- object_store_memory=args.ray_object_store_memory,
- num_cpus=args.ray_num_cpus,
- num_gpus=args.ray_num_gpus,
- local_mode=args.local_mode)
- if IS_NOTEBOOK:
- progress_reporter = JupyterNotebookReporter(
- overwrite=verbose >= 3, print_intermediate_tables=verbose >= 1)
- else:
- progress_reporter = CLIReporter(print_intermediate_tables=verbose >= 1)
- run_experiments(
- experiments,
- scheduler=create_scheduler(args.scheduler, **args.scheduler_config),
- resume=args.resume,
- verbose=verbose,
- progress_reporter=progress_reporter,
- concurrent=True)
- ray.shutdown()
- def main():
- parser = create_parser()
- args = parser.parse_args()
- run(args, parser)
- if __name__ == "__main__":
- main()
|