rllib_on_rllib_readme.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import gym
  2. from ray.rllib.agents.ppo import PPOTrainer
  3. # Define your problem using python and openAI's gym API:
  4. class ParrotEnv(gym.Env):
  5. """Environment in which an agent must learn to repeat the seen observations.
  6. Observations are float numbers indicating the to-be-repeated values,
  7. e.g. -1.0, 5.1, or 3.2.
  8. The action space is always the same as the observation space.
  9. Rewards are r=-abs(observation - action), for all steps.
  10. """
  11. def __init__(self, config):
  12. # Make the space (for actions and observations) configurable.
  13. self.action_space = config.get("parrot_shriek_range",
  14. gym.spaces.Box(-1.0, 1.0, shape=(1, )))
  15. # Since actions should repeat observations, their spaces must be the
  16. # same.
  17. self.observation_space = self.action_space
  18. self.cur_obs = None
  19. self.episode_len = 0
  20. def reset(self):
  21. """Resets the episode and returns the initial observation of the new one.
  22. """
  23. # Reset the episode len.
  24. self.episode_len = 0
  25. # Sample a random number from our observation space.
  26. self.cur_obs = self.observation_space.sample()
  27. # Return initial observation.
  28. return self.cur_obs
  29. def step(self, action):
  30. """Takes a single step in the episode given `action`
  31. Returns: New observation, reward, done-flag, info-dict (empty).
  32. """
  33. # Set `done` flag after 10 steps.
  34. self.episode_len += 1
  35. done = self.episode_len >= 10
  36. # r = -abs(obs - action)
  37. reward = -sum(abs(self.cur_obs - action))
  38. # Set a new observation (random sample).
  39. self.cur_obs = self.observation_space.sample()
  40. return self.cur_obs, reward, done, {}
  41. # Create an RLlib Trainer instance to learn how to act in the above
  42. # environment.
  43. trainer = PPOTrainer(
  44. config={
  45. # Env class to use (here: our gym.Env sub-class from above).
  46. "env": ParrotEnv,
  47. # Config dict to be passed to our custom env's constructor.
  48. "env_config": {
  49. "parrot_shriek_range": gym.spaces.Box(-5.0, 5.0, (1, ))
  50. },
  51. # Parallelize environment rollouts.
  52. "num_workers": 3,
  53. })
  54. # Train for n iterations and report results (mean episode rewards).
  55. # Since we have to guess 10 times and the optimal reward is 0.0
  56. # (exact match between observation and action value),
  57. # we can expect to reach an optimal episode reward of 0.0.
  58. for i in range(5):
  59. results = trainer.train()
  60. print(f"Iter: {i}; avg. reward={results['episode_reward_mean']}")
  61. # Perform inference (action computations) based on given env observations.
  62. # Note that we are using a slightly simpler env here (-3.0 to 3.0, instead
  63. # of -5.0 to 5.0!), however, this should still work as the agent has
  64. # (hopefully) learned to "just always repeat the observation!".
  65. env = ParrotEnv({"parrot_shriek_range": gym.spaces.Box(-3.0, 3.0, (1, ))})
  66. # Get the initial observation (some value between -10.0 and 10.0).
  67. obs = env.reset()
  68. done = False
  69. total_reward = 0.0
  70. # Play one episode.
  71. while not done:
  72. # Compute a single action, given the current observation
  73. # from the environment.
  74. action = trainer.compute_single_action(obs)
  75. # Apply the computed action in the environment.
  76. obs, reward, done, info = env.step(action)
  77. # Sum up rewards for reporting purposes.
  78. total_reward += reward
  79. # Report results.
  80. print(f"Played 1 episode; total-reward={total_reward}")