123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- from typing import List
- from ray.rllib.utils.annotations import PublicAPI
- from ray.rllib.utils.typing import TensorType, TensorStructType
- @PublicAPI
- class RepeatedValues:
- """Represents a variable-length list of items from spaces.Repeated.
- RepeatedValues are created when you use spaces.Repeated, and are
- accessible as part of input_dict["obs"] in ModelV2 forward functions.
- Example:
- Suppose the gym space definition was:
- Repeated(Repeated(Box(K), N), M)
- Then in the model forward function, input_dict["obs"] is of type:
- RepeatedValues(RepeatedValues(<Tensor shape=(B, M, N, K)>))
- The tensor is accessible via:
- input_dict["obs"].values.values
- And the actual data lengths via:
- # outer repetition, shape [B], range [0, M]
- input_dict["obs"].lengths
- -and-
- # inner repetition, shape [B, M], range [0, N]
- input_dict["obs"].values.lengths
- Attributes:
- values (Tensor): The padded data tensor of shape [B, max_len, ..., sz],
- where B is the batch dimension, max_len is the max length of this
- list, followed by any number of sub list max lens, followed by the
- actual data size.
- lengths (List[int]): Tensor of shape [B, ...] that represents the
- number of valid items in each list. When the list is nested within
- other lists, there will be extra dimensions for the parent list
- max lens.
- max_len (int): The max number of items allowed in each list.
- TODO(ekl): support conversion to tf.RaggedTensor.
- """
- def __init__(self, values: TensorType, lengths: List[int], max_len: int):
- self.values = values
- self.lengths = lengths
- self.max_len = max_len
- self._unbatched_repr = None
- def unbatch_all(self) -> List[List[TensorType]]:
- """Unbatch both the repeat and batch dimensions into Python lists.
- This is only supported in PyTorch / TF eager mode.
- This lets you view the data unbatched in its original form, but is
- not efficient for processing.
- Examples:
- >>> batch = RepeatedValues(<Tensor shape=(B, N, K)>)
- >>> items = batch.unbatch_all()
- >>> print(len(items) == B)
- True
- >>> print(max(len(x) for x in items) <= N)
- True
- >>> print(items)
- ... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
- ... ...
- ... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
- ... ...
- ... [<Tensor_1 shape=(K)>],
- ... ...
- ... [<Tensor_1 shape=(K)>, ..., <Tensor_N shape=(K)>]]
- """
- if self._unbatched_repr is None:
- B = _get_batch_dim_helper(self.values)
- if B is None:
- raise ValueError(
- "Cannot call unbatch_all() when batch_dim is unknown. "
- "This is probably because you are using TF graph mode.")
- else:
- B = int(B)
- slices = self.unbatch_repeat_dim()
- result = []
- for i in range(B):
- if hasattr(self.lengths[i], "item"):
- dynamic_len = int(self.lengths[i].item())
- else:
- dynamic_len = int(self.lengths[i].numpy())
- dynamic_slice = []
- for j in range(dynamic_len):
- dynamic_slice.append(_batch_index_helper(slices, i, j))
- result.append(dynamic_slice)
- self._unbatched_repr = result
- return self._unbatched_repr
- def unbatch_repeat_dim(self) -> List[TensorType]:
- """Unbatches the repeat dimension (the one `max_len` in size).
- This removes the repeat dimension. The result will be a Python list of
- with length `self.max_len`. Note that the data is still padded.
- Examples:
- >>> batch = RepeatedValues(<Tensor shape=(B, N, K)>)
- >>> items = batch.unbatch()
- >>> len(items) == batch.max_len
- True
- >>> print(items)
- ... [<Tensor_1 shape=(B, K)>, ..., <Tensor_N shape=(B, K)>]
- """
- return _unbatch_helper(self.values, self.max_len)
- def __repr__(self):
- return "RepeatedValues(value={}, lengths={}, max_len={})".format(
- repr(self.values), repr(self.lengths), self.max_len)
- def __str__(self):
- return repr(self)
- def _get_batch_dim_helper(v: TensorStructType) -> int:
- """Tries to find the batch dimension size of v, or None."""
- if isinstance(v, dict):
- for u in v.values():
- return _get_batch_dim_helper(u)
- elif isinstance(v, tuple):
- return _get_batch_dim_helper(v[0])
- elif isinstance(v, RepeatedValues):
- return _get_batch_dim_helper(v.values)
- else:
- B = v.shape[0]
- if hasattr(B, "value"):
- B = B.value # TensorFlow
- return B
- def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
- """Recursively unpacks the repeat dimension (max_len)."""
- if isinstance(v, dict):
- return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
- elif isinstance(v, tuple):
- return tuple(_unbatch_helper(u, max_len) for u in v)
- elif isinstance(v, RepeatedValues):
- unbatched = _unbatch_helper(v.values, max_len)
- return [
- RepeatedValues(u, v.lengths[:, i, ...], v.max_len)
- for i, u in enumerate(unbatched)
- ]
- else:
- return [v[:, i, ...] for i in range(max_len)]
- def _batch_index_helper(v: TensorStructType, i: int,
- j: int) -> TensorStructType:
- """Selects the item at the ith batch index and jth repetition."""
- if isinstance(v, dict):
- return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}
- elif isinstance(v, tuple):
- return tuple(_batch_index_helper(u, i, j) for u in v)
- elif isinstance(v, list):
- # This is the output of unbatch_repeat_dim(). Unfortunately we have to
- # process it here instead of in unbatch_all(), since it may be buried
- # under a dict / tuple.
- return _batch_index_helper(v[j], i, j)
- elif isinstance(v, RepeatedValues):
- unbatched = v.unbatch_all()
- # Don't need to select j here; that's already done in unbatch_all.
- return unbatched[i]
- else:
- return v[i, ...]
|