repeated.py 935 B

12345678910111213141516171819202122232425262728293031
  1. import gym
  2. from ray.rllib.utils.annotations import PublicAPI
  3. @PublicAPI
  4. class Repeated(gym.Space):
  5. """Represents a variable-length list of child spaces.
  6. Example:
  7. self.observation_space = spaces.Repeated(spaces.Box(4,), max_len=10)
  8. --> from 0 to 10 boxes of shape (4,)
  9. See also: documentation for rllib.models.RepeatedValues, which shows how
  10. the lists are represented as batched input for ModelV2 classes.
  11. """
  12. def __init__(self, child_space: gym.Space, max_len: int):
  13. super().__init__()
  14. self.child_space = child_space
  15. self.max_len = max_len
  16. def sample(self):
  17. return [
  18. self.child_space.sample()
  19. for _ in range(self.np_random.randint(1, self.max_len + 1))
  20. ]
  21. def contains(self, x):
  22. return (isinstance(x, list) and len(x) <= self.max_len
  23. and all(self.child_space.contains(c) for c in x))