scripts.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #!/usr/bin/env python
  2. import collections
  3. from rich.console import Console
  4. from rich.table import Table
  5. import typer
  6. from ray.rllib import train as train_module
  7. from ray.rllib.common import CLIArguments as cli
  8. from ray.rllib.common import (
  9. EXAMPLES,
  10. FrameworkEnum,
  11. example_help,
  12. download_example_file,
  13. )
  14. # Main Typer CLI app
  15. app = typer.Typer()
  16. example_app = typer.Typer()
  17. def example_error(example_id: str):
  18. return ValueError(
  19. f"Example {example_id} not found. Use `rllib example list` "
  20. f"to see available examples."
  21. )
  22. @example_app.callback()
  23. def example_callback():
  24. """RLlib command-line interface to run built-in examples. You can choose to list
  25. all available examples, get more information on an example or run a specific
  26. example.
  27. """
  28. pass
  29. @example_app.command()
  30. def list(
  31. filter: str = typer.Option(None, "--filter", "-f", help=example_help.get("filter"))
  32. ):
  33. """List all available RLlib examples that can be run from the command line.
  34. Note that many of these examples require specific hardware (e.g. a certain number
  35. of GPUs) to work.\n\n
  36. Example usage: `rllib example list --filter=cartpole`
  37. """
  38. table = Table(title="RLlib Examples")
  39. table.add_column("Example ID", justify="left", style="cyan", no_wrap=True)
  40. table.add_column("Description", justify="left", style="magenta")
  41. sorted_examples = collections.OrderedDict(sorted(EXAMPLES.items()))
  42. for name, value in sorted_examples.items():
  43. if filter:
  44. if filter.lower() in name:
  45. table.add_row(name, value["description"])
  46. else:
  47. table.add_row(name, value["description"])
  48. console = Console()
  49. console.print(table)
  50. console.print(
  51. "Run any RLlib example as using 'rllib example run <Example ID>'."
  52. "See 'rllib example run --help' for more information."
  53. )
  54. def get_example_file(example_id):
  55. """Simple helper function to get the example file for a given example ID."""
  56. if example_id not in EXAMPLES:
  57. raise example_error(example_id)
  58. example = EXAMPLES[example_id]
  59. assert (
  60. "file" in example.keys()
  61. ), f"Example {example_id} does not have a 'file' attribute."
  62. return example.get("file")
  63. @example_app.command()
  64. def get(example_id: str = typer.Argument(..., help="The example ID of the example.")):
  65. """Print the configuration of an example.\n\n
  66. Example usage: `rllib example get atari-a2c`
  67. """
  68. example_file = get_example_file(example_id)
  69. example_file, temp_file = download_example_file(example_file)
  70. with open(example_file) as f:
  71. console = Console()
  72. console.print(f.read())
  73. @example_app.command()
  74. def run(example_id: str = typer.Argument(..., help="Example ID to run.")):
  75. """Run an RLlib example from the command line by simply providing its ID.\n\n
  76. Example usage: `rllib example run pong-impala`
  77. """
  78. example = EXAMPLES[example_id]
  79. example_file = get_example_file(example_id)
  80. example_file, temp_file = download_example_file(example_file)
  81. stop = example.get("stop")
  82. train_module.file(
  83. config_file=example_file,
  84. stop=stop,
  85. checkpoint_freq=1,
  86. checkpoint_at_end=True,
  87. keep_checkpoints_num=None,
  88. checkpoint_score_attr="training_iteration",
  89. framework=FrameworkEnum.tf2,
  90. v=True,
  91. vv=False,
  92. trace=False,
  93. local_mode=False,
  94. ray_address=None,
  95. ray_ui=False,
  96. ray_num_cpus=None,
  97. ray_num_gpus=None,
  98. ray_num_nodes=None,
  99. ray_object_store_memory=None,
  100. resume=False,
  101. scheduler="FIFO",
  102. scheduler_config="{}",
  103. )
  104. if temp_file:
  105. temp_file.close()
  106. # Register all subcommands
  107. app.add_typer(example_app, name="example")
  108. app.add_typer(train_module.train_app, name="train")
  109. @app.command()
  110. def evaluate(
  111. checkpoint: str = cli.Checkpoint,
  112. algo: str = cli.Algo,
  113. env: str = cli.Env,
  114. local_mode: bool = cli.LocalMode,
  115. render: bool = cli.Render,
  116. steps: int = cli.Steps,
  117. episodes: int = cli.Episodes,
  118. out: str = cli.Out,
  119. config: str = cli.Config,
  120. save_info: bool = cli.SaveInfo,
  121. use_shelve: bool = cli.UseShelve,
  122. track_progress: bool = cli.TrackProgress,
  123. ):
  124. """Roll out a reinforcement learning agent given a checkpoint argument.
  125. You have to provide an environment ("--env") an an RLlib algorithm ("--algo") to
  126. evaluate your checkpoint.
  127. Example usage:\n\n
  128. rllib evaluate /tmp/ray/checkpoint_dir/checkpoint-0 --algo DQN --env CartPole-v1
  129. --steps 1000000 --out rollouts.pkl
  130. """
  131. from ray.rllib import evaluate as evaluate_module
  132. evaluate_module.run(
  133. checkpoint=checkpoint,
  134. algo=algo,
  135. env=env,
  136. local_mode=local_mode,
  137. render=render,
  138. steps=steps,
  139. episodes=episodes,
  140. out=out,
  141. config=config,
  142. save_info=save_info,
  143. use_shelve=use_shelve,
  144. track_progress=track_progress,
  145. )
  146. @app.command()
  147. def rollout(
  148. checkpoint: str = cli.Checkpoint,
  149. algo: str = cli.Algo,
  150. env: str = cli.Env,
  151. local_mode: bool = cli.LocalMode,
  152. render: bool = cli.Render,
  153. steps: int = cli.Steps,
  154. episodes: int = cli.Episodes,
  155. out: str = cli.Out,
  156. config: str = cli.Config,
  157. save_info: bool = cli.SaveInfo,
  158. use_shelve: bool = cli.UseShelve,
  159. track_progress: bool = cli.TrackProgress,
  160. ):
  161. """Old rollout script. Please use `rllib evaluate` instead."""
  162. from ray.rllib.utils.deprecation import deprecation_warning
  163. deprecation_warning(old="rllib rollout", new="rllib evaluate", error=True)
  164. @app.callback()
  165. def main_helper():
  166. """Welcome to the\n
  167. . ╔▄▓▓▓▓▄\n
  168. . ╔██▀╙╙╙▀██▄\n
  169. . ╫█████████████▓ ╫████▓ ╫████▓ ██▌ ▐██ ╫████▒\n
  170. . ╫███████████████▓ ╫█████▓ ╫█████▓ ╫██ ╫██ ╫██████▒\n
  171. . ╫█████▓ ████▓ ╫█████▓ ╫█████▓ ╙▓██████▀ ╫██████████████▒\n
  172. . ╫███████████████▓ ╫█████▓ ╫█████▓ ╫█▒ ╫████████████████▒\n
  173. . ╫█████████████▓ ╫█████▓ ╫█████▓ ╫█▒ ╫██████▒ ╫█████▒\n
  174. . ╫█████▓███████▓ ╫█████▓ ╫█████▓ ╫█▒ ╫██████▒ ╫█████▒\n
  175. . ╫█████▓ ██████▓ ╫████████████████▄ ╫█████▓ ╫█▒ ╫████████████████▒\n
  176. . ╫█████▓ ████▓ ╫█████████████████ ╫█████▓ ╫█▒ ╫██████████████▒\n
  177. . ╣▓▓▓▓▓▓▓▓▓▓▓▓██▓▓▓▓▓▓▓▓▓▓▓▓▄\n
  178. . ╫██╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╙╫█▒\n
  179. . ╫█ Command Line Interface █▒\n
  180. . ╫██▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄╣█▒\n
  181. . ▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀▀\n
  182. .\n
  183. Example usage for training:\n
  184. rllib train --algo DQN --env CartPole-v1\n
  185. rllib train file tuned_examples/ppo/pendulum-ppo.yaml\n\n
  186. Example usage for evaluation:\n
  187. rllib evaluate /trial_dir/checkpoint_000001/checkpoint-1 --algo DQN\n\n
  188. Example usage for built-in examples:\n
  189. rllib example list\n
  190. rllib example get atari-ppo\n
  191. rllib example run atari-ppo\n
  192. """
  193. def cli():
  194. # Keep this function here, it's referenced in the setup.py file, and exposes
  195. # the CLI as entry point ("rllib" command).
  196. app()
  197. if __name__ == "__main__":
  198. cli()