view_requirement.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import gym
  2. import numpy as np
  3. from typing import List, Optional, Union
  4. from ray.rllib.utils.framework import try_import_torch
  5. torch, _ = try_import_torch()
  6. class ViewRequirement:
  7. """Single view requirement (for one column in an SampleBatch/input_dict).
  8. Policies and ModelV2s return a Dict[str, ViewRequirement] upon calling
  9. their `[train|inference]_view_requirements()` methods, where the str key
  10. represents the column name (C) under which the view is available in the
  11. input_dict/SampleBatch and ViewRequirement specifies the actual underlying
  12. column names (in the original data buffer), timestep shifts, and other
  13. options to build the view.
  14. Examples:
  15. >>> # The default ViewRequirement for a Model is:
  16. >>> req = [ModelV2].view_requirements
  17. >>> print(req)
  18. {"obs": ViewRequirement(shift=0)}
  19. """
  20. def __init__(self,
  21. data_col: Optional[str] = None,
  22. space: gym.Space = None,
  23. shift: Union[int, str, List[int]] = 0,
  24. index: Optional[int] = None,
  25. batch_repeat_value: int = 1,
  26. used_for_compute_actions: bool = True,
  27. used_for_training: bool = True):
  28. """Initializes a ViewRequirement object.
  29. Args:
  30. data_col (Optional[str]): The data column name from the SampleBatch
  31. (str key). If None, use the dict key under which this
  32. ViewRequirement resides.
  33. space (gym.Space): The gym Space used in case we need to pad data
  34. in inaccessible areas of the trajectory (t<0 or t>H).
  35. Default: Simple box space, e.g. rewards.
  36. shift (Union[int, str, List[int]]): Single shift value or
  37. list of relative positions to use (relative to the underlying
  38. `data_col`).
  39. Example: For a view column "prev_actions", you can set
  40. `data_col="actions"` and `shift=-1`.
  41. Example: For a view column "obs" in an Atari framestacking
  42. fashion, you can set `data_col="obs"` and
  43. `shift=[-3, -2, -1, 0]`.
  44. Example: For the obs input to an attention net, you can specify
  45. a range via a str: `shift="-100:0"`, which will pass in
  46. the past 100 observations plus the current one.
  47. index (Optional[int]): An optional absolute position arg,
  48. used e.g. for the location of a requested inference dict within
  49. the trajectory. Negative values refer to counting from the end
  50. of a trajectory.
  51. used_for_compute_actions (bool): Whether the data will be used for
  52. creating input_dicts for `Policy.compute_actions()` calls (or
  53. `Policy.compute_actions_from_input_dict()`).
  54. used_for_training (bool): Whether the data will be used for
  55. training. If False, the column will not be copied into the
  56. final train batch.
  57. """
  58. self.data_col = data_col
  59. self.space = space if space is not None else gym.spaces.Box(
  60. float("-inf"), float("inf"), shape=())
  61. self.shift = shift
  62. if isinstance(self.shift, (list, tuple)):
  63. self.shift = np.array(self.shift)
  64. # Special case: Providing a (probably larger) range of indices, e.g.
  65. # "-100:0" (past 100 timesteps plus current one).
  66. self.shift_from = self.shift_to = None
  67. if isinstance(self.shift, str):
  68. f, t = self.shift.split(":")
  69. self.shift_from = int(f)
  70. self.shift_to = int(t)
  71. self.index = index
  72. self.batch_repeat_value = batch_repeat_value
  73. self.used_for_compute_actions = used_for_compute_actions
  74. self.used_for_training = used_for_training