sample_collector.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. from abc import abstractmethod, ABCMeta
  2. import logging
  3. from typing import Dict, List, Optional, TYPE_CHECKING, Union
  4. from ray.rllib.evaluation.episode import Episode
  5. from ray.rllib.policy.policy_map import PolicyMap
  6. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
  7. from ray.rllib.utils.typing import AgentID, EnvID, EpisodeID, PolicyID, \
  8. TensorType
  9. if TYPE_CHECKING:
  10. from ray.rllib.agents.callbacks import DefaultCallbacks
  11. logger = logging.getLogger(__name__)
  12. # yapf: disable
  13. # __sphinx_doc_begin__
  14. class SampleCollector(metaclass=ABCMeta):
  15. """Collects samples for all policies and agents from a multi-agent env.
  16. This API is controlled by RolloutWorker objects to store all data
  17. generated by Environments and Policies/Models during rollout and
  18. postprocessing. It's purposes are to a) make data collection and
  19. SampleBatch/input_dict generation from this data faster, b) to unify
  20. the way we collect samples from environments and model (outputs), thereby
  21. allowing for possible user customizations, c) to allow for more complex
  22. inputs fed into different policies (e.g. multi-agent case with inter-agent
  23. communication channel).
  24. """
  25. def __init__(self,
  26. policy_map: PolicyMap,
  27. clip_rewards: Union[bool, float],
  28. callbacks: "DefaultCallbacks",
  29. multiple_episodes_in_batch: bool = True,
  30. rollout_fragment_length: int = 200,
  31. count_steps_by: str = "env_steps"):
  32. """Initializes a SampleCollector instance.
  33. Args:
  34. policy_map (PolicyMap): Maps policy ids to policy instances.
  35. clip_rewards (Union[bool, float]): Whether to clip rewards before
  36. postprocessing (at +/-1.0) or the actual value to +/- clip.
  37. callbacks (DefaultCallbacks): RLlib callbacks.
  38. multiple_episodes_in_batch (bool): Whether it's allowed to pack
  39. multiple episodes into the same built batch.
  40. rollout_fragment_length (int): The
  41. """
  42. self.policy_map = policy_map
  43. self.clip_rewards = clip_rewards
  44. self.callbacks = callbacks
  45. self.multiple_episodes_in_batch = multiple_episodes_in_batch
  46. self.rollout_fragment_length = rollout_fragment_length
  47. self.count_steps_by = count_steps_by
  48. @abstractmethod
  49. def add_init_obs(self, episode: Episode, agent_id: AgentID,
  50. policy_id: PolicyID, t: int,
  51. init_obs: TensorType) -> None:
  52. """Adds an initial obs (after reset) to this collector.
  53. Since the very first observation in an environment is collected w/o
  54. additional data (w/o actions, w/o reward) after env.reset() is called,
  55. this method initializes a new trajectory for a given agent.
  56. `add_init_obs()` has to be called first for each agent/episode-ID
  57. combination. After this, only `add_action_reward_next_obs()` must be
  58. called for that same agent/episode-pair.
  59. Args:
  60. episode (Episode): The Episode, for which we
  61. are adding an Agent's initial observation.
  62. agent_id (AgentID): Unique id for the agent we are adding
  63. values for.
  64. env_id (EnvID): The environment index (in a vectorized setup).
  65. policy_id (PolicyID): Unique id for policy controlling the agent.
  66. t (int): The time step (episode length - 1). The initial obs has
  67. ts=-1(!), then an action/reward/next-obs at t=0, etc..
  68. init_obs (TensorType): Initial observation (after env.reset()).
  69. Examples:
  70. >>> obs = env.reset()
  71. >>> collector.add_init_obs(my_episode, 0, "pol0", -1, obs)
  72. >>> obs, r, done, info = env.step(action)
  73. >>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
  74. ... "action": action, "obs": obs, "reward": r, "done": done
  75. ... })
  76. """
  77. raise NotImplementedError
  78. @abstractmethod
  79. def add_action_reward_next_obs(self, episode_id: EpisodeID,
  80. agent_id: AgentID, env_id: EnvID,
  81. policy_id: PolicyID, agent_done: bool,
  82. values: Dict[str, TensorType]) -> None:
  83. """Add the given dictionary (row) of values to this collector.
  84. The incoming data (`values`) must include action, reward, done, and
  85. next_obs information and may include any other information.
  86. For the initial observation (after Env.reset()) of the given agent/
  87. episode-ID combination, `add_initial_obs()` must be called instead.
  88. Args:
  89. episode_id (EpisodeID): Unique id for the episode we are adding
  90. values for.
  91. agent_id (AgentID): Unique id for the agent we are adding
  92. values for.
  93. env_id (EnvID): The environment index (in a vectorized setup).
  94. policy_id (PolicyID): Unique id for policy controlling the agent.
  95. agent_done (bool): Whether the given agent is done with its
  96. trajectory (the multi-agent episode may still be ongoing).
  97. values (Dict[str, TensorType]): Row of values to add for this
  98. agent. This row must contain the keys SampleBatch.ACTION,
  99. REWARD, NEW_OBS, and DONE.
  100. Examples:
  101. >>> obs = env.reset()
  102. >>> collector.add_init_obs(12345, 0, "pol0", obs)
  103. >>> obs, r, done, info = env.step(action)
  104. >>> collector.add_action_reward_next_obs(12345, 0, "pol0", False, {
  105. ... "action": action, "obs": obs, "reward": r, "done": done
  106. ... })
  107. """
  108. raise NotImplementedError
  109. @abstractmethod
  110. def episode_step(self, episode: Episode) -> None:
  111. """Increases the episode step counter (across all agents) by one.
  112. Args:
  113. episode (Episode): Episode we are stepping through.
  114. Useful for handling counting b/c it is called once across
  115. all agents that are inside this episode.
  116. """
  117. raise NotImplementedError
  118. @abstractmethod
  119. def total_env_steps(self) -> int:
  120. """Returns total number of env-steps taken so far.
  121. Thereby, a step in an N-agent multi-agent environment counts as only 1
  122. for this metric. The returned count contains everything that has not
  123. been built yet (and returned as MultiAgentBatches by the
  124. `try_build_truncated_episode_multi_agent_batch` or
  125. `postprocess_episode(build=True)` methods). After such build, this
  126. counter is reset to 0.
  127. Returns:
  128. int: The number of env-steps taken in total in the environment(s)
  129. so far.
  130. """
  131. raise NotImplementedError
  132. @abstractmethod
  133. def total_agent_steps(self) -> int:
  134. """Returns total number of (individual) agent-steps taken so far.
  135. Thereby, a step in an N-agent multi-agent environment counts as N.
  136. If less than N agents have stepped (because some agents were not
  137. required to send actions), the count will be increased by less than N.
  138. The returned count contains everything that has not been built yet
  139. (and returned as MultiAgentBatches by the
  140. `try_build_truncated_episode_multi_agent_batch` or
  141. `postprocess_episode(build=True)` methods). After such build, this
  142. counter is reset to 0.
  143. Returns:
  144. int: The number of (individual) agent-steps taken in total in the
  145. environment(s) so far.
  146. """
  147. raise NotImplementedError
  148. @abstractmethod
  149. def get_inference_input_dict(self, policy_id: PolicyID) -> \
  150. Dict[str, TensorType]:
  151. """Returns an input_dict for an (inference) forward pass from our data.
  152. The input_dict can then be used for action computations inside a
  153. Policy via `Policy.compute_actions_from_input_dict()`.
  154. Args:
  155. policy_id (PolicyID): The Policy ID to get the input dict for.
  156. Returns:
  157. Dict[str, TensorType]: The input_dict to be passed into the ModelV2
  158. for inference/training.
  159. Examples:
  160. >>> obs, r, done, info = env.step(action)
  161. >>> collector.add_action_reward_next_obs(12345, 0, "pol0", {
  162. ... "action": action, "obs": obs, "reward": r, "done": done
  163. ... })
  164. >>> input_dict = collector.get_inference_input_dict(policy.model)
  165. >>> action = policy.compute_actions_from_input_dict(input_dict)
  166. >>> # repeat
  167. """
  168. raise NotImplementedError
  169. @abstractmethod
  170. def postprocess_episode(self,
  171. episode: Episode,
  172. is_done: bool = False,
  173. check_dones: bool = False,
  174. build: bool = False) -> Optional[MultiAgentBatch]:
  175. """Postprocesses all agents' trajectories in a given episode.
  176. Generates (single-trajectory) SampleBatches for all Policies/Agents and
  177. calls Policy.postprocess_trajectory on each of these. Postprocessing
  178. may happens in-place, meaning any changes to the viewed data columns
  179. are directly reflected inside this collector's buffers.
  180. Also makes sure that additional (newly created) data columns are
  181. correctly added to the buffers.
  182. Args:
  183. episode (Episode): The Episode object for which
  184. to post-process data.
  185. is_done (bool): Whether the given episode is actually terminated
  186. (all agents are done OR we hit a hard horizon). If True, the
  187. episode will no longer be used/continued and we may need to
  188. recycle/erase it internally. If a soft-horizon is hit, the
  189. episode will continue to be used and `is_done` should be set
  190. to False here.
  191. check_dones (bool): Whether we need to check that all agents'
  192. trajectories have dones=True at the end.
  193. build (bool): Whether to build a MultiAgentBatch from the given
  194. episode (and only that episode!) and return that
  195. MultiAgentBatch. Used for batch_mode=`complete_episodes`.
  196. Returns:
  197. Optional[MultiAgentBatch]: If `build` is True, the
  198. SampleBatch or MultiAgentBatch built from `episode` (either
  199. just from that episde or from the `_PolicyCollectorGroup`
  200. in the `episode.batch_builder` property).
  201. """
  202. raise NotImplementedError
  203. @abstractmethod
  204. def try_build_truncated_episode_multi_agent_batch(self) -> \
  205. List[Union[MultiAgentBatch, SampleBatch]]:
  206. """Tries to build an MA-batch, if `rollout_fragment_length` is reached.
  207. Any unprocessed data will be first postprocessed with a policy
  208. postprocessor.
  209. This is usually called to collect samples for policy training.
  210. If not enough data has been collected yet (`rollout_fragment_length`),
  211. returns an empty list.
  212. Returns:
  213. List[Union[MultiAgentBatch, SampleBatch]]: Returns a (possibly
  214. empty) list of MultiAgentBatches (containing the accumulated
  215. SampleBatches for each policy or a simple SampleBatch if only
  216. one policy). The list will be empty if
  217. `self.rollout_fragment_length` has not been reached yet.
  218. """
  219. raise NotImplementedError
  220. # __sphinx_doc_end__