sb2rllib_sb_example.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. """
  2. Example script on how to train, save, load, and test a stable baselines 2 agent
  3. Code taken and adjusted from SB2 docs:
  4. https://stable-baselines.readthedocs.io/en/master/guide/quickstart.html
  5. Equivalent script with RLlib: sb2rllib_rllib_example.py
  6. """
  7. import gym
  8. from stable_baselines.common.policies import MlpPolicy
  9. from stable_baselines import PPO2
  10. # settings used for both stable baselines and rllib
  11. env_name = "CartPole-v1"
  12. train_steps = 10000
  13. learning_rate = 1e-3
  14. save_dir = "saved_models"
  15. save_path = f"{save_dir}/sb_model_{train_steps}steps"
  16. env = gym.make(env_name)
  17. # training and saving
  18. model = PPO2(MlpPolicy, env, learning_rate=learning_rate, verbose=1)
  19. model.learn(total_timesteps=train_steps)
  20. model.save(save_path)
  21. print(f"Trained model saved at {save_path}")
  22. # delete and load model (just for illustration)
  23. del model
  24. model = PPO2.load(save_path)
  25. print(f"Agent loaded from saved model at {save_path}")
  26. # inference
  27. obs = env.reset()
  28. for i in range(1000):
  29. action, _states = model.predict(obs)
  30. obs, reward, done, info = env.step(action)
  31. env.render()
  32. if done:
  33. print(f"Cart pole dropped after {i} steps.")
  34. break