view_requirement.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import dataclasses
  2. import gymnasium as gym
  3. from typing import Dict, List, Optional, Union
  4. import numpy as np
  5. from ray.rllib.utils.annotations import PublicAPI
  6. from ray.rllib.utils.framework import try_import_torch
  7. from ray.rllib.utils.serialization import (
  8. gym_space_to_dict,
  9. gym_space_from_dict,
  10. )
  11. torch, _ = try_import_torch()
  12. @PublicAPI
  13. @dataclasses.dataclass
  14. class ViewRequirement:
  15. """Single view requirement (for one column in an SampleBatch/input_dict).
  16. Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
  17. their `[train|inference]_view_requirements()` methods, where the str key
  18. represents the column name (C) under which the view is available in the
  19. input_dict/SampleBatch and ViewRequirement specifies the actual underlying
  20. column names (in the original data buffer), timestep shifts, and other
  21. options to build the view.
  22. Examples:
  23. >>> from ray.rllib.models.modelv2 import ModelV2
  24. >>> # The default ViewRequirement for a Model is:
  25. >>> req = ModelV2(...).view_requirements # doctest: +SKIP
  26. >>> print(req) # doctest: +SKIP
  27. {"obs": ViewRequirement(shift=0)}
  28. Args:
  29. data_col: The data column name from the SampleBatch
  30. (str key). If None, use the dict key under which this
  31. ViewRequirement resides.
  32. space: The gym Space used in case we need to pad data
  33. in inaccessible areas of the trajectory (t<0 or t>H).
  34. Default: Simple box space, e.g. rewards.
  35. shift: Single shift value or
  36. list of relative positions to use (relative to the underlying
  37. `data_col`).
  38. Example: For a view column "prev_actions", you can set
  39. `data_col="actions"` and `shift=-1`.
  40. Example: For a view column "obs" in an Atari framestacking
  41. fashion, you can set `data_col="obs"` and
  42. `shift=[-3, -2, -1, 0]`.
  43. Example: For the obs input to an attention net, you can specify
  44. a range via a str: `shift="-100:0"`, which will pass in
  45. the past 100 observations plus the current one.
  46. index: An optional absolute position arg,
  47. used e.g. for the location of a requested inference dict within
  48. the trajectory. Negative values refer to counting from the end
  49. of a trajectory. (#TODO: Is this still used?)
  50. batch_repeat_value: determines how many time steps we should skip
  51. before we repeat the view indexing for the next timestep. For RNNs this
  52. number is usually the sequence length that we will rollout over.
  53. Example:
  54. view_col = "state_in_0", data_col = "state_out_0"
  55. batch_repeat_value = 5, shift = -1
  56. buffer["state_out_0"] = [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
  57. output["state_in_0"] = [-1, 4, 9]
  58. Explanation: For t=0, we output buffer["state_out_0"][-1]. We then skip 5
  59. time steps and repeat the view. for t=5, we output buffer["state_out_0"][4]
  60. . Continuing on this pattern, for t=10, we output buffer["state_out_0"][9].
  61. used_for_compute_actions: Whether the data will be used for
  62. creating input_dicts for `Policy.compute_actions()` calls (or
  63. `Policy.compute_actions_from_input_dict()`).
  64. used_for_training: Whether the data will be used for
  65. training. If False, the column will not be copied into the
  66. final train batch.
  67. """
  68. data_col: Optional[str] = None
  69. space: gym.Space = None
  70. shift: Union[int, str, List[int]] = 0
  71. index: Optional[int] = None
  72. batch_repeat_value: int = 1
  73. used_for_compute_actions: bool = True
  74. used_for_training: bool = True
  75. shift_arr: Optional[np.ndarray] = dataclasses.field(init=False)
  76. def __post_init__(self):
  77. """Initializes a ViewRequirement object.
  78. shift_arr is infered from the shift value.
  79. For example:
  80. - if shift is -1, then shift_arr is np.array([-1]).
  81. - if shift is [-1, -2], then shift_arr is np.array([-2, -1]).
  82. - if shift is "-2:2", then shift_arr is np.array([-2, -1, 0, 1, 2]).
  83. """
  84. if self.space is None:
  85. self.space = gym.spaces.Box(float("-inf"), float("inf"), shape=())
  86. # TODO: ideally we won't need shift_from and shift_to, and shift_step.
  87. # all of them should be captured within shift_arr.
  88. # Special case: Providing a (probably larger) range of indices, e.g.
  89. # "-100:0" (past 100 timesteps plus current one).
  90. self.shift_from = self.shift_to = self.shift_step = None
  91. if isinstance(self.shift, str):
  92. split = self.shift.split(":")
  93. assert len(split) in [2, 3], f"Invalid shift str format: {self.shift}"
  94. if len(split) == 2:
  95. f, t = split
  96. self.shift_step = 1
  97. else:
  98. f, t, s = split
  99. self.shift_step = int(s)
  100. self.shift_from = int(f)
  101. self.shift_to = int(t)
  102. shift = self.shift
  103. self.shfit_arr = None
  104. if self.shift_from:
  105. self.shift_arr = np.arange(
  106. self.shift_from, self.shift_to + 1, self.shift_step
  107. )
  108. else:
  109. if isinstance(shift, int):
  110. self.shift_arr = np.array([shift])
  111. elif isinstance(shift, list):
  112. self.shift_arr = np.array(shift)
  113. else:
  114. ValueError(f'unrecognized shift type: "{shift}"')
  115. def to_dict(self) -> Dict:
  116. """Return a dict for this ViewRequirement that can be JSON serialized."""
  117. return {
  118. "data_col": self.data_col,
  119. "space": gym_space_to_dict(self.space),
  120. "shift": self.shift,
  121. "index": self.index,
  122. "batch_repeat_value": self.batch_repeat_value,
  123. "used_for_training": self.used_for_training,
  124. "used_for_compute_actions": self.used_for_compute_actions,
  125. }
  126. @classmethod
  127. def from_dict(cls, d: Dict):
  128. """Construct a ViewRequirement instance from JSON deserialized dict."""
  129. d["space"] = gym_space_from_dict(d["space"])
  130. return cls(**d)