rnnsac_stateless_cartpole.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import json
  2. import os
  3. from pathlib import Path
  4. import ray
  5. from ray import air, tune
  6. from ray.tune.registry import get_trainable_cls
  7. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  8. param_space = {
  9. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  10. "framework": "torch",
  11. "num_workers": 4,
  12. "num_envs_per_worker": 1,
  13. "num_cpus_per_worker": 1,
  14. "log_level": "INFO",
  15. "env": StatelessCartPole,
  16. "gamma": 0.95,
  17. "batch_mode": "complete_episodes",
  18. "replay_buffer_config": {
  19. "type": "MultiAgentReplayBuffer",
  20. "storage_unit": "sequences",
  21. "capacity": 100000,
  22. "replay_burn_in": 4,
  23. },
  24. "num_steps_sampled_before_learning_starts": 1000,
  25. "train_batch_size": 480,
  26. "target_network_update_freq": 480,
  27. "tau": 0.3,
  28. "zero_init_states": False,
  29. "optimization": {
  30. "actor_learning_rate": 0.005,
  31. "critic_learning_rate": 0.005,
  32. "entropy_learning_rate": 0.0001,
  33. },
  34. "model": {
  35. "max_seq_len": 20,
  36. },
  37. "policy_model_config": {
  38. "use_lstm": True,
  39. "lstm_cell_size": 64,
  40. "fcnet_hiddens": [64, 64],
  41. "lstm_use_prev_action": True,
  42. "lstm_use_prev_reward": True,
  43. },
  44. "q_model_config": {
  45. "use_lstm": True,
  46. "lstm_cell_size": 64,
  47. "fcnet_hiddens": [64, 64],
  48. "lstm_use_prev_action": True,
  49. "lstm_use_prev_reward": True,
  50. },
  51. }
  52. if __name__ == "__main__":
  53. # INIT
  54. ray.init(num_cpus=5)
  55. # TRAIN
  56. results = tune.Tuner(
  57. "RNNSAC",
  58. run_config=air.RunConfig(
  59. name="RNNSAC_example",
  60. local_dir=str(Path(__file__).parent / "example_out"),
  61. verbose=2,
  62. checkpoint_config=air.CheckpointConfig(
  63. checkpoint_at_end=True,
  64. num_to_keep=1,
  65. checkpoint_score_attribute="episode_reward_mean",
  66. ),
  67. stop={
  68. "episode_reward_mean": 65.0,
  69. "timesteps_total": 50000,
  70. },
  71. ),
  72. tune_config=tune.TuneConfig(
  73. metric="episode_reward_mean",
  74. mode="max",
  75. ),
  76. param_space=param_space,
  77. ).fit()
  78. # TEST
  79. checkpoint_config_path = os.path.join(results.get_best_result().path, "params.json")
  80. with open(checkpoint_config_path, "rb") as f:
  81. checkpoint_config = json.load(f)
  82. checkpoint_config["explore"] = False
  83. best_checkpoint = results.get_best_result().best_checkpoints[0][0]
  84. print("Loading checkpoint: {}".format(best_checkpoint))
  85. algo = get_trainable_cls("RNNSAC")(env=StatelessCartPole, config=checkpoint_config)
  86. algo.restore(best_checkpoint)
  87. env = algo.env_creator({})
  88. state = algo.get_policy().get_initial_state()
  89. prev_action = 0
  90. prev_reward = 0
  91. obs, info = env.reset()
  92. eps = 0
  93. ep_reward = 0
  94. while eps < 10:
  95. action, state, info_algo = algo.compute_single_action(
  96. obs,
  97. state=state,
  98. prev_action=prev_action,
  99. prev_reward=prev_reward,
  100. full_fetch=True,
  101. )
  102. obs, reward, terminated, truncated, info = env.step(action)
  103. prev_action = action
  104. prev_reward = reward
  105. ep_reward += reward
  106. try:
  107. env.render()
  108. except Exception:
  109. pass
  110. if terminated or truncated:
  111. eps += 1
  112. print("Episode {}: {}".format(eps, ep_reward))
  113. ep_reward = 0
  114. state = algo.get_policy().get_initial_state()
  115. prev_action = 0
  116. prev_reward = 0
  117. obs, info = env.reset()
  118. ray.shutdown()