repeat_initial_obs_env.py 904 B

1234567891011121314151617181920212223242526272829303132
  1. import gymnasium as gym
  2. from gymnasium.spaces import Discrete
  3. import random
  4. class RepeatInitialObsEnv(gym.Env):
  5. """Env in which the initial observation has to be repeated all the time.
  6. Runs for n steps.
  7. r=1 if action correct, -1 otherwise (max. R=100).
  8. """
  9. def __init__(self, episode_len=100):
  10. self.observation_space = Discrete(2)
  11. self.action_space = Discrete(2)
  12. self.token = None
  13. self.episode_len = episode_len
  14. self.num_steps = 0
  15. def reset(self, *, seed=None, options=None):
  16. self.token = random.choice([0, 1])
  17. self.num_steps = 0
  18. return self.token, {}
  19. def step(self, action):
  20. if action == self.token:
  21. reward = 1
  22. else:
  23. reward = -1
  24. self.num_steps += 1
  25. done = truncated = self.num_steps >= self.episode_len
  26. return 0, reward, done, truncated, {}