123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- import json
- import os
- from pathlib import Path
- import ray
- from ray import air, tune
- from ray.tune.registry import get_trainable_cls
- from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
- param_space = {
- "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
- "framework": "torch",
- "num_workers": 4,
- "num_envs_per_worker": 1,
- "num_cpus_per_worker": 1,
- "log_level": "INFO",
- "env": StatelessCartPole,
- "gamma": 0.95,
- "batch_mode": "complete_episodes",
- "replay_buffer_config": {
- "type": "MultiAgentReplayBuffer",
- "storage_unit": "sequences",
- "capacity": 100000,
- "replay_burn_in": 4,
- },
- "num_steps_sampled_before_learning_starts": 1000,
- "train_batch_size": 480,
- "target_network_update_freq": 480,
- "tau": 0.3,
- "zero_init_states": False,
- "optimization": {
- "actor_learning_rate": 0.005,
- "critic_learning_rate": 0.005,
- "entropy_learning_rate": 0.0001,
- },
- "model": {
- "max_seq_len": 20,
- },
- "policy_model_config": {
- "use_lstm": True,
- "lstm_cell_size": 64,
- "fcnet_hiddens": [64, 64],
- "lstm_use_prev_action": True,
- "lstm_use_prev_reward": True,
- },
- "q_model_config": {
- "use_lstm": True,
- "lstm_cell_size": 64,
- "fcnet_hiddens": [64, 64],
- "lstm_use_prev_action": True,
- "lstm_use_prev_reward": True,
- },
- }
- if __name__ == "__main__":
- # INIT
- ray.init(num_cpus=5)
- # TRAIN
- results = tune.Tuner(
- "RNNSAC",
- run_config=air.RunConfig(
- name="RNNSAC_example",
- local_dir=str(Path(__file__).parent / "example_out"),
- verbose=2,
- checkpoint_config=air.CheckpointConfig(
- checkpoint_at_end=True,
- num_to_keep=1,
- checkpoint_score_attribute="episode_reward_mean",
- ),
- stop={
- "episode_reward_mean": 65.0,
- "timesteps_total": 50000,
- },
- ),
- tune_config=tune.TuneConfig(
- metric="episode_reward_mean",
- mode="max",
- ),
- param_space=param_space,
- ).fit()
- # TEST
- checkpoint_config_path = os.path.join(results.get_best_result().path, "params.json")
- with open(checkpoint_config_path, "rb") as f:
- checkpoint_config = json.load(f)
- checkpoint_config["explore"] = False
- best_checkpoint = results.get_best_result().best_checkpoints[0][0]
- print("Loading checkpoint: {}".format(best_checkpoint))
- algo = get_trainable_cls("RNNSAC")(env=StatelessCartPole, config=checkpoint_config)
- algo.restore(best_checkpoint)
- env = algo.env_creator({})
- state = algo.get_policy().get_initial_state()
- prev_action = 0
- prev_reward = 0
- obs, info = env.reset()
- eps = 0
- ep_reward = 0
- while eps < 10:
- action, state, info_algo = algo.compute_single_action(
- obs,
- state=state,
- prev_action=prev_action,
- prev_reward=prev_reward,
- full_fetch=True,
- )
- obs, reward, terminated, truncated, info = env.step(action)
- prev_action = action
- prev_reward = reward
- ep_reward += reward
- try:
- env.render()
- except Exception:
- pass
- if terminated or truncated:
- eps += 1
- print("Episode {}: {}".format(eps, ep_reward))
- ep_reward = 0
- state = algo.get_policy().get_initial_state()
- prev_action = 0
- prev_reward = 0
- obs, info = env.reset()
- ray.shutdown()
|