sb2rllib_rllib_example.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. """
  2. Example script on how to train, save, load, and test an RLlib agent.
  3. Equivalent script with stable baselines: sb2rllib_sb_example.py.
  4. Demonstrates transition from stable_baselines to Ray RLlib.
  5. Run example: python sb2rllib_rllib_example.py
  6. """
  7. import gymnasium as gym
  8. from ray import tune, air
  9. import ray.rllib.algorithms.ppo as ppo
  10. # settings used for both stable baselines and rllib
  11. env_name = "CartPole-v1"
  12. train_steps = 10000
  13. learning_rate = 1e-3
  14. save_dir = "saved_models"
  15. # training and saving
  16. analysis = tune.Tuner(
  17. "PPO",
  18. run_config=air.RunConfig(
  19. stop={"timesteps_total": train_steps},
  20. local_dir=save_dir,
  21. checkpoint_config=air.CheckpointConfig(
  22. checkpoint_at_end=True,
  23. ),
  24. ),
  25. param_space={"env": env_name, "lr": learning_rate},
  26. ).fit()
  27. # retrieve the checkpoint path
  28. analysis.default_metric = "episode_reward_mean"
  29. analysis.default_mode = "max"
  30. checkpoint_path = analysis.get_best_checkpoint(trial=analysis.get_best_trial())
  31. print(f"Trained model saved at {checkpoint_path}")
  32. # load and restore model
  33. agent = ppo.PPO(env=env_name)
  34. agent.restore(checkpoint_path)
  35. print(f"Agent loaded from saved model at {checkpoint_path}")
  36. # inference
  37. env = gym.make(env_name)
  38. obs, info = env.reset()
  39. for i in range(1000):
  40. action = agent.compute_single_action(obs)
  41. obs, reward, terminated, truncated, info = env.step(action)
  42. env.render()
  43. if terminated or truncated:
  44. print(f"Cart pole ended after {i} steps.")
  45. break