mbmpo_cartpole_v1_model_based.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import argparse
  2. from gymnasium.wrappers import TimeLimit
  3. from rllib_mbmpo.env.mbmpo_env import CartPoleWrapper
  4. from rllib_mbmpo.mbmpo import MBMPO, MBMPOConfig
  5. import ray
  6. from ray import air, tune
  7. from ray.tune.registry import register_env
  8. def get_cli_args():
  9. """Create CLI parser and return parsed arguments"""
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--run-as-test", action="store_true", default=False)
  12. args = parser.parse_args()
  13. print(f"Running with following CLI args: {args}")
  14. return args
  15. if __name__ == "__main__":
  16. args = get_cli_args()
  17. ray.init()
  18. register_env(
  19. "cartpole-mbmpo",
  20. lambda env_ctx: TimeLimit(CartPoleWrapper(), max_episode_steps=200),
  21. )
  22. config = (
  23. MBMPOConfig()
  24. # .rollouts(num_rollout_workers=7, num_envs_per_worker=20)
  25. .framework("torch")
  26. .environment("cartpole-mbmpo")
  27. .rollouts(num_rollout_workers=4)
  28. # .training(dynamics_model={"ensemble_size": 2})
  29. # )
  30. .training(
  31. inner_adaptation_steps=1,
  32. maml_optimizer_steps=8,
  33. gamma=0.99,
  34. lambda_=1.0,
  35. lr=0.001,
  36. clip_param=0.5,
  37. kl_target=0.003,
  38. kl_coeff=0.0000000001,
  39. inner_lr=0.001,
  40. num_maml_steps=15,
  41. model={"fcnet_hiddens": [32, 32], "free_log_std": True},
  42. )
  43. )
  44. if args.run_as_test:
  45. stop = {
  46. "episode_reward_mean": 190,
  47. "training_iteration": 20,
  48. }
  49. else:
  50. stop = {"training_iteration": 1}
  51. tuner = tune.Tuner(
  52. MBMPO,
  53. param_space=config.to_dict(),
  54. run_config=air.RunConfig(
  55. stop=stop,
  56. failure_config=air.FailureConfig(fail_fast="raise"),
  57. ),
  58. )
  59. results = tuner.fit()