custom_experiment.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. """Example of a custom experiment wrapped around an RLlib trainer."""
  2. import argparse
  3. import ray
  4. from ray import tune
  5. from ray.rllib.agents import ppo
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument("--train-iterations", type=int, default=10)
  8. def experiment(config):
  9. iterations = config.pop("train-iterations")
  10. train_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
  11. checkpoint = None
  12. train_results = {}
  13. # Train
  14. for i in range(iterations):
  15. train_results = train_agent.train()
  16. if i % 2 == 0 or i == iterations - 1:
  17. checkpoint = train_agent.save(tune.get_trial_dir())
  18. tune.report(**train_results)
  19. train_agent.stop()
  20. # Manual Eval
  21. config["num_workers"] = 0
  22. eval_agent = ppo.PPOTrainer(config=config, env="CartPole-v0")
  23. eval_agent.restore(checkpoint)
  24. env = eval_agent.workers.local_worker().env
  25. obs = env.reset()
  26. done = False
  27. eval_results = {"eval_reward": 0, "eval_eps_length": 0}
  28. while not done:
  29. action = eval_agent.compute_single_action(obs)
  30. next_obs, reward, done, info = env.step(action)
  31. eval_results["eval_reward"] += reward
  32. eval_results["eval_eps_length"] += 1
  33. results = {**train_results, **eval_results}
  34. tune.report(results)
  35. if __name__ == "__main__":
  36. args = parser.parse_args()
  37. ray.init(num_cpus=3)
  38. config = ppo.DEFAULT_CONFIG.copy()
  39. config["train-iterations"] = args.train_iterations
  40. config["env"] = "CartPole-v0"
  41. tune.run(
  42. experiment,
  43. config=config,
  44. resources_per_trial=ppo.PPOTrainer.default_resource_request(config))