scripts.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #!/usr/bin/env python
  2. import argparse
  3. from ray.rllib import evaluate, train
  4. from ray.rllib.utils.deprecation import deprecation_warning
  5. EXAMPLE_USAGE = """
  6. Example usage for training:
  7. rllib train --run DQN --env CartPole-v0
  8. Example usage for evaluate (aka: "rollout"):
  9. rllib evaluate /trial_dir/checkpoint_000001/checkpoint-1 --run DQN
  10. """
  11. def cli():
  12. parser = argparse.ArgumentParser(
  13. description="Train or evaluate an RLlib Trainer.",
  14. formatter_class=argparse.RawDescriptionHelpFormatter,
  15. epilog=EXAMPLE_USAGE)
  16. subcommand_group = parser.add_subparsers(
  17. help="Commands to train or evaluate an RLlib agent.", dest="command")
  18. # see _SubParsersAction.add_parser in
  19. # https://github.com/python/cpython/blob/master/Lib/argparse.py
  20. train_parser = train.create_parser(
  21. lambda **kwargs: subcommand_group.add_parser("train", **kwargs))
  22. evaluate_parser = evaluate.create_parser(
  23. lambda **kwargs: subcommand_group.add_parser("evaluate", **kwargs))
  24. rollout_parser = evaluate.create_parser(
  25. lambda **kwargs: subcommand_group.add_parser("rollout", **kwargs))
  26. options = parser.parse_args()
  27. if options.command == "train":
  28. train.run(options, train_parser)
  29. elif options.command == "evaluate":
  30. evaluate.run(options, evaluate_parser)
  31. elif options.command == "rollout":
  32. deprecation_warning(
  33. old="rllib rollout", new="rllib evaluate", error=False)
  34. evaluate.run(options, rollout_parser)
  35. else:
  36. parser.print_help()