1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import argparse
- from gymnasium.wrappers import TimeLimit
- from rllib_mbmpo.env.mbmpo_env import CartPoleWrapper
- from rllib_mbmpo.mbmpo import MBMPO, MBMPOConfig
- import ray
- from ray import air, tune
- from ray.tune.registry import register_env
- def get_cli_args():
- """Create CLI parser and return parsed arguments"""
- parser = argparse.ArgumentParser()
- parser.add_argument("--run-as-test", action="store_true", default=False)
- args = parser.parse_args()
- print(f"Running with following CLI args: {args}")
- return args
- if __name__ == "__main__":
- args = get_cli_args()
- ray.init()
- register_env(
- "cartpole-mbmpo",
- lambda env_ctx: TimeLimit(CartPoleWrapper(), max_episode_steps=200),
- )
- config = (
- MBMPOConfig()
- # .rollouts(num_rollout_workers=7, num_envs_per_worker=20)
- .framework("torch")
- .environment("cartpole-mbmpo")
- .rollouts(num_rollout_workers=4)
- # .training(dynamics_model={"ensemble_size": 2})
- # )
- .training(
- inner_adaptation_steps=1,
- maml_optimizer_steps=8,
- gamma=0.99,
- lambda_=1.0,
- lr=0.001,
- clip_param=0.5,
- kl_target=0.003,
- kl_coeff=0.0000000001,
- inner_lr=0.001,
- num_maml_steps=15,
- model={"fcnet_hiddens": [32, 32], "free_log_std": True},
- )
- )
- if args.run_as_test:
- stop = {
- "episode_reward_mean": 190,
- "training_iteration": 20,
- }
- else:
- stop = {"training_iteration": 1}
- tuner = tune.Tuner(
- MBMPO,
- param_space=config.to_dict(),
- run_config=air.RunConfig(
- stop=stop,
- failure_config=air.FailureConfig(fail_fast="raise"),
- ),
- )
- results = tuner.fit()
|