rllib_in_60s.py 502 B

123456789101112131415161718192021
  1. # flake8: noqa
  2. # __rllib-in-60s-begin__
  3. from ray.rllib.algorithms.ppo import PPOConfig
  4. config = ( # 1. Configure the algorithm,
  5. PPOConfig()
  6. .environment("Taxi-v3")
  7. .rollouts(num_rollout_workers=2)
  8. .framework("torch")
  9. .training(model={"fcnet_hiddens": [64, 64]})
  10. .evaluation(evaluation_num_workers=1)
  11. )
  12. algo = config.build() # 2. build the algorithm,
  13. for _ in range(5):
  14. print(algo.train()) # 3. train it,
  15. algo.evaluate() # 4. and evaluate it.
  16. # __rllib-in-60s-end__