look_and_push.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import gymnasium as gym
  2. import numpy as np
  3. class LookAndPush(gym.Env):
  4. """Memory-requiring Env: Best sequence of actions depends on prev. states.
  5. Optimal behavior:
  6. 0) a=0 -> observe next state (s'), which is the "hidden" state.
  7. If a=1 here, the hidden state is not observed.
  8. 1) a=1 to always jump to s=2 (not matter what the prev. state was).
  9. 2) a=1 to move to s=3.
  10. 3) a=1 to move to s=4.
  11. 4) a=0 OR 1 depending on s' observed after 0): +10 reward and done.
  12. otherwise: -10 reward and done.
  13. """
  14. def __init__(self):
  15. self.action_space = gym.spaces.Discrete(2)
  16. self.observation_space = gym.spaces.Discrete(5)
  17. self._state = None
  18. self._case = None
  19. def reset(self, *, seed=None, options=None):
  20. self._state = 2
  21. self._case = np.random.choice(2)
  22. return self._state, {}
  23. def step(self, action):
  24. assert self.action_space.contains(action)
  25. if self._state == 4:
  26. if action and self._case:
  27. return self._state, 10.0, True, {}
  28. else:
  29. return self._state, -10, True, {}
  30. else:
  31. if action:
  32. if self._state == 0:
  33. self._state = 2
  34. else:
  35. self._state += 1
  36. elif self._state == 2:
  37. self._state = self._case
  38. return self._state, -1, False, False, {}
  39. class OneHot(gym.Wrapper):
  40. def __init__(self, env):
  41. super(OneHot, self).__init__(env)
  42. self.observation_space = gym.spaces.Box(0.0, 1.0, (env.observation_space.n,))
  43. def reset(self, *, seed=None, options=None):
  44. obs, info = self.env.reset(seed=seed, options=options)
  45. return self._encode_obs(obs), info
  46. def step(self, action):
  47. obs, reward, terminated, truncated, info = self.env.step(action)
  48. return self._encode_obs(obs), reward, terminated, truncated, info
  49. def _encode_obs(self, obs):
  50. new_obs = np.ones(self.env.observation_space.n)
  51. new_obs[obs] = 1.0
  52. return new_obs