correlated_actions_env.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. import gymnasium as gym
  2. from gymnasium.spaces import Discrete, Tuple
  3. import random
  4. class CorrelatedActionsEnv(gym.Env):
  5. """
  6. Simple env in which the policy has to emit a tuple of equal actions.
  7. In each step, the agent observes a random number (0 or 1) and has to choose
  8. two actions a1 and a2.
  9. It gets +5 reward for matching a1 to the random obs and +5 for matching a2
  10. to a1. I.e., +10 at most per step.
  11. One way to effectively learn this is through correlated action
  12. distributions, e.g., in examples/autoregressive_action_dist.py
  13. There are 20 steps. Hence, the best score would be ~200 reward.
  14. """
  15. def __init__(self, _):
  16. self.observation_space = Discrete(2)
  17. self.action_space = Tuple([Discrete(2), Discrete(2)])
  18. self.last_observation = None
  19. def reset(self, *, seed=None, options=None):
  20. self.t = 0
  21. self.last_observation = random.choice([0, 1])
  22. return self.last_observation, {}
  23. def step(self, action):
  24. self.t += 1
  25. a1, a2 = action
  26. reward = 0
  27. # Encourage correlation between most recent observation and a1.
  28. if a1 == self.last_observation:
  29. reward += 5
  30. # Encourage correlation between a1 and a2.
  31. if a1 == a2:
  32. reward += 5
  33. done = truncated = self.t > 20
  34. self.last_observation = random.choice([0, 1])
  35. return self.last_observation, reward, done, truncated, {}