policy_inference_after_training.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """
  2. Example showing how you can use your trained policy for inference
  3. (computing actions) in an environment.
  4. Includes options for LSTM-based models (--use-lstm), attention-net models
  5. (--use-attention), and plain (non-recurrent) models.
  6. """
  7. import argparse
  8. import gymnasium as gym
  9. import os
  10. import ray
  11. from ray import air, tune
  12. from ray.rllib.algorithms.algorithm import Algorithm
  13. from ray.tune.registry import get_trainable_cls
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument(
  16. "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
  17. )
  18. parser.add_argument("--num-cpus", type=int, default=0)
  19. parser.add_argument(
  20. "--framework",
  21. choices=["tf", "tf2", "torch"],
  22. default="torch",
  23. help="The DL framework specifier.",
  24. )
  25. parser.add_argument(
  26. "--stop-iters",
  27. type=int,
  28. default=200,
  29. help="Number of iterations to train before we do inference.",
  30. )
  31. parser.add_argument(
  32. "--stop-timesteps",
  33. type=int,
  34. default=100000,
  35. help="Number of timesteps to train before we do inference.",
  36. )
  37. parser.add_argument(
  38. "--stop-reward",
  39. type=float,
  40. default=150.0,
  41. help="Reward at which we stop training before we do inference.",
  42. )
  43. parser.add_argument(
  44. "--explore-during-inference",
  45. action="store_true",
  46. help="Whether the trained policy should use exploration during action "
  47. "inference.",
  48. )
  49. parser.add_argument(
  50. "--num-episodes-during-inference",
  51. type=int,
  52. default=10,
  53. help="Number of episodes to do inference over after training.",
  54. )
  55. if __name__ == "__main__":
  56. args = parser.parse_args()
  57. ray.init(num_cpus=args.num_cpus or None)
  58. config = (
  59. get_trainable_cls(args.run)
  60. .get_default_config()
  61. .environment("FrozenLake-v1")
  62. # Run with tracing enabled for tf2?
  63. .framework(args.framework)
  64. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  65. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  66. )
  67. stop = {
  68. "training_iteration": args.stop_iters,
  69. "timesteps_total": args.stop_timesteps,
  70. "episode_reward_mean": args.stop_reward,
  71. }
  72. print("Training policy until desired reward/timesteps/iterations. ...")
  73. tuner = tune.Tuner(
  74. args.run,
  75. param_space=config.to_dict(),
  76. run_config=air.RunConfig(
  77. stop=stop,
  78. verbose=2,
  79. checkpoint_config=air.CheckpointConfig(
  80. checkpoint_frequency=1, checkpoint_at_end=True
  81. ),
  82. ),
  83. )
  84. results = tuner.fit()
  85. print("Training completed. Restoring new Algorithm for action inference.")
  86. # Get the last checkpoint from the above training run.
  87. checkpoint = results.get_best_result().checkpoint
  88. # Create new Algorithm and restore its state from the last checkpoint.
  89. algo = Algorithm.from_checkpoint(checkpoint)
  90. # Create the env to do inference in.
  91. env = gym.make("FrozenLake-v1")
  92. obs, info = env.reset()
  93. num_episodes = 0
  94. episode_reward = 0.0
  95. while num_episodes < args.num_episodes_during_inference:
  96. # Compute an action (`a`).
  97. a = algo.compute_single_action(
  98. observation=obs,
  99. explore=args.explore_during_inference,
  100. policy_id="default_policy", # <- default value
  101. )
  102. # Send the computed action `a` to the env.
  103. obs, reward, done, truncated, _ = env.step(a)
  104. episode_reward += reward
  105. # Is the episode `done`? -> Reset.
  106. if done:
  107. print(f"Episode done: Total reward = {episode_reward}")
  108. obs, info = env.reset()
  109. num_episodes += 1
  110. episode_reward = 0.0
  111. algo.stop()
  112. ray.shutdown()