rnnsac_stateless_cartpole.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import json
  2. import os
  3. from pathlib import Path
  4. import ray
  5. from ray import tune
  6. from ray.rllib.agents.registry import get_trainer_class
  7. from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
  8. from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
  9. envs = {
  10. "RepeatAfterMeEnv": RepeatAfterMeEnv,
  11. "StatelessCartPole": StatelessCartPole
  12. }
  13. config = {
  14. "name": "RNNSAC_example",
  15. "local_dir": str(Path(__file__).parent / "example_out"),
  16. "checkpoint_freq": 1,
  17. "keep_checkpoints_num": 1,
  18. "checkpoint_score_attr": "episode_reward_mean",
  19. "stop": {
  20. "episode_reward_mean": 65.0,
  21. "timesteps_total": 50000,
  22. },
  23. "metric": "episode_reward_mean",
  24. "mode": "max",
  25. "verbose": 2,
  26. "config": {
  27. "seed": 42,
  28. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
  29. "framework": "torch",
  30. "num_workers": 4,
  31. "num_envs_per_worker": 1,
  32. "num_cpus_per_worker": 1,
  33. "log_level": "INFO",
  34. # "env": envs["RepeatAfterMeEnv"],
  35. "env": envs["StatelessCartPole"],
  36. "horizon": 1000,
  37. "gamma": 0.95,
  38. "batch_mode": "complete_episodes",
  39. "prioritized_replay": False,
  40. "buffer_size": 100000,
  41. "learning_starts": 1000,
  42. "train_batch_size": 480,
  43. "target_network_update_freq": 480,
  44. "tau": 0.3,
  45. "burn_in": 4,
  46. "zero_init_states": False,
  47. "optimization": {
  48. "actor_learning_rate": 0.005,
  49. "critic_learning_rate": 0.005,
  50. "entropy_learning_rate": 0.0001
  51. },
  52. "model": {
  53. "max_seq_len": 20,
  54. },
  55. "policy_model": {
  56. "use_lstm": True,
  57. "lstm_cell_size": 64,
  58. "fcnet_hiddens": [64, 64],
  59. "lstm_use_prev_action": True,
  60. "lstm_use_prev_reward": True,
  61. },
  62. "Q_model": {
  63. "use_lstm": True,
  64. "lstm_cell_size": 64,
  65. "fcnet_hiddens": [64, 64],
  66. "lstm_use_prev_action": True,
  67. "lstm_use_prev_reward": True,
  68. },
  69. },
  70. }
  71. if __name__ == "__main__":
  72. # INIT
  73. ray.init(num_cpus=5)
  74. # TRAIN
  75. results = tune.run("RNNSAC", **config)
  76. # TEST
  77. best_checkpoint = results.best_checkpoint
  78. print("Loading checkpoint: {}".format(best_checkpoint))
  79. checkpoint_config_path = str(
  80. Path(best_checkpoint).parent.parent / "params.json")
  81. with open(checkpoint_config_path, "rb") as f:
  82. checkpoint_config = json.load(f)
  83. checkpoint_config["explore"] = False
  84. agent = get_trainer_class("RNNSAC")(
  85. env=config["config"]["env"], config=checkpoint_config)
  86. agent.restore(best_checkpoint)
  87. env = agent.env_creator({})
  88. state = agent.get_policy().get_initial_state()
  89. prev_action = 0
  90. prev_reward = 0
  91. obs = env.reset()
  92. eps = 0
  93. ep_reward = 0
  94. while eps < 10:
  95. action, state, info_trainer = agent.compute_action(
  96. obs,
  97. state=state,
  98. prev_action=prev_action,
  99. prev_reward=prev_reward,
  100. full_fetch=True)
  101. obs, reward, done, info = env.step(action)
  102. prev_action = action
  103. prev_reward = reward
  104. ep_reward += reward
  105. try:
  106. env.render()
  107. except (NotImplementedError, ImportError):
  108. pass
  109. if done:
  110. eps += 1
  111. print("Episode {}: {}".format(eps, ep_reward))
  112. ep_reward = 0
  113. state = agent.get_policy().get_initial_state()
  114. prev_action = 0
  115. prev_reward = 0
  116. obs = env.reset()
  117. ray.shutdown()