custom_metrics_and_callbacks_legacy.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """Deprecated API; see custom_metrics_and_callbacks.py instead."""
  2. import argparse
  3. import numpy as np
  4. import os
  5. import ray
  6. from ray import tune
  7. def on_episode_start(info):
  8. episode = info["episode"]
  9. print("episode {} started".format(episode.episode_id))
  10. episode.user_data["pole_angles"] = []
  11. episode.hist_data["pole_angles"] = []
  12. def on_episode_step(info):
  13. episode = info["episode"]
  14. pole_angle = abs(episode.last_observation_for()[2])
  15. raw_angle = abs(episode.last_raw_obs_for()[2])
  16. assert pole_angle == raw_angle
  17. episode.user_data["pole_angles"].append(pole_angle)
  18. def on_episode_end(info):
  19. episode = info["episode"]
  20. pole_angle = np.mean(episode.user_data["pole_angles"])
  21. print("episode {} ended with length {} and pole angles {}".format(
  22. episode.episode_id, episode.length, pole_angle))
  23. episode.custom_metrics["pole_angle"] = pole_angle
  24. episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
  25. def on_sample_end(info):
  26. print("returned sample batch of size {}".format(info["samples"].count))
  27. def on_train_result(info):
  28. print("trainer.train() result: {} -> {} episodes".format(
  29. info["trainer"], info["result"]["episodes_this_iter"]))
  30. # you can mutate the result dict to add new fields to return
  31. info["result"]["callback_ok"] = True
  32. def on_postprocess_traj(info):
  33. episode = info["episode"]
  34. batch = info["post_batch"]
  35. print("postprocessed {} steps".format(batch.count))
  36. if "num_batches" not in episode.custom_metrics:
  37. episode.custom_metrics["num_batches"] = 0
  38. episode.custom_metrics["num_batches"] += 1
  39. if __name__ == "__main__":
  40. parser = argparse.ArgumentParser()
  41. parser.add_argument("--stop-iters", type=int, default=2000)
  42. args = parser.parse_args()
  43. ray.init()
  44. trials = tune.run(
  45. "PG",
  46. stop={
  47. "training_iteration": args.stop_iters,
  48. },
  49. config={
  50. "env": "CartPole-v0",
  51. "callbacks": {
  52. "on_episode_start": on_episode_start,
  53. "on_episode_step": on_episode_step,
  54. "on_episode_end": on_episode_end,
  55. "on_sample_end": on_sample_end,
  56. "on_train_result": on_train_result,
  57. "on_postprocess_traj": on_postprocess_traj,
  58. },
  59. "framework": "tf",
  60. # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
  61. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  62. }).trials
  63. # verify custom metrics for integration tests
  64. custom_metrics = trials[0].last_result["custom_metrics"]
  65. print(custom_metrics)
  66. assert "pole_angle_mean" in custom_metrics
  67. assert "pole_angle_min" in custom_metrics
  68. assert "pole_angle_max" in custom_metrics
  69. assert "num_batches_mean" in custom_metrics
  70. assert "callback_ok" in trials[0].last_result