trajectory_view_api.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import argparse
  2. import numpy as np
  3. import ray
  4. from ray.rllib.agents.ppo import PPOTrainer
  5. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  6. from ray.rllib.examples.models.trajectory_view_utilizing_models import \
  7. FrameStackingCartPoleModel, TorchFrameStackingCartPoleModel
  8. from ray.rllib.models.catalog import ModelCatalog
  9. from ray.rllib.utils.framework import try_import_tf
  10. from ray.rllib.utils.test_utils import check_learning_achieved
  11. from ray import tune
  12. tf1, tf, tfv = try_import_tf()
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument(
  15. "--run",
  16. type=str,
  17. default="PPO",
  18. help="The RLlib-registered algorithm to use.")
  19. parser.add_argument(
  20. "--framework",
  21. choices=["tf", "tf2", "tfe", "torch"],
  22. default="tf",
  23. help="The DL framework specifier.")
  24. parser.add_argument(
  25. "--as-test",
  26. action="store_true",
  27. help="Whether this script should be run as a test: --stop-reward must "
  28. "be achieved within --stop-timesteps AND --stop-iters.")
  29. parser.add_argument(
  30. "--stop-iters",
  31. type=int,
  32. default=50,
  33. help="Number of iterations to train.")
  34. parser.add_argument(
  35. "--stop-timesteps",
  36. type=int,
  37. default=200000,
  38. help="Number of timesteps to train.")
  39. parser.add_argument(
  40. "--stop-reward",
  41. type=float,
  42. default=150.0,
  43. help="Reward at which we stop training.")
  44. if __name__ == "__main__":
  45. args = parser.parse_args()
  46. ray.init(num_cpus=3)
  47. num_frames = 16
  48. ModelCatalog.register_custom_model(
  49. "frame_stack_model", FrameStackingCartPoleModel
  50. if args.framework != "torch" else TorchFrameStackingCartPoleModel)
  51. config = {
  52. "env": StatelessCartPole,
  53. "model": {
  54. "vf_share_layers": True,
  55. "custom_model": "frame_stack_model",
  56. "custom_model_config": {
  57. "num_frames": num_frames,
  58. },
  59. # To compare against a simple LSTM:
  60. # "use_lstm": True,
  61. # "lstm_use_prev_action": True,
  62. # "lstm_use_prev_reward": True,
  63. # To compare against a simple attention net:
  64. # "use_attention": True,
  65. # "attention_use_n_prev_actions": 1,
  66. # "attention_use_n_prev_rewards": 1,
  67. },
  68. "num_sgd_iter": 5,
  69. "vf_loss_coeff": 0.0001,
  70. "framework": args.framework,
  71. }
  72. stop = {
  73. "training_iteration": args.stop_iters,
  74. "timesteps_total": args.stop_timesteps,
  75. "episode_reward_mean": args.stop_reward,
  76. }
  77. results = tune.run(
  78. args.run, config=config, stop=stop, verbose=2, checkpoint_at_end=True)
  79. if args.as_test:
  80. check_learning_achieved(results, args.stop_reward)
  81. checkpoints = results.get_trial_checkpoints_paths(
  82. trial=results.get_best_trial("episode_reward_mean", mode="max"),
  83. metric="episode_reward_mean")
  84. checkpoint_path = checkpoints[0][0]
  85. trainer = PPOTrainer(config)
  86. trainer.restore(checkpoint_path)
  87. # Inference loop.
  88. env = StatelessCartPole()
  89. # Run manual inference loop for n episodes.
  90. for _ in range(10):
  91. episode_reward = 0.0
  92. reward = 0.0
  93. action = 0
  94. done = False
  95. obs = env.reset()
  96. while not done:
  97. # Create a dummy action using the same observation n times,
  98. # as well as dummy prev-n-actions and prev-n-rewards.
  99. action, state, logits = trainer.compute_single_action(
  100. input_dict={
  101. "obs": obs,
  102. "prev_n_obs": np.stack([obs for _ in range(num_frames)]),
  103. "prev_n_actions": np.stack([0 for _ in range(num_frames)]),
  104. "prev_n_rewards": np.stack(
  105. [1.0 for _ in range(num_frames)]),
  106. },
  107. full_fetch=True)
  108. obs, reward, done, info = env.step(action)
  109. episode_reward += reward
  110. print(f"Episode reward={episode_reward}")
  111. ray.shutdown()