scripts.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. #!/usr/bin/env python
  2. import collections
  3. from rich.console import Console
  4. from rich.table import Table
  5. import typer
  6. from ray.air.constants import TRAINING_ITERATION
  7. from ray.rllib import train as train_module
  8. from ray.rllib.common import CLIArguments as cli
  9. from ray.rllib.common import (
  10. EXAMPLES,
  11. FrameworkEnum,
  12. example_help,
  13. _download_example_file,
  14. )
  15. from ray.rllib.utils.deprecation import deprecation_warning
  16. # Main Typer CLI app
  17. app = typer.Typer()
  18. example_app = typer.Typer()
  19. def example_error(example_id: str):
  20. return ValueError(
  21. f"Example {example_id} not found. Use `rllib example list` "
  22. f"to see available examples."
  23. )
  24. @example_app.callback()
  25. def example_callback():
  26. """RLlib command-line interface to run built-in examples. You can choose to list
  27. all available examples, get more information on an example or run a specific
  28. example.
  29. """
  30. pass
  31. @example_app.command()
  32. def list(
  33. filter: str = typer.Option(None, "--filter", "-f", help=example_help.get("filter"))
  34. ):
  35. """List all available RLlib examples that can be run from the command line.
  36. Note that many of these examples require specific hardware (e.g. a certain number
  37. of GPUs) to work.\n\n
  38. Example usage: `rllib example list --filter=cartpole`
  39. """
  40. table = Table(title="RLlib Examples")
  41. table.add_column("Example ID", justify="left", style="cyan", no_wrap=True)
  42. table.add_column("Description", justify="left", style="magenta")
  43. sorted_examples = collections.OrderedDict(sorted(EXAMPLES.items()))
  44. for name, value in sorted_examples.items():
  45. if filter:
  46. if filter.lower() in name:
  47. table.add_row(name, value["description"])
  48. else:
  49. table.add_row(name, value["description"])
  50. console = Console()
  51. console.print(table)
  52. console.print(
  53. "Run any RLlib example as using 'rllib example run <Example ID>'."
  54. "See 'rllib example run --help' for more information."
  55. )
  56. def get_example_file(example_id):
  57. """Simple helper function to get the example file for a given example ID."""
  58. if example_id not in EXAMPLES:
  59. raise example_error(example_id)
  60. example = EXAMPLES[example_id]
  61. assert (
  62. "file" in example.keys()
  63. ), f"Example {example_id} does not have a 'file' attribute."
  64. return example.get("file")
  65. @example_app.command()
  66. def get(example_id: str = typer.Argument(..., help="The example ID of the example.")):
  67. """Print the configuration of an example.\n\n
  68. Example usage: `rllib example get atari-a2c`
  69. """
  70. example_file = get_example_file(example_id)
  71. example_file, temp_file = _download_example_file(example_file)
  72. with open(example_file) as f:
  73. console = Console()
  74. console.print(f.read())
  75. @example_app.command()
  76. def run(example_id: str = typer.Argument(..., help="Example ID to run.")):
  77. """Run an RLlib example from the command line by simply providing its ID.\n\n
  78. Example usage: `rllib example run pong-impala`
  79. """
  80. example = EXAMPLES[example_id]
  81. example_file = get_example_file(example_id)
  82. example_file, temp_file = _download_example_file(example_file)
  83. stop = example.get("stop")
  84. train_module.file(
  85. config_file=example_file,
  86. stop=stop,
  87. checkpoint_freq=1,
  88. checkpoint_at_end=True,
  89. keep_checkpoints_num=None,
  90. checkpoint_score_attr=TRAINING_ITERATION,
  91. framework=FrameworkEnum.tf2,
  92. v=True,
  93. vv=False,
  94. trace=False,
  95. local_mode=False,
  96. ray_address=None,
  97. ray_ui=False,
  98. ray_num_cpus=None,
  99. ray_num_gpus=None,
  100. ray_num_nodes=None,
  101. ray_object_store_memory=None,
  102. resume=False,
  103. scheduler="FIFO",
  104. scheduler_config="{}",
  105. )
  106. if temp_file:
  107. temp_file.close()
  108. # Register all subcommands
  109. app.add_typer(example_app, name="example")
  110. app.add_typer(train_module.train_app, name="train")
  111. @app.command()
  112. def evaluate(
  113. checkpoint: str = cli.Checkpoint,
  114. algo: str = cli.Algo,
  115. env: str = cli.Env,
  116. local_mode: bool = cli.LocalMode,
  117. render: bool = cli.Render,
  118. steps: int = cli.Steps,
  119. episodes: int = cli.Episodes,
  120. out: str = cli.Out,
  121. config: str = cli.Config,
  122. save_info: bool = cli.SaveInfo,
  123. use_shelve: bool = cli.UseShelve,
  124. track_progress: bool = cli.TrackProgress,
  125. ):
  126. """Roll out a reinforcement learning agent given a checkpoint argument.
  127. You have to provide an environment ("--env") an an RLlib algorithm ("--algo") to
  128. evaluate your checkpoint.
  129. Example usage:\n\n
  130. rllib evaluate /tmp/ray/checkpoint_dir/checkpoint-0 --algo DQN --env CartPole-v1
  131. --steps 1000000 --out rollouts.pkl
  132. """
  133. from ray.rllib import evaluate as evaluate_module
  134. evaluate_module.run(
  135. checkpoint=checkpoint,
  136. algo=algo,
  137. env=env,
  138. local_mode=local_mode,
  139. render=render,
  140. steps=steps,
  141. episodes=episodes,
  142. out=out,
  143. config=config,
  144. save_info=save_info,
  145. use_shelve=use_shelve,
  146. track_progress=track_progress,
  147. )
  148. @app.command()
  149. def rollout(
  150. checkpoint: str = cli.Checkpoint,
  151. algo: str = cli.Algo,
  152. env: str = cli.Env,
  153. local_mode: bool = cli.LocalMode,
  154. render: bool = cli.Render,
  155. steps: int = cli.Steps,
  156. episodes: int = cli.Episodes,
  157. out: str = cli.Out,
  158. config: str = cli.Config,
  159. save_info: bool = cli.SaveInfo,
  160. use_shelve: bool = cli.UseShelve,
  161. track_progress: bool = cli.TrackProgress,
  162. ):
  163. """Old rollout script. Please use `rllib evaluate` instead."""
  164. from ray.rllib.utils.deprecation import deprecation_warning
  165. deprecation_warning(old="rllib rollout", new="rllib evaluate", error=True)
  166. @app.callback()
  167. def main_helper():
  168. """Welcome to the\n
  169. . ╔▄▓▓▓▓▄\n
  170. . ╔██▀╙╙╙▀██▄\n
  171. . ╫█████████████▓ ╫████▓ ╫████▓ ██▌ ▐██ ╫████▒\n
  172. . ╫███████████████▓ ╫█████▓ ╫█████▓ ╫██ ╫██ ╫██████▒\n
  173. . ╫█████▓ ████▓ ╫█████▓ ╫█████▓ ╙▓██████▀ ╫██████████████▒\n
  174. . ╫███████████████▓ ╫█████▓ ╫█████▓ ╫█▒ ╫████████████████▒\n
  175. . ╫█████████████▓ ╫█████▓ ╫█████▓ ╫█▒ ╫██████▒ ╫█████▒\n
  176. . ╫█████▓███████▓ ╫█████▓ ╫█████▓ ╫█▒ ╫██████▒ ╫█████▒\n
  177. . ╫█████▓ ██████▓ ╫████████████████▄ ╫█████▓ ╫█▒ ╫████████████████▒\n
  178. . ╫█████▓ ████▓ ╫█████████████████ ╫█████▓ ╫█▒ ╫██████████████▒\n
  179. . ╣▓▓▓▓▓▓▓▓▓▓▓▓██▓▓▓▓▓▓▓▓▓▓▓▓▄\n
  180. . ╫██╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╫█▒\n
  181. . ╫█ Command Line Interface █▒\n
  182. . ╫██▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄╣█▒\n
  183. . ▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀\n
  184. .\n
  185. Example usage for training:\n
  186. rllib train --algo DQN --env CartPole-v1\n
  187. rllib train file tuned_examples/ppo/pendulum-ppo.yaml\n\n
  188. Example usage for evaluation:\n
  189. rllib evaluate /trial_dir/checkpoint_000001/checkpoint-1 --algo DQN\n\n
  190. Example usage for built-in examples:\n
  191. rllib example list\n
  192. rllib example get atari-ppo\n
  193. rllib example run atari-ppo\n
  194. """
  195. def cli():
  196. # Keep this function here, it's referenced in the setup.py file, and exposes
  197. # the CLI as entry point ("rllib" command).
  198. deprecation_warning(
  199. old="RLlib CLI (`rllib train` and `rllib evaluate`)",
  200. help="The RLlib CLI scripts will be deprecated soon! "
  201. "Use RLlib's python API instead, which is more flexible and offers a more "
  202. "unified approach to running RL experiments, evaluating policies, and "
  203. "creating checkpoints for later deployments. See here for a quick intro: "
  204. "https://docs.ray.io/en/latest/rllib/rllib-training.html#using-the-python-api",
  205. error=False,
  206. )
  207. app()
  208. if __name__ == "__main__":
  209. cli()