1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- """Example of a custom experiment wrapped around an RLlib trainer."""
- import argparse
- import ray
- from ray import tune
- from ray.rllib.agents import ppo
- parser = argparse.ArgumentParser()
- parser.add_argument("--train-iterations", type=int, default=10)
- def experiment(config):
- iterations = config.pop("train-iterations")
- train_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
- checkpoint = None
- train_results = {}
- # Train
- for i in range(iterations):
- train_results = train_agent.train()
- if i % 2 == 0 or i == iterations - 1:
- checkpoint = train_agent.save(tune.get_trial_dir())
- tune.report(**train_results)
- train_agent.stop()
- # Manual Eval
- config["num_workers"] = 0
- eval_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
- eval_agent.restore(checkpoint)
- env = eval_agent.workers.local_worker().env
- obs = env.reset()
- done = False
- eval_results = {"eval_reward": 0, "eval_eps_length": 0}
- while not done:
- action = eval_agent.compute_single_action(obs)
- next_obs, reward, done, info = env.step(action)
- eval_results["eval_reward"] += reward
- eval_results["eval_eps_length"] += 1
- results = {**train_results, **eval_results}
- tune.report(results)
- if __name__ == "__main__":
- args = parser.parse_args()
- ray.init(num_cpus=3)
- config = ppo.DEFAULT_CONFIG.copy()
- config["train-iterations"] = args.train_iterations
- config["env"] = "CartPole-v0"
- tune.run(
- experiment,
- config=config,
- resources_per_trial=ppo.PPOTrainer.default_resource_request(config))
|