checkpoint_by_custom_criteria.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import argparse
  2. import os
  3. import ray
  4. from ray import air, tune
  5. from ray.tune.registry import get_trainable_cls
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument(
  8. "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
  9. )
  10. parser.add_argument("--num-cpus", type=int, default=0)
  11. parser.add_argument(
  12. "--framework",
  13. choices=["tf", "tf2", "torch"],
  14. default="torch",
  15. help="The DL framework specifier.",
  16. )
  17. parser.add_argument("--stop-iters", type=int, default=200)
  18. parser.add_argument("--stop-timesteps", type=int, default=100000)
  19. parser.add_argument("--stop-reward", type=float, default=150.0)
  20. parser.add_argument(
  21. "--local-mode",
  22. action="store_true",
  23. help="Init Ray in local mode for easier debugging.",
  24. )
  25. if __name__ == "__main__":
  26. args = parser.parse_args()
  27. ray.init(num_cpus=args.num_cpus or None, local_mode=args.local_mode)
  28. # Simple generic config.
  29. config = (
  30. get_trainable_cls(args.run)
  31. .get_default_config()
  32. .environment("CartPole-v1")
  33. # Run with tracing enabled for tf2.
  34. .framework(args.framework)
  35. # Run 3 trials.
  36. .training(
  37. lr=tune.grid_search([0.01, 0.001, 0.0001]), train_batch_size=2341
  38. ) # TEST
  39. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  40. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
  41. )
  42. stop = {
  43. "training_iteration": args.stop_iters,
  44. "timesteps_total": args.stop_timesteps,
  45. "episode_reward_mean": args.stop_reward,
  46. }
  47. # Run tune for some iterations and generate checkpoints.
  48. tuner = tune.Tuner(
  49. args.run,
  50. param_space=config.to_dict(),
  51. run_config=air.RunConfig(
  52. stop=stop, checkpoint_config=air.CheckpointConfig(checkpoint_frequency=1)
  53. ),
  54. )
  55. results = tuner.fit()
  56. # Get the best of the 3 trials by using some metric.
  57. # NOTE: Choosing the min `episodes_this_iter` automatically picks the trial
  58. # with the best performance (over the entire run (scope="all")):
  59. # The fewer episodes, the longer each episode lasted, the more reward we
  60. # got each episode.
  61. # Setting scope to "last", "last-5-avg", or "last-10-avg" will only compare
  62. # (using `mode=min|max`) the average values of the last 1, 5, or 10
  63. # iterations with each other, respectively.
  64. # Setting scope to "avg" will compare (using `mode`=min|max) the average
  65. # values over the entire run.
  66. metric = "episodes_this_iter"
  67. # notice here `scope` is `all`, meaning for each trial,
  68. # all results (not just the last one) will be examined.
  69. best_result = results.get_best_result(metric=metric, mode="min", scope="all")
  70. value_best_metric = best_result.metrics_dataframe[metric].min()
  71. print(
  72. "Best trial's lowest episode length (over all "
  73. "iterations): {}".format(value_best_metric)
  74. )
  75. # Confirm, we picked the right trial.
  76. assert value_best_metric <= results.get_dataframe()[metric].min()
  77. # Get the best checkpoints from the trial, based on different metrics.
  78. # Checkpoint with the lowest policy loss value:
  79. if config._enable_learner_api:
  80. policy_loss_key = "info/learner/default_policy/policy_loss"
  81. else:
  82. policy_loss_key = "info/learner/default_policy/learner_stats/policy_loss"
  83. ckpt = results.get_best_result(metric=policy_loss_key, mode="min").checkpoint
  84. print("Lowest pol-loss: {}".format(ckpt))
  85. # Checkpoint with the highest value-function loss:
  86. if config._enable_learner_api:
  87. vf_loss_key = "info/learner/default_policy/vf_loss"
  88. else:
  89. vf_loss_key = "info/learner/default_policy/learner_stats/vf_loss"
  90. ckpt = results.get_best_result(metric=vf_loss_key, mode="max").checkpoint
  91. print("Highest vf-loss: {}".format(ckpt))
  92. ray.shutdown()