sample_batch_builder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. import collections
  2. import logging
  3. import numpy as np
  4. from typing import List, Any, Dict, Optional, TYPE_CHECKING
  5. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  6. from ray.rllib.evaluation.episode import Episode
  7. from ray.rllib.policy.policy import Policy
  8. from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
  9. from ray.rllib.utils.annotations import DeveloperAPI
  10. from ray.rllib.utils.deprecation import Deprecated
  11. from ray.rllib.utils.debug import summarize
  12. from ray.rllib.utils.deprecation import deprecation_warning
  13. from ray.rllib.utils.typing import PolicyID, AgentID
  14. from ray.util.debug import log_once
  15. if TYPE_CHECKING:
  16. from ray.rllib.agents.callbacks import DefaultCallbacks
  17. logger = logging.getLogger(__name__)
  18. def to_float_array(v: List[Any]) -> np.ndarray:
  19. arr = np.array(v)
  20. if arr.dtype == np.float64:
  21. return arr.astype(np.float32) # save some memory
  22. return arr
  23. @Deprecated(new="a child class of `SampleCollector`", error=False)
  24. class SampleBatchBuilder:
  25. """Util to build a SampleBatch incrementally.
  26. For efficiency, SampleBatches hold values in column form (as arrays).
  27. However, it is useful to add data one row (dict) at a time.
  28. """
  29. _next_unroll_id = 0 # disambiguates unrolls within a single episode
  30. def __init__(self):
  31. self.buffers: Dict[str, List] = collections.defaultdict(list)
  32. self.count = 0
  33. def add_values(self, **values: Any) -> None:
  34. """Add the given dictionary (row) of values to this batch."""
  35. for k, v in values.items():
  36. self.buffers[k].append(v)
  37. self.count += 1
  38. def add_batch(self, batch: SampleBatch) -> None:
  39. """Add the given batch of values to this batch."""
  40. for k, column in batch.items():
  41. self.buffers[k].extend(column)
  42. self.count += batch.count
  43. def build_and_reset(self) -> SampleBatch:
  44. """Returns a sample batch including all previously added values."""
  45. batch = SampleBatch(
  46. {k: to_float_array(v)
  47. for k, v in self.buffers.items()})
  48. if SampleBatch.UNROLL_ID not in batch:
  49. batch[SampleBatch.UNROLL_ID] = np.repeat(
  50. SampleBatchBuilder._next_unroll_id, batch.count)
  51. SampleBatchBuilder._next_unroll_id += 1
  52. self.buffers.clear()
  53. self.count = 0
  54. return batch
  55. # Deprecated class: Use a child class of `SampleCollector` instead
  56. # (which handles multi-agent setups as well).
  57. @DeveloperAPI
  58. class MultiAgentSampleBatchBuilder:
  59. """Util to build SampleBatches for each policy in a multi-agent env.
  60. Input data is per-agent, while output data is per-policy. There is an M:N
  61. mapping between agents and policies. We retain one local batch builder
  62. per agent. When an agent is done, then its local batch is appended into the
  63. corresponding policy batch for the agent's policy.
  64. """
  65. def __init__(self, policy_map: Dict[PolicyID, Policy], clip_rewards: bool,
  66. callbacks: "DefaultCallbacks"):
  67. """Initialize a MultiAgentSampleBatchBuilder.
  68. Args:
  69. policy_map (Dict[str,Policy]): Maps policy ids to policy instances.
  70. clip_rewards (Union[bool,float]): Whether to clip rewards before
  71. postprocessing (at +/-1.0) or the actual value to +/- clip.
  72. callbacks (DefaultCallbacks): RLlib callbacks.
  73. """
  74. if log_once("MultiAgentSampleBatchBuilder"):
  75. deprecation_warning(
  76. old="MultiAgentSampleBatchBuilder", error=False)
  77. self.policy_map = policy_map
  78. self.clip_rewards = clip_rewards
  79. # Build the Policies' SampleBatchBuilders.
  80. self.policy_builders = {
  81. k: SampleBatchBuilder()
  82. for k in policy_map.keys()
  83. }
  84. # Whenever we observe a new agent, add a new SampleBatchBuilder for
  85. # this agent.
  86. self.agent_builders = {}
  87. # Internal agent-to-policy map.
  88. self.agent_to_policy = {}
  89. self.callbacks = callbacks
  90. # Number of "inference" steps taken in the environment.
  91. # Regardless of the number of agents involved in each of these steps.
  92. self.count = 0
  93. def total(self) -> int:
  94. """Returns the total number of steps taken in the env (all agents).
  95. Returns:
  96. int: The number of steps taken in total in the environment over all
  97. agents.
  98. """
  99. return sum(a.count for a in self.agent_builders.values())
  100. def has_pending_agent_data(self) -> bool:
  101. """Returns whether there is pending unprocessed data.
  102. Returns:
  103. bool: True if there is at least one per-agent builder (with data
  104. in it).
  105. """
  106. return len(self.agent_builders) > 0
  107. @DeveloperAPI
  108. def add_values(self, agent_id: AgentID, policy_id: AgentID,
  109. **values: Any) -> None:
  110. """Add the given dictionary (row) of values to this batch.
  111. Args:
  112. agent_id (obj): Unique id for the agent we are adding values for.
  113. policy_id (obj): Unique id for policy controlling the agent.
  114. values (dict): Row of values to add for this agent.
  115. """
  116. if agent_id not in self.agent_builders:
  117. self.agent_builders[agent_id] = SampleBatchBuilder()
  118. self.agent_to_policy[agent_id] = policy_id
  119. # Include the current agent id for multi-agent algorithms.
  120. if agent_id != _DUMMY_AGENT_ID:
  121. values["agent_id"] = agent_id
  122. self.agent_builders[agent_id].add_values(**values)
  123. def postprocess_batch_so_far(self,
  124. episode: Optional[Episode] = None) -> None:
  125. """Apply policy postprocessors to any unprocessed rows.
  126. This pushes the postprocessed per-agent batches onto the per-policy
  127. builders, clearing per-agent state.
  128. Args:
  129. episode (Optional[Episode]): The Episode object that
  130. holds this MultiAgentBatchBuilder object.
  131. """
  132. # Materialize the batches so far.
  133. pre_batches = {}
  134. for agent_id, builder in self.agent_builders.items():
  135. pre_batches[agent_id] = (
  136. self.policy_map[self.agent_to_policy[agent_id]],
  137. builder.build_and_reset())
  138. # Apply postprocessor.
  139. post_batches = {}
  140. if self.clip_rewards is True:
  141. for _, (_, pre_batch) in pre_batches.items():
  142. pre_batch["rewards"] = np.sign(pre_batch["rewards"])
  143. elif self.clip_rewards:
  144. for _, (_, pre_batch) in pre_batches.items():
  145. pre_batch["rewards"] = np.clip(
  146. pre_batch["rewards"],
  147. a_min=-self.clip_rewards,
  148. a_max=self.clip_rewards)
  149. for agent_id, (_, pre_batch) in pre_batches.items():
  150. other_batches = pre_batches.copy()
  151. del other_batches[agent_id]
  152. policy = self.policy_map[self.agent_to_policy[agent_id]]
  153. if any(pre_batch["dones"][:-1]) or len(set(
  154. pre_batch["eps_id"])) > 1:
  155. raise ValueError(
  156. "Batches sent to postprocessing must only contain steps "
  157. "from a single trajectory.", pre_batch)
  158. # Call the Policy's Exploration's postprocess method.
  159. post_batches[agent_id] = pre_batch
  160. if getattr(policy, "exploration", None) is not None:
  161. policy.exploration.postprocess_trajectory(
  162. policy, post_batches[agent_id], policy.get_session())
  163. post_batches[agent_id] = policy.postprocess_trajectory(
  164. post_batches[agent_id], other_batches, episode)
  165. if log_once("after_post"):
  166. logger.info(
  167. "Trajectory fragment after postprocess_trajectory():\n\n{}\n".
  168. format(summarize(post_batches)))
  169. # Append into policy batches and reset
  170. from ray.rllib.evaluation.rollout_worker import get_global_worker
  171. for agent_id, post_batch in sorted(post_batches.items()):
  172. self.callbacks.on_postprocess_trajectory(
  173. worker=get_global_worker(),
  174. episode=episode,
  175. agent_id=agent_id,
  176. policy_id=self.agent_to_policy[agent_id],
  177. policies=self.policy_map,
  178. postprocessed_batch=post_batch,
  179. original_batches=pre_batches)
  180. self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
  181. post_batch)
  182. self.agent_builders.clear()
  183. self.agent_to_policy.clear()
  184. def check_missing_dones(self) -> None:
  185. for agent_id, builder in self.agent_builders.items():
  186. if builder.buffers["dones"][-1] is not True:
  187. raise ValueError(
  188. "The environment terminated for all agents, but we still "
  189. "don't have a last observation for "
  190. "agent {} (policy {}). ".format(
  191. agent_id, self.agent_to_policy[agent_id]) +
  192. "Please ensure that you include the last observations "
  193. "of all live agents when setting '__all__' done to True. "
  194. "Alternatively, set no_done_at_end=True to allow this.")
  195. @DeveloperAPI
  196. def build_and_reset(self,
  197. episode: Optional[Episode] = None) -> MultiAgentBatch:
  198. """Returns the accumulated sample batches for each policy.
  199. Any unprocessed rows will be first postprocessed with a policy
  200. postprocessor. The internal state of this builder will be reset.
  201. Args:
  202. episode (Optional[Episode]): The Episode object that
  203. holds this MultiAgentBatchBuilder object or None.
  204. Returns:
  205. MultiAgentBatch: Returns the accumulated sample batches for each
  206. policy.
  207. """
  208. self.postprocess_batch_so_far(episode)
  209. policy_batches = {}
  210. for policy_id, builder in self.policy_builders.items():
  211. if builder.count > 0:
  212. policy_batches[policy_id] = builder.build_and_reset()
  213. old_count = self.count
  214. self.count = 0
  215. return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)