episode.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. from collections import defaultdict
  2. import numpy as np
  3. import random
  4. from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
  5. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  6. from ray.rllib.policy.policy_map import PolicyMap
  7. from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
  8. from ray.rllib.utils.deprecation import deprecation_warning
  9. from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
  10. from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
  11. EnvActionType, EnvID, EnvInfoDict, EnvObsType
  12. from ray.util import log_once
  13. if TYPE_CHECKING:
  14. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  15. from ray.rllib.evaluation.sample_batch_builder import \
  16. MultiAgentSampleBatchBuilder
  17. @DeveloperAPI
  18. class Episode:
  19. """Tracks the current state of a (possibly multi-agent) episode.
  20. Attributes:
  21. new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
  22. add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
  23. batch_builder (obj): Batch builder for the current episode.
  24. total_reward (float): Summed reward across all agents in this episode.
  25. length (int): Length of this episode.
  26. episode_id (int): Unique id identifying this trajectory.
  27. agent_rewards (dict): Summed rewards broken down by agent.
  28. custom_metrics (dict): Dict where the you can add custom metrics.
  29. user_data (dict): Dict that you can use for temporary storage. E.g.
  30. in between two custom callbacks referring to the same episode.
  31. hist_data (dict): Dict mapping str keys to List[float] for storage of
  32. per-timestep float data throughout the episode.
  33. Use case 1: Model-based rollouts in multi-agent:
  34. A custom compute_actions() function in a policy can inspect the
  35. current episode state and perform a number of rollouts based on the
  36. policies and state of other agents in the environment.
  37. Use case 2: Returning extra rollouts data.
  38. The model rollouts can be returned back to the sampler by calling:
  39. >>> batch = episode.new_batch_builder()
  40. >>> for each transition:
  41. batch.add_values(...) # see sampler for usage
  42. >>> episode.extra_batches.add(batch.build_and_reset())
  43. """
  44. def __init__(
  45. self,
  46. policies: PolicyMap,
  47. policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
  48. PolicyID],
  49. batch_builder_factory: Callable[[],
  50. "MultiAgentSampleBatchBuilder"],
  51. extra_batch_callback: Callable[[SampleBatchType], None],
  52. env_id: EnvID,
  53. *,
  54. worker: Optional["RolloutWorker"] = None,
  55. ):
  56. """Initializes an Episode instance.
  57. Args:
  58. policies: The PolicyMap object (mapping PolicyIDs to Policy
  59. objects) to use for determining, which policy is used for
  60. which agent.
  61. policy_mapping_fn: The mapping function mapping AgentIDs to
  62. PolicyIDs.
  63. batch_builder_factory:
  64. extra_batch_callback:
  65. env_id: The environment's ID in which this episode runs.
  66. worker: The RolloutWorker instance, in which this episode runs.
  67. """
  68. self.new_batch_builder: Callable[
  69. [], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
  70. self.add_extra_batch: Callable[[SampleBatchType],
  71. None] = extra_batch_callback
  72. self.batch_builder: "MultiAgentSampleBatchBuilder" = \
  73. batch_builder_factory()
  74. self.total_reward: float = 0.0
  75. self.length: int = 0
  76. self.episode_id: int = random.randrange(2e9)
  77. self.env_id = env_id
  78. self.worker = worker
  79. self.agent_rewards: Dict[AgentID, float] = defaultdict(float)
  80. self.custom_metrics: Dict[str, float] = {}
  81. self.user_data: Dict[str, Any] = {}
  82. self.hist_data: Dict[str, List[float]] = {}
  83. self.media: Dict[str, Any] = {}
  84. self.policy_map: PolicyMap = policies
  85. self._policies = self.policy_map # backward compatibility
  86. self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
  87. PolicyID] = policy_mapping_fn
  88. self._next_agent_index: int = 0
  89. self._agent_to_index: Dict[AgentID, int] = {}
  90. self._agent_to_policy: Dict[AgentID, PolicyID] = {}
  91. self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
  92. self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
  93. self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
  94. self._agent_to_last_done: Dict[AgentID, bool] = {}
  95. self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
  96. self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
  97. self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
  98. self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
  99. self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
  100. list)
  101. @DeveloperAPI
  102. def soft_reset(self) -> None:
  103. """Clears rewards and metrics, but retains RNN and other state.
  104. This is used to carry state across multiple logical episodes in the
  105. same env (i.e., if `soft_horizon` is set).
  106. """
  107. self.length = 0
  108. self.episode_id = random.randrange(2e9)
  109. self.total_reward = 0.0
  110. self.agent_rewards = defaultdict(float)
  111. self._agent_reward_history = defaultdict(list)
  112. @DeveloperAPI
  113. def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
  114. """Returns and stores the policy ID for the specified agent.
  115. If the agent is new, the policy mapping fn will be called to bind the
  116. agent to a policy for the duration of the entire episode (even if the
  117. policy_mapping_fn is changed in the meantime!).
  118. Args:
  119. agent_id: The agent ID to lookup the policy ID for.
  120. Returns:
  121. The policy ID for the specified agent.
  122. """
  123. # Perform a new policy_mapping_fn lookup and bind AgentID for the
  124. # duration of this episode to the returned PolicyID.
  125. if agent_id not in self._agent_to_policy:
  126. # Try new API: pass in agent_id and episode as named args.
  127. # New signature should be: (agent_id, episode, worker, **kwargs)
  128. try:
  129. policy_id = self._agent_to_policy[agent_id] = \
  130. self.policy_mapping_fn(agent_id, self, worker=self.worker)
  131. except TypeError as e:
  132. if "positional argument" in e.args[0] or \
  133. "unexpected keyword argument" in e.args[0]:
  134. if log_once("policy_mapping_new_signature"):
  135. deprecation_warning(
  136. old="policy_mapping_fn(agent_id)",
  137. new="policy_mapping_fn(agent_id, episode, "
  138. "worker, **kwargs)")
  139. policy_id = self._agent_to_policy[agent_id] = \
  140. self.policy_mapping_fn(agent_id)
  141. else:
  142. raise e
  143. # Use already determined PolicyID.
  144. else:
  145. policy_id = self._agent_to_policy[agent_id]
  146. # PolicyID not found in policy map -> Error.
  147. if policy_id not in self.policy_map:
  148. raise KeyError("policy_mapping_fn returned invalid policy id "
  149. f"'{policy_id}'!")
  150. return policy_id
  151. @DeveloperAPI
  152. def last_observation_for(
  153. self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
  154. """Returns the last observation for the specified AgentID.
  155. Args:
  156. agent_id: The agent's ID to get the last observation for.
  157. Returns:
  158. Last observation the specified AgentID has seen. None in case
  159. the agent has never made any observations in the episode.
  160. """
  161. return self._agent_to_last_obs.get(agent_id)
  162. @DeveloperAPI
  163. def last_raw_obs_for(
  164. self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
  165. """Returns the last un-preprocessed obs for the specified AgentID.
  166. Args:
  167. agent_id: The agent's ID to get the last un-preprocessed
  168. observation for.
  169. Returns:
  170. Last un-preprocessed observation the specified AgentID has seen.
  171. None in case the agent has never made any observations in the
  172. episode.
  173. """
  174. return self._agent_to_last_raw_obs.get(agent_id)
  175. @DeveloperAPI
  176. def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
  177. ) -> Optional[EnvInfoDict]:
  178. """Returns the last info for the specified AgentID.
  179. Args:
  180. agent_id: The agent's ID to get the last info for.
  181. Returns:
  182. Last info dict the specified AgentID has seen.
  183. None in case the agent has never made any observations in the
  184. episode.
  185. """
  186. return self._agent_to_last_info.get(agent_id)
  187. @DeveloperAPI
  188. def last_action_for(self,
  189. agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
  190. """Returns the last action for the specified AgentID, or zeros.
  191. The "last" action is the most recent one taken by the agent.
  192. Args:
  193. agent_id: The agent's ID to get the last action for.
  194. Returns:
  195. Last action the specified AgentID has executed.
  196. Zeros in case the agent has never performed any actions in the
  197. episode.
  198. """
  199. # Agent has already taken at least one action in the episode.
  200. if agent_id in self._agent_to_last_action:
  201. return flatten_to_single_ndarray(
  202. self._agent_to_last_action[agent_id])
  203. # Agent has not acted yet, return all zeros.
  204. else:
  205. policy_id = self.policy_for(agent_id)
  206. policy = self.policy_map[policy_id]
  207. flat = flatten_to_single_ndarray(policy.action_space.sample())
  208. if hasattr(policy.action_space, "dtype"):
  209. return np.zeros_like(flat, dtype=policy.action_space.dtype)
  210. return np.zeros_like(flat)
  211. @DeveloperAPI
  212. def prev_action_for(self,
  213. agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
  214. """Returns the previous action for the specified agent, or zeros.
  215. The "previous" action is the one taken one timestep before the
  216. most recent action taken by the agent.
  217. Args:
  218. agent_id: The agent's ID to get the previous action for.
  219. Returns:
  220. Previous action the specified AgentID has executed.
  221. Zero in case the agent has never performed any actions (or only
  222. one) in the episode.
  223. """
  224. # We are at t > 1 -> There has been a previous action by this agent.
  225. if agent_id in self._agent_to_prev_action:
  226. return flatten_to_single_ndarray(
  227. self._agent_to_prev_action[agent_id])
  228. # We're at t <= 1, so return all zeros.
  229. else:
  230. return np.zeros_like(self.last_action_for(agent_id))
  231. @DeveloperAPI
  232. def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
  233. """Returns the last reward for the specified agent, or zero.
  234. The "last" reward is the one received most recently by the agent.
  235. Args:
  236. agent_id: The agent's ID to get the last reward for.
  237. Returns:
  238. Last reward for the the specified AgentID.
  239. Zero in case the agent has never performed any actions
  240. (and thus received rewards) in the episode.
  241. """
  242. history = self._agent_reward_history[agent_id]
  243. # We are at t > 0 -> Return previously received reward.
  244. if len(history) >= 1:
  245. return history[-1]
  246. # We're at t=0, so there is no previous reward, just return zero.
  247. else:
  248. return 0.0
  249. @DeveloperAPI
  250. def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
  251. """Returns the previous reward for the specified agent, or zero.
  252. The "previous" reward is the one received one timestep before the
  253. most recently received reward of the agent.
  254. Args:
  255. agent_id: The agent's ID to get the previous reward for.
  256. Returns:
  257. Previous reward for the the specified AgentID.
  258. Zero in case the agent has never performed any actions (or only
  259. one) in the episode.
  260. """
  261. history = self._agent_reward_history[agent_id]
  262. # We are at t > 1 -> Return reward prior to most recent (last) one.
  263. if len(history) >= 2:
  264. return history[-2]
  265. # We're at t <= 1, so there is no previous reward, just return zero.
  266. else:
  267. return 0.0
  268. @DeveloperAPI
  269. def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
  270. """Returns the last RNN state for the specified agent.
  271. Args:
  272. agent_id: The agent's ID to get the most recent RNN state for.
  273. Returns:
  274. Most recent RNN state of the the specified AgentID.
  275. """
  276. if agent_id not in self._agent_to_rnn_state:
  277. policy_id = self.policy_for(agent_id)
  278. policy = self.policy_map[policy_id]
  279. self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
  280. return self._agent_to_rnn_state[agent_id]
  281. @DeveloperAPI
  282. def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
  283. """Returns the last done flag for the specified AgentID.
  284. Args:
  285. agent_id: The agent's ID to get the last done flag for.
  286. Returns:
  287. Last done flag for the specified AgentID.
  288. """
  289. if agent_id not in self._agent_to_last_done:
  290. self._agent_to_last_done[agent_id] = False
  291. return self._agent_to_last_done[agent_id]
  292. @DeveloperAPI
  293. def last_extra_action_outs_for(
  294. self,
  295. agent_id: AgentID = _DUMMY_AGENT_ID,
  296. ) -> dict:
  297. """Returns the last extra-action outputs for the specified agent.
  298. This data is returned by a call to
  299. `Policy.compute_actions_from_input_dict` as the 3rd return value
  300. (1st return value = action; 2nd return value = RNN state outs).
  301. Args:
  302. agent_id: The agent's ID to get the last extra-action outs for.
  303. Returns:
  304. The last extra-action outs for the specified AgentID.
  305. """
  306. return self._agent_to_last_extra_action_outs[agent_id]
  307. @DeveloperAPI
  308. def get_agents(self) -> List[AgentID]:
  309. """Returns list of agent IDs that have appeared in this episode.
  310. Returns:
  311. The list of all agent IDs that have appeared so far in this
  312. episode.
  313. """
  314. return list(self._agent_to_index.keys())
  315. def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None:
  316. for agent_id, reward in reward_dict.items():
  317. if reward is not None:
  318. self.agent_rewards[agent_id,
  319. self.policy_for(agent_id)] += reward
  320. self.total_reward += reward
  321. self._agent_reward_history[agent_id].append(reward)
  322. def _set_rnn_state(self, agent_id, rnn_state):
  323. self._agent_to_rnn_state[agent_id] = rnn_state
  324. def _set_last_observation(self, agent_id, obs):
  325. self._agent_to_last_obs[agent_id] = obs
  326. def _set_last_raw_obs(self, agent_id, obs):
  327. self._agent_to_last_raw_obs[agent_id] = obs
  328. def _set_last_done(self, agent_id, done):
  329. self._agent_to_last_done[agent_id] = done
  330. def _set_last_info(self, agent_id, info):
  331. self._agent_to_last_info[agent_id] = info
  332. def _set_last_action(self, agent_id, action):
  333. if agent_id in self._agent_to_last_action:
  334. self._agent_to_prev_action[agent_id] = \
  335. self._agent_to_last_action[agent_id]
  336. self._agent_to_last_action[agent_id] = action
  337. def _set_last_extra_action_outs(self, agent_id, pi_info):
  338. self._agent_to_last_extra_action_outs[agent_id] = pi_info
  339. def _agent_index(self, agent_id):
  340. if agent_id not in self._agent_to_index:
  341. self._agent_to_index[agent_id] = self._next_agent_index
  342. self._next_agent_index += 1
  343. return self._agent_to_index[agent_id]
  344. @property
  345. def _policy_mapping_fn(self):
  346. deprecation_warning(
  347. old="Episode._policy_mapping_fn",
  348. new="Episode.policy_mapping_fn",
  349. error=False,
  350. )
  351. return self.policy_mapping_fn
  352. @Deprecated(new="Episode.last_extra_action_outs_for", error=False)
  353. def last_pi_info_for(self, *args, **kwargs):
  354. return self.last_extra_action_outs_for(*args, **kwargs)
  355. # Backward compatibility. The name Episode implies that there is
  356. # also a (single agent?) Episode.
  357. @Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
  358. class MultiAgentEpisode(Episode):
  359. pass