rnnsac_stateless_cartpole.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import json
  2. from pathlib import Path
  3. import ray
  4. from ray import tune
  5. from ray.rllib.agents.registry import get_trainer_class
  6. from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
  7. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  8. envs = {
  9. "RepeatAfterMeEnv": RepeatAfterMeEnv,
  10. "StatelessCartPole": StatelessCartPole
  11. }
  12. config = {
  13. "name": "RNNSAC_example",
  14. "local_dir": str(Path(__file__).parent / "example_out"),
  15. "checkpoint_freq": 1,
  16. "keep_checkpoints_num": 1,
  17. "checkpoint_score_attr": "episode_reward_mean",
  18. "stop": {
  19. "episode_reward_mean": 65.0,
  20. "timesteps_total": 100000,
  21. },
  22. "metric": "episode_reward_mean",
  23. "mode": "max",
  24. "verbose": 2,
  25. "config": {
  26. "framework": "torch",
  27. "num_workers": 4,
  28. "num_envs_per_worker": 1,
  29. "num_cpus_per_worker": 1,
  30. "log_level": "INFO",
  31. # "env": envs["RepeatAfterMeEnv"],
  32. "env": envs["StatelessCartPole"],
  33. "horizon": 1000,
  34. "gamma": 0.95,
  35. "batch_mode": "complete_episodes",
  36. "prioritized_replay": False,
  37. "buffer_size": 100000,
  38. "learning_starts": 1000,
  39. "train_batch_size": 480,
  40. "target_network_update_freq": 480,
  41. "tau": 0.3,
  42. "burn_in": 4,
  43. "zero_init_states": False,
  44. "optimization": {
  45. "actor_learning_rate": 0.005,
  46. "critic_learning_rate": 0.005,
  47. "entropy_learning_rate": 0.0001
  48. },
  49. "model": {
  50. "max_seq_len": 20,
  51. },
  52. "policy_model": {
  53. "use_lstm": True,
  54. "lstm_cell_size": 64,
  55. "fcnet_hiddens": [64, 64],
  56. "lstm_use_prev_action": True,
  57. "lstm_use_prev_reward": True,
  58. },
  59. "Q_model": {
  60. "use_lstm": True,
  61. "lstm_cell_size": 64,
  62. "fcnet_hiddens": [64, 64],
  63. "lstm_use_prev_action": True,
  64. "lstm_use_prev_reward": True,
  65. },
  66. },
  67. }
  68. if __name__ == "__main__":
  69. # INIT
  70. ray.init(num_cpus=5)
  71. # TRAIN
  72. results = tune.run("RNNSAC", **config)
  73. # TEST
  74. best_checkpoint = results.best_checkpoint
  75. print("Loading checkpoint: {}".format(best_checkpoint))
  76. checkpoint_config_path = str(
  77. Path(best_checkpoint).parent.parent / "params.json")
  78. with open(checkpoint_config_path, "rb") as f:
  79. checkpoint_config = json.load(f)
  80. checkpoint_config["explore"] = False
  81. agent = get_trainer_class("RNNSAC")(
  82. env=config["config"]["env"], config=checkpoint_config)
  83. agent.restore(best_checkpoint)
  84. env = agent.env_creator({})
  85. state = agent.get_policy().get_initial_state()
  86. prev_action = 0
  87. prev_reward = 0
  88. obs = env.reset()
  89. eps = 0
  90. ep_reward = 0
  91. while eps < 10:
  92. action, state, info_trainer = agent.compute_action(
  93. obs,
  94. state=state,
  95. prev_action=prev_action,
  96. prev_reward=prev_reward,
  97. full_fetch=True)
  98. obs, reward, done, info = env.step(action)
  99. prev_action = action
  100. prev_reward = reward
  101. ep_reward += reward
  102. try:
  103. env.render()
  104. except (NotImplementedError, ImportError):
  105. pass
  106. if done:
  107. eps += 1
  108. print("Episode {}: {}".format(eps, ep_reward))
  109. ep_reward = 0
  110. state = agent.get_policy().get_initial_state()
  111. prev_action = 0
  112. prev_reward = 0
  113. obs = env.reset()
  114. ray.shutdown()