bare_metal_policy_with_custom_view_reqs.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import argparse
  2. import os
  3. import ray
  4. from ray.rllib.agents.trainer_template import build_trainer
  5. from ray.rllib.examples.policy.bare_metal_policy_with_custom_view_reqs \
  6. import BareMetalPolicyWithCustomViewReqs
  7. from ray import tune
  8. def get_cli_args():
  9. """Create CLI parser and return parsed arguments"""
  10. parser = argparse.ArgumentParser()
  11. # general args
  12. parser.add_argument(
  13. "--run", default="PPO", help="The RLlib-registered algorithm to use.")
  14. parser.add_argument("--num-cpus", type=int, default=3)
  15. parser.add_argument(
  16. "--stop-iters",
  17. type=int,
  18. default=1,
  19. help="Number of iterations to train.")
  20. parser.add_argument(
  21. "--stop-timesteps",
  22. type=int,
  23. default=100000,
  24. help="Number of timesteps to train.")
  25. parser.add_argument(
  26. "--local-mode",
  27. action="store_true",
  28. help="Init Ray in local mode for easier debugging.")
  29. args = parser.parse_args()
  30. print(f"Running with following CLI args: {args}")
  31. return args
  32. if __name__ == "__main__":
  33. args = get_cli_args()
  34. ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
  35. # Create q custom Trainer class using our custom Policy.
  36. BareMetalPolicyTrainer = build_trainer(
  37. name="MyPolicy", default_policy=BareMetalPolicyWithCustomViewReqs)
  38. config = {
  39. "env": "CartPole-v0",
  40. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  41. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  42. "model": {
  43. # Necessary to get the whole trajectory of 'state_in_0' in the
  44. # sample batch.
  45. "max_seq_len": 1,
  46. },
  47. "num_workers": 1,
  48. # NOTE: Does this have consequences?
  49. # I use it for not loading tensorflow/pytorch.
  50. "framework": None,
  51. "log_level": "DEBUG",
  52. "create_env_on_driver": True,
  53. }
  54. stop = {
  55. "training_iteration": args.stop_iters,
  56. "timesteps_total": args.stop_timesteps,
  57. }
  58. # Train the Trainer with our policy.
  59. results = tune.run(BareMetalPolicyTrainer, config=config, stop=stop)
  60. print(results)