nested_space_repeat_after_me_env.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import gym
  2. from gym.spaces import Box, Dict, Discrete, Tuple
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from ray.rllib.utils.spaces.space_utils import flatten_space
  6. class NestedSpaceRepeatAfterMeEnv(gym.Env):
  7. """Env for which policy has to repeat the (possibly complex) observation.
  8. The action space and observation spaces are always the same and may be
  9. arbitrarily nested Dict/Tuple Spaces.
  10. Rewards are given for exactly matching Discrete sub-actions and for being
  11. as close as possible for Box sub-actions.
  12. """
  13. def __init__(self, config):
  14. self.observation_space = config.get(
  15. "space", Tuple([Discrete(2),
  16. Dict({
  17. "a": Box(-1.0, 1.0, (2, ))
  18. })]))
  19. self.action_space = self.observation_space
  20. self.flattened_action_space = flatten_space(self.action_space)
  21. self.episode_len = config.get("episode_len", 100)
  22. def reset(self):
  23. self.steps = 0
  24. return self._next_obs()
  25. def step(self, action):
  26. self.steps += 1
  27. action = tree.flatten(action)
  28. reward = 0.0
  29. for a, o, space in zip(action, self.current_obs_flattened,
  30. self.flattened_action_space):
  31. # Box: -abs(diff).
  32. if isinstance(space, gym.spaces.Box):
  33. reward -= np.sum(np.abs(a - o))
  34. # Discrete: +1.0 if exact match.
  35. if isinstance(space, gym.spaces.Discrete):
  36. reward += 1.0 if a == o else 0.0
  37. done = self.steps >= self.episode_len
  38. return self._next_obs(), reward, done, {}
  39. def _next_obs(self):
  40. self.current_obs = self.observation_space.sample()
  41. self.current_obs_flattened = tree.flatten(self.current_obs)
  42. return self.current_obs