repeated.py 1.0 KB

1234567891011121314151617181920212223242526272829303132333435
  1. import gym
  2. import numpy as np
  3. from ray.rllib.utils.annotations import PublicAPI
  4. @PublicAPI
  5. class Repeated(gym.Space):
  6. """Represents a variable-length list of child spaces.
  7. Example:
  8. self.observation_space = spaces.Repeated(spaces.Box(4,), max_len=10)
  9. --> from 0 to 10 boxes of shape (4,)
  10. See also: documentation for rllib.models.RepeatedValues, which shows how
  11. the lists are represented as batched input for ModelV2 classes.
  12. """
  13. def __init__(self, child_space: gym.Space, max_len: int):
  14. super().__init__()
  15. self.child_space = child_space
  16. self.max_len = max_len
  17. def sample(self):
  18. return [
  19. self.child_space.sample()
  20. for _ in range(self.np_random.randint(1, self.max_len + 1))
  21. ]
  22. def contains(self, x):
  23. return (isinstance(x, (list, np.ndarray)) and len(x) <= self.max_len
  24. and all(self.child_space.contains(c) for c in x))
  25. def __repr__(self):
  26. return "Repeated({}, {})".format(self.child_space, self.max_len)