"""Deprecated API; see custom_metrics_and_callbacks.py instead.""" import argparse import numpy as np import os import ray from ray import tune def on_episode_start(info): episode = info["episode"] print("episode {} started".format(episode.episode_id)) episode.user_data["pole_angles"] = [] episode.hist_data["pole_angles"] = [] def on_episode_step(info): episode = info["episode"] pole_angle = abs(episode.last_observation_for()[2]) raw_angle = abs(episode.last_raw_obs_for()[2]) assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle) def on_episode_end(info): episode = info["episode"] pole_angle = np.mean(episode.user_data["pole_angles"]) print("episode {} ended with length {} and pole angles {}".format( episode.episode_id, episode.length, pole_angle)) episode.custom_metrics["pole_angle"] = pole_angle episode.hist_data["pole_angles"] = episode.user_data["pole_angles"] def on_sample_end(info): print("returned sample batch of size {}".format(info["samples"].count)) def on_train_result(info): print("trainer.train() result: {} -> {} episodes".format( info["trainer"], info["result"]["episodes_this_iter"])) # you can mutate the result dict to add new fields to return info["result"]["callback_ok"] = True def on_postprocess_traj(info): episode = info["episode"] batch = info["post_batch"] print("postprocessed {} steps".format(batch.count)) if "num_batches" not in episode.custom_metrics: episode.custom_metrics["num_batches"] = 0 episode.custom_metrics["num_batches"] += 1 if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--stop-iters", type=int, default=2000) args = parser.parse_args() ray.init() trials = tune.run( "PG", stop={ "training_iteration": args.stop_iters, }, config={ "env": "CartPole-v0", "callbacks": { "on_episode_start": on_episode_start, "on_episode_step": on_episode_step, "on_episode_end": on_episode_end, "on_sample_end": on_sample_end, "on_train_result": on_train_result, "on_postprocess_traj": on_postprocess_traj, }, "framework": "tf", # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), }).trials # verify custom metrics for integration tests custom_metrics = trials[0].last_result["custom_metrics"] print(custom_metrics) assert "pole_angle_mean" in custom_metrics assert "pole_angle_min" in custom_metrics assert "pole_angle_max" in custom_metrics assert "num_batches_mean" in custom_metrics assert "callback_ok" in trials[0].last_result