train.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #!/usr/bin/env python
  2. import importlib
  3. import json
  4. import os
  5. from pathlib import Path
  6. import re
  7. import sys
  8. import typer
  9. from typing import Optional
  10. import uuid
  11. import yaml
  12. import ray
  13. from ray.air.integrations.wandb import WandbLoggerCallback
  14. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  15. from ray.rllib.common import CLIArguments as cli
  16. from ray.rllib.common import FrameworkEnum, SupportedFileType
  17. from ray.rllib.common import _download_example_file, _get_file_type
  18. from ray.train.constants import _DEPRECATED_VALUE
  19. from ray.tune.resources import resources_to_json, json_to_resources
  20. from ray.tune.tune import run_experiments
  21. from ray.tune.schedulers import create_scheduler
  22. from ray.util.annotations import DeveloperAPI, PublicAPI
  23. def _import_backends():
  24. """Try to import both backends for flag checking/warnings."""
  25. tf1, tf, tfv = try_import_tf()
  26. torch, _ = try_import_torch()
  27. # Create the "train" Typer app
  28. train_app = typer.Typer()
  29. def _patch_path(path: str):
  30. """
  31. Patch a path to be relative to the current working directory.
  32. Args:
  33. path: relative input path.
  34. Returns: Patched path.
  35. """
  36. # This script runs in the ray/rllib dir.
  37. rllib_dir = Path(__file__).parent
  38. if isinstance(path, list):
  39. return [_patch_path(i) for i in path]
  40. elif isinstance(path, dict):
  41. return {_patch_path(k): _patch_path(v) for k, v in path.items()}
  42. elif isinstance(path, str):
  43. if os.path.exists(path):
  44. return path
  45. else:
  46. abs_path = str(rllib_dir.absolute().joinpath(path))
  47. return abs_path if os.path.exists(abs_path) else path
  48. else:
  49. return path
  50. @PublicAPI(stability="beta")
  51. def load_experiments_from_file(
  52. config_file: str,
  53. file_type: SupportedFileType,
  54. stop: Optional[str] = None,
  55. checkpoint_config: Optional[dict] = None,
  56. ) -> dict:
  57. """Load experiments from a file. Supports YAML and Python files.
  58. If you want to use a Python file, it has to have a 'config' variable
  59. that is an AlgorithmConfig object and - optionally - a `stop` variable defining
  60. the stop criteria.
  61. Args:
  62. config_file: The yaml or python file to be used as experiment definition.
  63. Must only contain exactly one experiment.
  64. file_type: One value of the `SupportedFileType` enum (yaml or python).
  65. stop: An optional stop json string, only used if file_type is python.
  66. If None (and file_type is python), will try to extract stop information
  67. from a defined `stop` variable in the python file, otherwise, will use {}.
  68. checkpoint_config: An optional checkpoint config to add to the returned
  69. experiments dict.
  70. Returns:
  71. The experiments dict ready to be passed into `tune.run_experiments()`.
  72. """
  73. # Yaml file.
  74. if file_type == SupportedFileType.yaml:
  75. with open(config_file) as f:
  76. experiments = yaml.safe_load(f)
  77. if stop is not None and stop != "{}":
  78. raise ValueError("`stop` criteria only supported for python files.")
  79. # Python file case (ensured by file type enum)
  80. else:
  81. module_name = os.path.basename(config_file).replace(".py", "")
  82. spec = importlib.util.spec_from_file_location(module_name, config_file)
  83. module = importlib.util.module_from_spec(spec)
  84. sys.modules[module_name] = module
  85. spec.loader.exec_module(module)
  86. if not hasattr(module, "config"):
  87. raise ValueError(
  88. "Your Python file must contain a 'config' variable "
  89. "that is an AlgorithmConfig object."
  90. )
  91. algo_config = getattr(module, "config")
  92. if stop is None:
  93. stop = getattr(module, "stop", {})
  94. else:
  95. stop = json.loads(stop)
  96. # Note: we do this gymnastics to support the old format that
  97. # "_run_rllib_experiments" expects. Ideally, we'd just build the config and
  98. # run the algo.
  99. config = algo_config.to_dict()
  100. experiments = {
  101. f"default_{uuid.uuid4().hex}": {
  102. "run": algo_config.algo_class,
  103. "env": config.get("env"),
  104. "config": config,
  105. "stop": stop,
  106. }
  107. }
  108. for key, val in experiments.items():
  109. experiments[key]["checkpoint_config"] = checkpoint_config or {}
  110. return experiments
  111. @DeveloperAPI
  112. @train_app.command()
  113. def file(
  114. # File-based arguments.
  115. config_file: str = cli.ConfigFile,
  116. # stopping conditions
  117. stop: Optional[str] = cli.Stop,
  118. # Environment override.
  119. env: Optional[str] = cli.Env,
  120. # Checkpointing
  121. checkpoint_freq: int = cli.CheckpointFreq,
  122. checkpoint_at_end: bool = cli.CheckpointAtEnd,
  123. keep_checkpoints_num: int = cli.KeepCheckpointsNum,
  124. checkpoint_score_attr: str = cli.CheckpointScoreAttr,
  125. # Additional config arguments used for overriding.
  126. v: bool = cli.V,
  127. vv: bool = cli.VV,
  128. framework: FrameworkEnum = cli.Framework,
  129. trace: bool = cli.Trace,
  130. # WandB options.
  131. wandb_key: Optional[str] = cli.WandBKey,
  132. wandb_project: Optional[str] = cli.WandBProject,
  133. wandb_run_name: Optional[str] = cli.WandBRunName,
  134. # Ray cluster options.
  135. local_mode: bool = cli.LocalMode,
  136. ray_address: Optional[str] = cli.RayAddress,
  137. ray_ui: bool = cli.RayUi,
  138. ray_num_cpus: Optional[int] = cli.RayNumCpus,
  139. ray_num_gpus: Optional[int] = cli.RayNumGpus,
  140. ray_num_nodes: Optional[int] = cli.RayNumNodes,
  141. ray_object_store_memory: Optional[int] = cli.RayObjectStoreMemory,
  142. # Ray scheduling options.
  143. resume: bool = cli.Resume,
  144. scheduler: Optional[str] = cli.Scheduler,
  145. scheduler_config: str = cli.SchedulerConfig,
  146. ):
  147. """Train a reinforcement learning agent from file.
  148. The file argument is required to run this command.\n\n
  149. Grid search example with the RLlib CLI:\n
  150. rllib train file tuned_examples/ppo/cartpole-ppo.yaml\n\n
  151. You can also run an example from a URL with the file content:\n
  152. rllib train file https://raw.githubusercontent.com/ray-project/ray/\
  153. master/rllib/tuned_examples/ppo/cartpole-ppo.yaml
  154. """
  155. # Attempt to download the file if it's not found locally.
  156. config_file, temp_file = _download_example_file(
  157. example_file=config_file, base_url=None
  158. )
  159. _import_backends()
  160. framework = framework.value if framework else None
  161. checkpoint_config = {
  162. "checkpoint_frequency": checkpoint_freq,
  163. "checkpoint_at_end": checkpoint_at_end,
  164. "num_to_keep": keep_checkpoints_num,
  165. "checkpoint_score_attribute": checkpoint_score_attr,
  166. }
  167. file_type = _get_file_type(config_file)
  168. experiments = load_experiments_from_file(
  169. config_file, file_type, stop, checkpoint_config
  170. )
  171. exp_name = list(experiments.keys())[0]
  172. experiment = experiments[exp_name]
  173. algo = experiment["run"]
  174. # Override the env from the config by the value given on the command line.
  175. if env is not None:
  176. experiment["env"] = env
  177. # WandB logging support.
  178. callbacks = None
  179. if wandb_key is not None:
  180. project = wandb_project or (
  181. algo.lower() + "-" + re.sub("\\W+", "-", experiment["env"].lower())
  182. if file_type == SupportedFileType.python
  183. else exp_name
  184. )
  185. callbacks = [
  186. WandbLoggerCallback(
  187. api_key=wandb_key,
  188. project=project,
  189. **({"name": wandb_run_name} if wandb_run_name is not None else {}),
  190. )
  191. ]
  192. # if we had to download the config file, remove the temp file.
  193. if temp_file:
  194. temp_file.close()
  195. _run_rllib_experiments(
  196. experiments=experiments,
  197. v=v,
  198. vv=vv,
  199. framework=framework,
  200. trace=trace,
  201. ray_num_nodes=ray_num_nodes,
  202. ray_num_cpus=ray_num_cpus,
  203. ray_num_gpus=ray_num_gpus,
  204. ray_object_store_memory=ray_object_store_memory,
  205. ray_ui=ray_ui,
  206. ray_address=ray_address,
  207. local_mode=local_mode,
  208. resume=resume,
  209. scheduler=scheduler,
  210. scheduler_config=scheduler_config,
  211. algo=algo,
  212. callbacks=callbacks,
  213. )
  214. @DeveloperAPI
  215. @train_app.callback(invoke_without_command=True)
  216. def run(
  217. # Context object for subcommands
  218. ctx: typer.Context,
  219. # Config-based arguments.
  220. algo: str = cli.Algo,
  221. env: str = cli.Env,
  222. config: str = cli.Config,
  223. stop: str = cli.Stop,
  224. experiment_name: str = cli.ExperimentName,
  225. num_samples: int = cli.NumSamples,
  226. checkpoint_freq: int = cli.CheckpointFreq,
  227. checkpoint_at_end: bool = cli.CheckpointAtEnd,
  228. storage_path: str = cli.StoragePath,
  229. restore: str = cli.Restore,
  230. resources_per_trial: str = cli.ResourcesPerTrial,
  231. keep_checkpoints_num: int = cli.KeepCheckpointsNum,
  232. checkpoint_score_attr: str = cli.CheckpointScoreAttr,
  233. # Additional config arguments used for overriding.
  234. v: bool = cli.V,
  235. vv: bool = cli.VV,
  236. framework: FrameworkEnum = cli.Framework,
  237. trace: bool = cli.Trace,
  238. # Ray cluster options.
  239. local_mode: bool = cli.LocalMode,
  240. ray_address: str = cli.RayAddress,
  241. ray_ui: bool = cli.RayUi,
  242. ray_num_cpus: int = cli.RayNumCpus,
  243. ray_num_gpus: int = cli.RayNumGpus,
  244. ray_num_nodes: int = cli.RayNumNodes,
  245. ray_object_store_memory: int = cli.RayObjectStoreMemory,
  246. # Ray scheduling options.
  247. resume: bool = cli.Resume,
  248. scheduler: str = cli.Scheduler,
  249. scheduler_config: str = cli.SchedulerConfig,
  250. # TODO(arturn): [Deprecated] Remove in 2.11.
  251. local_dir: str = cli.LocalDir,
  252. upload_dir: str = cli.UploadDir,
  253. ):
  254. """Train a reinforcement learning agent from command line options.
  255. The options --env and --algo are required to run this command.
  256. Training example via RLlib CLI:\n
  257. rllib train --algo DQN --env CartPole-v1\n\n
  258. """
  259. # If no subcommand is specified, simply run the following lines as the
  260. # "rllib train" main command.
  261. if ctx.invoked_subcommand is None:
  262. # we only check for backends when actually running the command. otherwise the
  263. # start-up time is too slow.
  264. _import_backends()
  265. framework = framework.value if framework else None
  266. config = json.loads(config)
  267. resources_per_trial = json_to_resources(resources_per_trial)
  268. if local_dir != _DEPRECATED_VALUE:
  269. raise DeprecationWarning(
  270. "`local_dir` is deprecated. Please use `storage_path` instead."
  271. )
  272. if upload_dir != _DEPRECATED_VALUE:
  273. raise DeprecationWarning(
  274. "`upload_dir` is deprecated. Please use `storage_path` instead."
  275. )
  276. # Load a single experiment from configuration
  277. experiments = {
  278. experiment_name: { # i.e. log to ~/ray_results/default
  279. "run": algo,
  280. "checkpoint_config": {
  281. "checkpoint_frequency": checkpoint_freq,
  282. "checkpoint_at_end": checkpoint_at_end,
  283. "num_to_keep": keep_checkpoints_num,
  284. "checkpoint_score_attribute": checkpoint_score_attr,
  285. },
  286. "storage_path": storage_path,
  287. "resources_per_trial": (
  288. resources_per_trial and resources_to_json(resources_per_trial)
  289. ),
  290. "stop": json.loads(stop),
  291. "config": dict(config, env=env),
  292. "restore": restore,
  293. "num_samples": num_samples,
  294. }
  295. }
  296. _run_rllib_experiments(
  297. experiments=experiments,
  298. v=v,
  299. vv=vv,
  300. framework=framework,
  301. trace=trace,
  302. ray_num_nodes=ray_num_nodes,
  303. ray_num_cpus=ray_num_cpus,
  304. ray_num_gpus=ray_num_gpus,
  305. ray_object_store_memory=ray_object_store_memory,
  306. ray_ui=ray_ui,
  307. ray_address=ray_address,
  308. local_mode=local_mode,
  309. resume=resume,
  310. scheduler=scheduler,
  311. scheduler_config=scheduler_config,
  312. algo=algo,
  313. )
  314. def _run_rllib_experiments(
  315. experiments: dict,
  316. v: cli.V,
  317. vv: cli.VV,
  318. framework: str,
  319. trace: cli.Trace,
  320. ray_num_nodes: cli.RayNumNodes,
  321. ray_num_cpus: cli.RayNumCpus,
  322. ray_num_gpus: cli.RayNumGpus,
  323. ray_object_store_memory: cli.RayObjectStoreMemory,
  324. ray_ui: cli.RayUi,
  325. ray_address: cli.RayAddress,
  326. local_mode: cli.LocalMode,
  327. resume: cli.Resume,
  328. scheduler: cli.Scheduler,
  329. scheduler_config: cli.SchedulerConfig,
  330. algo: cli.Algo,
  331. callbacks=None,
  332. ):
  333. """Main training function for the RLlib CLI, whether you've loaded your
  334. experiments from a config file or from command line options."""
  335. # Override experiment data with command line arguments.
  336. verbose = 1
  337. for exp in experiments.values():
  338. # Bazel makes it hard to find files specified in `args` (and `data`).
  339. # Look for them here.
  340. # NOTE: Some of our yaml files don't have a `config` section.
  341. input_ = exp.get("config", {}).get("input")
  342. if input_ and input_ != "sampler":
  343. exp["config"]["input"] = _patch_path(input_)
  344. if not exp.get("env") and not exp.get("config", {}).get("env"):
  345. raise ValueError(
  346. "You either need to provide an --env argument (e.g. 'CartPole-v1') "
  347. "or pass an `env` key with a valid environment to your `config`"
  348. "argument."
  349. )
  350. elif framework is not None:
  351. exp["config"]["framework"] = framework
  352. if trace:
  353. if exp["config"]["framework"] not in ["tf2"]:
  354. raise ValueError("Must enable framework=tf2 to enable eager tracing.")
  355. exp["config"]["eager_tracing"] = True
  356. if v:
  357. exp["config"]["log_level"] = "INFO"
  358. verbose = 3 # Print details on trial result
  359. if vv:
  360. exp["config"]["log_level"] = "DEBUG"
  361. verbose = 3 # Print details on trial result
  362. # Initialize the Ray cluster with the specified options.
  363. if ray_num_nodes:
  364. # Import this only here so that train.py also works with
  365. # older versions (and user doesn't use `--ray-num-nodes`).
  366. from ray.cluster_utils import Cluster
  367. cluster = Cluster()
  368. for _ in range(ray_num_nodes):
  369. cluster.add_node(
  370. num_cpus=ray_num_cpus or 1,
  371. num_gpus=ray_num_gpus or 0,
  372. object_store_memory=ray_object_store_memory,
  373. )
  374. ray.init(address=cluster.address)
  375. else:
  376. ray.init(
  377. include_dashboard=ray_ui,
  378. address=ray_address,
  379. object_store_memory=ray_object_store_memory,
  380. num_cpus=ray_num_cpus,
  381. num_gpus=ray_num_gpus,
  382. local_mode=local_mode,
  383. )
  384. # Run the Tune experiment and return the trials.
  385. scheduler_config = json.loads(scheduler_config)
  386. trials = run_experiments(
  387. experiments,
  388. scheduler=create_scheduler(scheduler, **scheduler_config),
  389. resume=resume,
  390. verbose=verbose,
  391. concurrent=True,
  392. callbacks=callbacks,
  393. )
  394. ray.shutdown()
  395. checkpoints = []
  396. for trial in trials:
  397. if trial.checkpoint:
  398. checkpoints.append(trial.checkpoint)
  399. if checkpoints:
  400. from rich import print
  401. from rich.panel import Panel
  402. print("\nYour training finished.")
  403. print("Best available checkpoint for each trial:")
  404. for cp in checkpoints:
  405. print(f" {cp.path}")
  406. print(
  407. "\nYou can now evaluate your trained algorithm from any "
  408. "checkpoint, e.g. by running:"
  409. )
  410. print(Panel(f"[green] rllib evaluate {checkpoints[0].path} --algo {algo}"))
  411. @DeveloperAPI
  412. def main():
  413. """Run the CLI."""
  414. train_app()
  415. if __name__ == "__main__":
  416. main()