12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- """Deprecated API; see custom_metrics_and_callbacks.py instead."""
- import argparse
- import numpy as np
- import os
- import ray
- from ray import tune
- def on_episode_start(info):
- episode = info["episode"]
- print("episode {} started".format(episode.episode_id))
- episode.user_data["pole_angles"] = []
- episode.hist_data["pole_angles"] = []
- def on_episode_step(info):
- episode = info["episode"]
- pole_angle = abs(episode.last_observation_for()[2])
- raw_angle = abs(episode.last_raw_obs_for()[2])
- assert pole_angle == raw_angle
- episode.user_data["pole_angles"].append(pole_angle)
- def on_episode_end(info):
- episode = info["episode"]
- pole_angle = np.mean(episode.user_data["pole_angles"])
- print("episode {} ended with length {} and pole angles {}".format(
- episode.episode_id, episode.length, pole_angle))
- episode.custom_metrics["pole_angle"] = pole_angle
- episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
- def on_sample_end(info):
- print("returned sample batch of size {}".format(info["samples"].count))
- def on_train_result(info):
- print("trainer.train() result: {} -> {} episodes".format(
- info["trainer"], info["result"]["episodes_this_iter"]))
- # you can mutate the result dict to add new fields to return
- info["result"]["callback_ok"] = True
- def on_postprocess_traj(info):
- episode = info["episode"]
- batch = info["post_batch"]
- print("postprocessed {} steps".format(batch.count))
- if "num_batches" not in episode.custom_metrics:
- episode.custom_metrics["num_batches"] = 0
- episode.custom_metrics["num_batches"] += 1
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--stop-iters", type=int, default=2000)
- args = parser.parse_args()
- ray.init()
- trials = tune.run(
- "PG",
- stop={
- "training_iteration": args.stop_iters,
- },
- config={
- "env": "CartPole-v0",
- "callbacks": {
- "on_episode_start": on_episode_start,
- "on_episode_step": on_episode_step,
- "on_episode_end": on_episode_end,
- "on_sample_end": on_sample_end,
- "on_train_result": on_train_result,
- "on_postprocess_traj": on_postprocess_traj,
- },
- "framework": "tf",
- # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- }).trials
- # verify custom metrics for integration tests
- custom_metrics = trials[0].last_result["custom_metrics"]
- print(custom_metrics)
- assert "pole_angle_mean" in custom_metrics
- assert "pole_angle_min" in custom_metrics
- assert "pole_angle_max" in custom_metrics
- assert "num_batches_mean" in custom_metrics
- assert "callback_ok" in trials[0].last_result
|