repeated_values.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from typing import List
  2. from ray.rllib.utils.annotations import PublicAPI
  3. from ray.rllib.utils.typing import TensorType, TensorStructType
  4. @PublicAPI
  5. class RepeatedValues:
  6. """Represents a variable-length list of items from spaces.Repeated.
  7. RepeatedValues are created when you use spaces.Repeated, and are
  8. accessible as part of input_dict["obs"] in ModelV2 forward functions.
  9. Example:
  10. Suppose the gym space definition was:
  11. Repeated(Repeated(Box(K), N), M)
  12. Then in the model forward function, input_dict["obs"] is of type:
  13. RepeatedValues(RepeatedValues(<Tensor shape=(B, M, N, K)>))
  14. The tensor is accessible via:
  15. input_dict["obs"].values.values
  16. And the actual data lengths via:
  17. # outer repetition, shape [B], range [0, M]
  18. input_dict["obs"].lengths
  19. -and-
  20. # inner repetition, shape [B, M], range [0, N]
  21. input_dict["obs"].values.lengths
  22. Attributes:
  23. values (Tensor): The padded data tensor of shape [B, max_len, ..., sz],
  24. where B is the batch dimension, max_len is the max length of this
  25. list, followed by any number of sub list max lens, followed by the
  26. actual data size.
  27. lengths (List[int]): Tensor of shape [B, ...] that represents the
  28. number of valid items in each list. When the list is nested within
  29. other lists, there will be extra dimensions for the parent list
  30. max lens.
  31. max_len (int): The max number of items allowed in each list.
  32. TODO(ekl): support conversion to tf.RaggedTensor.
  33. """
  34. def __init__(self, values: TensorType, lengths: List[int], max_len: int):
  35. self.values = values
  36. self.lengths = lengths
  37. self.max_len = max_len
  38. self._unbatched_repr = None
  39. def unbatch_all(self) -> List[List[TensorType]]:
  40. """Unbatch both the repeat and batch dimensions into Python lists.
  41. This is only supported in PyTorch / TF eager mode.
  42. This lets you view the data unbatched in its original form, but is
  43. not efficient for processing.
  44. Examples:
  45. >>> batch = RepeatedValues(<Tensor shape=(B, N, K)>)
  46. >>> items = batch.unbatch_all()
  47. >>> print(len(items) == B)
  48. True
  49. >>> print(max(len(x) for x in items) <= N)
  50. True
  51. >>> print(items)
  52. ... [[<Tensor_1 shape=(K)>, ..., <Tensor_N, shape=(K)>],
  53. ... ...
  54. ... [<Tensor_1 shape=(K)>, <Tensor_2 shape=(K)>],
  55. ... ...
  56. ... [<Tensor_1 shape=(K)>],
  57. ... ...
  58. ... [<Tensor_1 shape=(K)>, ..., <Tensor_N shape=(K)>]]
  59. """
  60. if self._unbatched_repr is None:
  61. B = _get_batch_dim_helper(self.values)
  62. if B is None:
  63. raise ValueError(
  64. "Cannot call unbatch_all() when batch_dim is unknown. "
  65. "This is probably because you are using TF graph mode.")
  66. else:
  67. B = int(B)
  68. slices = self.unbatch_repeat_dim()
  69. result = []
  70. for i in range(B):
  71. if hasattr(self.lengths[i], "item"):
  72. dynamic_len = int(self.lengths[i].item())
  73. else:
  74. dynamic_len = int(self.lengths[i].numpy())
  75. dynamic_slice = []
  76. for j in range(dynamic_len):
  77. dynamic_slice.append(_batch_index_helper(slices, i, j))
  78. result.append(dynamic_slice)
  79. self._unbatched_repr = result
  80. return self._unbatched_repr
  81. def unbatch_repeat_dim(self) -> List[TensorType]:
  82. """Unbatches the repeat dimension (the one `max_len` in size).
  83. This removes the repeat dimension. The result will be a Python list of
  84. with length `self.max_len`. Note that the data is still padded.
  85. Examples:
  86. >>> batch = RepeatedValues(<Tensor shape=(B, N, K)>)
  87. >>> items = batch.unbatch()
  88. >>> len(items) == batch.max_len
  89. True
  90. >>> print(items)
  91. ... [<Tensor_1 shape=(B, K)>, ..., <Tensor_N shape=(B, K)>]
  92. """
  93. return _unbatch_helper(self.values, self.max_len)
  94. def __repr__(self):
  95. return "RepeatedValues(value={}, lengths={}, max_len={})".format(
  96. repr(self.values), repr(self.lengths), self.max_len)
  97. def __str__(self):
  98. return repr(self)
  99. def _get_batch_dim_helper(v: TensorStructType) -> int:
  100. """Tries to find the batch dimension size of v, or None."""
  101. if isinstance(v, dict):
  102. for u in v.values():
  103. return _get_batch_dim_helper(u)
  104. elif isinstance(v, tuple):
  105. return _get_batch_dim_helper(v[0])
  106. elif isinstance(v, RepeatedValues):
  107. return _get_batch_dim_helper(v.values)
  108. else:
  109. B = v.shape[0]
  110. if hasattr(B, "value"):
  111. B = B.value # TensorFlow
  112. return B
  113. def _unbatch_helper(v: TensorStructType, max_len: int) -> TensorStructType:
  114. """Recursively unpacks the repeat dimension (max_len)."""
  115. if isinstance(v, dict):
  116. return {k: _unbatch_helper(u, max_len) for (k, u) in v.items()}
  117. elif isinstance(v, tuple):
  118. return tuple(_unbatch_helper(u, max_len) for u in v)
  119. elif isinstance(v, RepeatedValues):
  120. unbatched = _unbatch_helper(v.values, max_len)
  121. return [
  122. RepeatedValues(u, v.lengths[:, i, ...], v.max_len)
  123. for i, u in enumerate(unbatched)
  124. ]
  125. else:
  126. return [v[:, i, ...] for i in range(max_len)]
  127. def _batch_index_helper(v: TensorStructType, i: int,
  128. j: int) -> TensorStructType:
  129. """Selects the item at the ith batch index and jth repetition."""
  130. if isinstance(v, dict):
  131. return {k: _batch_index_helper(u, i, j) for (k, u) in v.items()}
  132. elif isinstance(v, tuple):
  133. return tuple(_batch_index_helper(u, i, j) for u in v)
  134. elif isinstance(v, list):
  135. # This is the output of unbatch_repeat_dim(). Unfortunately we have to
  136. # process it here instead of in unbatch_all(), since it may be buried
  137. # under a dict / tuple.
  138. return _batch_index_helper(v[j], i, j)
  139. elif isinstance(v, RepeatedValues):
  140. unbatched = v.unbatch_all()
  141. # Don't need to select j here; that's already done in unbatch_all.
  142. return unbatched[i]
  143. else:
  144. return v[i, ...]