train.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. #!/usr/bin/env python
  2. import importlib
  3. import json
  4. import os
  5. from pathlib import Path
  6. import re
  7. import sys
  8. import typer
  9. from typing import Optional
  10. import uuid
  11. import yaml
  12. import ray
  13. from ray.air.integrations.wandb import WandbLoggerCallback
  14. from ray.tune.resources import resources_to_json, json_to_resources
  15. from ray.tune.tune import run_experiments
  16. from ray.tune.schedulers import create_scheduler
  17. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  18. from ray.rllib.common import CLIArguments as cli
  19. from ray.rllib.common import FrameworkEnum, SupportedFileType
  20. from ray.rllib.common import download_example_file, get_file_type
  21. def import_backends():
  22. """Try to import both backends for flag checking/warnings."""
  23. tf1, tf, tfv = try_import_tf()
  24. torch, _ = try_import_torch()
  25. # Create the "train" Typer app
  26. train_app = typer.Typer()
  27. def _patch_path(path: str):
  28. """
  29. Patch a path to be relative to the current working directory.
  30. Args:
  31. path: relative input path.
  32. Returns: Patched path.
  33. """
  34. # This script runs in the ray/rllib dir.
  35. rllib_dir = Path(__file__).parent
  36. if isinstance(path, list):
  37. return [_patch_path(i) for i in path]
  38. elif isinstance(path, dict):
  39. return {_patch_path(k): _patch_path(v) for k, v in path.items()}
  40. elif isinstance(path, str):
  41. if os.path.exists(path):
  42. return path
  43. else:
  44. abs_path = str(rllib_dir.absolute().joinpath(path))
  45. return abs_path if os.path.exists(abs_path) else path
  46. else:
  47. return path
  48. def load_experiments_from_file(
  49. config_file: str,
  50. file_type: SupportedFileType,
  51. stop: Optional[str] = None,
  52. checkpoint_config: Optional[dict] = None,
  53. ) -> dict:
  54. """Load experiments from a file. Supports YAML and Python files.
  55. If you want to use a Python file, it has to have a 'config' variable
  56. that is an AlgorithmConfig object and - optionally - a `stop` variable defining
  57. the stop criteria.
  58. Args:
  59. config_file: The yaml or python file to be used as experiment definition.
  60. Must only contain exactly one experiment.
  61. file_type: One value of the `SupportedFileType` enum (yaml or python).
  62. stop: An optional stop json string, only used if file_type is python.
  63. If None (and file_type is python), will try to extract stop information
  64. from a defined `stop` variable in the python file, otherwise, will use {}.
  65. checkpoint_config: An optional checkpoint config to add to the returned
  66. experiments dict.
  67. Returns:
  68. The experiments dict ready to be passed into `tune.run_experiments()`.
  69. """
  70. # Yaml file.
  71. if file_type == SupportedFileType.yaml:
  72. with open(config_file) as f:
  73. experiments = yaml.safe_load(f)
  74. if stop is not None and stop != "{}":
  75. raise ValueError("`stop` criteria only supported for python files.")
  76. # Python file case (ensured by file type enum)
  77. else:
  78. module_name = os.path.basename(config_file).replace(".py", "")
  79. spec = importlib.util.spec_from_file_location(module_name, config_file)
  80. module = importlib.util.module_from_spec(spec)
  81. sys.modules[module_name] = module
  82. spec.loader.exec_module(module)
  83. if not hasattr(module, "config"):
  84. raise ValueError(
  85. "Your Python file must contain a 'config' variable "
  86. "that is an AlgorithmConfig object."
  87. )
  88. algo_config = getattr(module, "config")
  89. if stop is None:
  90. stop = getattr(module, "stop", {})
  91. else:
  92. stop = json.loads(stop)
  93. # Note: we do this gymnastics to support the old format that
  94. # "run_rllib_experiments" expects. Ideally, we'd just build the config and
  95. # run the algo.
  96. config = algo_config.to_dict()
  97. experiments = {
  98. f"default_{uuid.uuid4().hex}": {
  99. "run": algo_config.__class__.__name__.replace("Config", ""),
  100. "env": config.get("env"),
  101. "config": config,
  102. "stop": stop,
  103. }
  104. }
  105. for key, val in experiments.items():
  106. experiments[key]["checkpoint_config"] = checkpoint_config or {}
  107. return experiments
  108. @train_app.command()
  109. def file(
  110. # File-based arguments.
  111. config_file: str = cli.ConfigFile,
  112. # stopping conditions
  113. stop: Optional[str] = cli.Stop,
  114. # Environment override.
  115. env: Optional[str] = cli.Env,
  116. # Checkpointing
  117. checkpoint_freq: int = cli.CheckpointFreq,
  118. checkpoint_at_end: bool = cli.CheckpointAtEnd,
  119. keep_checkpoints_num: int = cli.KeepCheckpointsNum,
  120. checkpoint_score_attr: str = cli.CheckpointScoreAttr,
  121. # Additional config arguments used for overriding.
  122. v: bool = cli.V,
  123. vv: bool = cli.VV,
  124. framework: FrameworkEnum = cli.Framework,
  125. trace: bool = cli.Trace,
  126. # WandB options.
  127. wandb_key: Optional[str] = cli.WandBKey,
  128. wandb_project: Optional[str] = cli.WandBProject,
  129. wandb_run_name: Optional[str] = cli.WandBRunName,
  130. # Ray cluster options.
  131. local_mode: bool = cli.LocalMode,
  132. ray_address: Optional[str] = cli.RayAddress,
  133. ray_ui: bool = cli.RayUi,
  134. ray_num_cpus: Optional[int] = cli.RayNumCpus,
  135. ray_num_gpus: Optional[int] = cli.RayNumGpus,
  136. ray_num_nodes: Optional[int] = cli.RayNumNodes,
  137. ray_object_store_memory: Optional[int] = cli.RayObjectStoreMemory,
  138. # Ray scheduling options.
  139. resume: bool = cli.Resume,
  140. scheduler: Optional[str] = cli.Scheduler,
  141. scheduler_config: str = cli.SchedulerConfig,
  142. ):
  143. """Train a reinforcement learning agent from file.
  144. The file argument is required to run this command.\n\n
  145. Grid search example with the RLlib CLI:\n
  146. rllib train file tuned_examples/ppo/cartpole-ppo.yaml\n\n
  147. You can also run an example from a URL with the file content:\n
  148. rllib train file https://raw.githubusercontent.com/ray-project/ray/\
  149. master/rllib/tuned_examples/ppo/cartpole-ppo.yaml
  150. """
  151. # Attempt to download the file if it's not found locally.
  152. config_file, temp_file = download_example_file(
  153. example_file=config_file, base_url=None
  154. )
  155. import_backends()
  156. framework = framework.value if framework else None
  157. checkpoint_config = {
  158. "checkpoint_frequency": checkpoint_freq,
  159. "checkpoint_at_end": checkpoint_at_end,
  160. "num_to_keep": keep_checkpoints_num,
  161. "checkpoint_score_attribute": checkpoint_score_attr,
  162. }
  163. file_type = get_file_type(config_file)
  164. experiments = load_experiments_from_file(
  165. config_file, file_type, stop, checkpoint_config
  166. )
  167. exp_name = list(experiments.keys())[0]
  168. experiment = experiments[exp_name]
  169. algo = experiment["run"]
  170. # Override the env from the config by the value given on the command line.
  171. if env is not None:
  172. experiment["env"] = env
  173. # WandB logging support.
  174. callbacks = None
  175. if wandb_key is not None:
  176. project = wandb_project or (
  177. algo.lower() + "-" + re.sub("\\W+", "-", experiment["env"].lower())
  178. if file_type == SupportedFileType.python
  179. else exp_name
  180. )
  181. callbacks = [
  182. WandbLoggerCallback(
  183. api_key=wandb_key,
  184. project=project,
  185. **({"name": wandb_run_name} if wandb_run_name is not None else {}),
  186. )
  187. ]
  188. # if we had to download the config file, remove the temp file.
  189. if temp_file:
  190. temp_file.close()
  191. run_rllib_experiments(
  192. experiments=experiments,
  193. v=v,
  194. vv=vv,
  195. framework=framework,
  196. trace=trace,
  197. ray_num_nodes=ray_num_nodes,
  198. ray_num_cpus=ray_num_cpus,
  199. ray_num_gpus=ray_num_gpus,
  200. ray_object_store_memory=ray_object_store_memory,
  201. ray_ui=ray_ui,
  202. ray_address=ray_address,
  203. local_mode=local_mode,
  204. resume=resume,
  205. scheduler=scheduler,
  206. scheduler_config=scheduler_config,
  207. algo=algo,
  208. callbacks=callbacks,
  209. )
  210. @train_app.callback(invoke_without_command=True)
  211. def run(
  212. # Context object for subcommands
  213. ctx: typer.Context,
  214. # Config-based arguments.
  215. algo: str = cli.Algo,
  216. env: str = cli.Env,
  217. config: str = cli.Config,
  218. stop: str = cli.Stop,
  219. experiment_name: str = cli.ExperimentName,
  220. num_samples: int = cli.NumSamples,
  221. checkpoint_freq: int = cli.CheckpointFreq,
  222. checkpoint_at_end: bool = cli.CheckpointAtEnd,
  223. local_dir: str = cli.LocalDir,
  224. restore: str = cli.Restore,
  225. resources_per_trial: str = cli.ResourcesPerTrial,
  226. keep_checkpoints_num: int = cli.KeepCheckpointsNum,
  227. checkpoint_score_attr: str = cli.CheckpointScoreAttr,
  228. upload_dir: str = cli.UploadDir,
  229. # Additional config arguments used for overriding.
  230. v: bool = cli.V,
  231. vv: bool = cli.VV,
  232. framework: FrameworkEnum = cli.Framework,
  233. trace: bool = cli.Trace,
  234. # Ray cluster options.
  235. local_mode: bool = cli.LocalMode,
  236. ray_address: str = cli.RayAddress,
  237. ray_ui: bool = cli.RayUi,
  238. ray_num_cpus: int = cli.RayNumCpus,
  239. ray_num_gpus: int = cli.RayNumGpus,
  240. ray_num_nodes: int = cli.RayNumNodes,
  241. ray_object_store_memory: int = cli.RayObjectStoreMemory,
  242. # Ray scheduling options.
  243. resume: bool = cli.Resume,
  244. scheduler: str = cli.Scheduler,
  245. scheduler_config: str = cli.SchedulerConfig,
  246. ):
  247. """Train a reinforcement learning agent from command line options.
  248. The options --env and --algo are required to run this command.
  249. Training example via RLlib CLI:\n
  250. rllib train --algo DQN --env CartPole-v1\n\n
  251. """
  252. # If no subcommand is specified, simply run the following lines as the
  253. # "rllib train" main command.
  254. if ctx.invoked_subcommand is None:
  255. # we only check for backends when actually running the command. otherwise the
  256. # start-up time is too slow.
  257. import_backends()
  258. framework = framework.value if framework else None
  259. config = json.loads(config)
  260. resources_per_trial = json_to_resources(resources_per_trial)
  261. # Load a single experiment from configuration
  262. experiments = {
  263. experiment_name: { # i.e. log to ~/ray_results/default
  264. "run": algo,
  265. "checkpoint_config": {
  266. "checkpoint_frequency": checkpoint_freq,
  267. "checkpoint_at_end": checkpoint_at_end,
  268. "num_to_keep": keep_checkpoints_num,
  269. "checkpoint_score_attribute": checkpoint_score_attr,
  270. },
  271. "local_dir": local_dir,
  272. "resources_per_trial": (
  273. resources_per_trial and resources_to_json(resources_per_trial)
  274. ),
  275. "stop": json.loads(stop),
  276. "config": dict(config, env=env),
  277. "restore": restore,
  278. "num_samples": num_samples,
  279. "sync_config": {
  280. "upload_dir": upload_dir,
  281. },
  282. }
  283. }
  284. run_rllib_experiments(
  285. experiments=experiments,
  286. v=v,
  287. vv=vv,
  288. framework=framework,
  289. trace=trace,
  290. ray_num_nodes=ray_num_nodes,
  291. ray_num_cpus=ray_num_cpus,
  292. ray_num_gpus=ray_num_gpus,
  293. ray_object_store_memory=ray_object_store_memory,
  294. ray_ui=ray_ui,
  295. ray_address=ray_address,
  296. local_mode=local_mode,
  297. resume=resume,
  298. scheduler=scheduler,
  299. scheduler_config=scheduler_config,
  300. algo=algo,
  301. )
  302. def run_rllib_experiments(
  303. experiments: dict,
  304. v: cli.V,
  305. vv: cli.VV,
  306. framework: str,
  307. trace: cli.Trace,
  308. ray_num_nodes: cli.RayNumNodes,
  309. ray_num_cpus: cli.RayNumCpus,
  310. ray_num_gpus: cli.RayNumGpus,
  311. ray_object_store_memory: cli.RayObjectStoreMemory,
  312. ray_ui: cli.RayUi,
  313. ray_address: cli.RayAddress,
  314. local_mode: cli.LocalMode,
  315. resume: cli.Resume,
  316. scheduler: cli.Scheduler,
  317. scheduler_config: cli.SchedulerConfig,
  318. algo: cli.Algo,
  319. callbacks=None,
  320. ):
  321. """Main training function for the RLlib CLI, whether you've loaded your
  322. experiments from a config file or from command line options."""
  323. # Override experiment data with command line arguments.
  324. verbose = 1
  325. for exp in experiments.values():
  326. # Bazel makes it hard to find files specified in `args` (and `data`).
  327. # Look for them here.
  328. # NOTE: Some of our yaml files don't have a `config` section.
  329. input_ = exp.get("config", {}).get("input")
  330. if input_ and input_ != "sampler":
  331. exp["config"]["input"] = _patch_path(input_)
  332. if not exp.get("env") and not exp.get("config", {}).get("env"):
  333. raise ValueError(
  334. "You either need to provide an --env argument (e.g. 'CartPole-v1') "
  335. "or pass an `env` key with a valid environment to your `config`"
  336. "argument."
  337. )
  338. elif framework is not None:
  339. exp["config"]["framework"] = framework
  340. if trace:
  341. if exp["config"]["framework"] not in ["tf2"]:
  342. raise ValueError("Must enable framework=tf2 to enable eager tracing.")
  343. exp["config"]["eager_tracing"] = True
  344. if v:
  345. exp["config"]["log_level"] = "INFO"
  346. verbose = 3 # Print details on trial result
  347. if vv:
  348. exp["config"]["log_level"] = "DEBUG"
  349. verbose = 3 # Print details on trial result
  350. # Initialize the Ray cluster with the specified options.
  351. if ray_num_nodes:
  352. # Import this only here so that train.py also works with
  353. # older versions (and user doesn't use `--ray-num-nodes`).
  354. from ray.cluster_utils import Cluster
  355. cluster = Cluster()
  356. for _ in range(ray_num_nodes):
  357. cluster.add_node(
  358. num_cpus=ray_num_cpus or 1,
  359. num_gpus=ray_num_gpus or 0,
  360. object_store_memory=ray_object_store_memory,
  361. )
  362. ray.init(address=cluster.address)
  363. else:
  364. ray.init(
  365. include_dashboard=ray_ui,
  366. address=ray_address,
  367. object_store_memory=ray_object_store_memory,
  368. num_cpus=ray_num_cpus,
  369. num_gpus=ray_num_gpus,
  370. local_mode=local_mode,
  371. )
  372. # Run the Tune experiment and return the trials.
  373. scheduler_config = json.loads(scheduler_config)
  374. trials = run_experiments(
  375. experiments,
  376. scheduler=create_scheduler(scheduler, **scheduler_config),
  377. resume=resume,
  378. verbose=verbose,
  379. concurrent=True,
  380. callbacks=callbacks,
  381. )
  382. ray.shutdown()
  383. checkpoints = []
  384. for trial in trials:
  385. if trial.checkpoint:
  386. checkpoints.append(trial.checkpoint)
  387. if checkpoints:
  388. from rich import print
  389. from rich.panel import Panel
  390. print("\nYour training finished.")
  391. print("Best available checkpoint for each trial:")
  392. for cp in checkpoints:
  393. print(f" {cp.path}")
  394. print(
  395. "\nYou can now evaluate your trained algorithm from any "
  396. "checkpoint, e.g. by running:"
  397. )
  398. print(Panel(f"[green] rllib evaluate {checkpoints[0].path} --algo {algo}"))
  399. def main():
  400. """Run the CLI."""
  401. train_app()
  402. if __name__ == "__main__":
  403. main()