episode.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. from collections import defaultdict
  2. import numpy as np
  3. import random
  4. import tree # pip install dm_tree
  5. from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
  6. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  7. from ray.rllib.policy.policy_map import PolicyMap
  8. from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
  9. from ray.rllib.utils.deprecation import deprecation_warning
  10. from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
  11. from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
  12. EnvActionType, EnvID, EnvInfoDict, EnvObsType
  13. from ray.util import log_once
  14. if TYPE_CHECKING:
  15. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  16. from ray.rllib.evaluation.sample_batch_builder import \
  17. MultiAgentSampleBatchBuilder
  18. @DeveloperAPI
  19. class Episode:
  20. """Tracks the current state of a (possibly multi-agent) episode.
  21. Attributes:
  22. new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
  23. add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
  24. batch_builder (obj): Batch builder for the current episode.
  25. total_reward (float): Summed reward across all agents in this episode.
  26. length (int): Length of this episode.
  27. episode_id (int): Unique id identifying this trajectory.
  28. agent_rewards (dict): Summed rewards broken down by agent.
  29. custom_metrics (dict): Dict where the you can add custom metrics.
  30. user_data (dict): Dict that you can use for temporary storage. E.g.
  31. in between two custom callbacks referring to the same episode.
  32. hist_data (dict): Dict mapping str keys to List[float] for storage of
  33. per-timestep float data throughout the episode.
  34. Use case 1: Model-based rollouts in multi-agent:
  35. A custom compute_actions() function in a policy can inspect the
  36. current episode state and perform a number of rollouts based on the
  37. policies and state of other agents in the environment.
  38. Use case 2: Returning extra rollouts data.
  39. The model rollouts can be returned back to the sampler by calling:
  40. >>> batch = episode.new_batch_builder()
  41. >>> for each transition:
  42. batch.add_values(...) # see sampler for usage
  43. >>> episode.extra_batches.add(batch.build_and_reset())
  44. """
  45. def __init__(
  46. self,
  47. policies: PolicyMap,
  48. policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
  49. PolicyID],
  50. batch_builder_factory: Callable[[],
  51. "MultiAgentSampleBatchBuilder"],
  52. extra_batch_callback: Callable[[SampleBatchType], None],
  53. env_id: EnvID,
  54. *,
  55. worker: Optional["RolloutWorker"] = None,
  56. ):
  57. """Initializes an Episode instance.
  58. Args:
  59. policies: The PolicyMap object (mapping PolicyIDs to Policy
  60. objects) to use for determining, which policy is used for
  61. which agent.
  62. policy_mapping_fn: The mapping function mapping AgentIDs to
  63. PolicyIDs.
  64. batch_builder_factory:
  65. extra_batch_callback:
  66. env_id: The environment's ID in which this episode runs.
  67. worker: The RolloutWorker instance, in which this episode runs.
  68. """
  69. self.new_batch_builder: Callable[
  70. [], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
  71. self.add_extra_batch: Callable[[SampleBatchType],
  72. None] = extra_batch_callback
  73. self.batch_builder: "MultiAgentSampleBatchBuilder" = \
  74. batch_builder_factory()
  75. self.total_reward: float = 0.0
  76. self.length: int = 0
  77. self.episode_id: int = random.randrange(2e9)
  78. self.env_id = env_id
  79. self.worker = worker
  80. self.agent_rewards: Dict[AgentID, float] = defaultdict(float)
  81. self.custom_metrics: Dict[str, float] = {}
  82. self.user_data: Dict[str, Any] = {}
  83. self.hist_data: Dict[str, List[float]] = {}
  84. self.media: Dict[str, Any] = {}
  85. self.policy_map: PolicyMap = policies
  86. self._policies = self.policy_map # backward compatibility
  87. self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
  88. PolicyID] = policy_mapping_fn
  89. self._next_agent_index: int = 0
  90. self._agent_to_index: Dict[AgentID, int] = {}
  91. self._agent_to_policy: Dict[AgentID, PolicyID] = {}
  92. self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
  93. self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
  94. self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
  95. self._agent_to_last_done: Dict[AgentID, bool] = {}
  96. self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
  97. self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
  98. self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
  99. self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
  100. self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
  101. list)
  102. @DeveloperAPI
  103. def soft_reset(self) -> None:
  104. """Clears rewards and metrics, but retains RNN and other state.
  105. This is used to carry state across multiple logical episodes in the
  106. same env (i.e., if `soft_horizon` is set).
  107. """
  108. self.length = 0
  109. self.episode_id = random.randrange(2e9)
  110. self.total_reward = 0.0
  111. self.agent_rewards = defaultdict(float)
  112. self._agent_reward_history = defaultdict(list)
  113. @DeveloperAPI
  114. def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
  115. """Returns and stores the policy ID for the specified agent.
  116. If the agent is new, the policy mapping fn will be called to bind the
  117. agent to a policy for the duration of the entire episode (even if the
  118. policy_mapping_fn is changed in the meantime!).
  119. Args:
  120. agent_id: The agent ID to lookup the policy ID for.
  121. Returns:
  122. The policy ID for the specified agent.
  123. """
  124. # Perform a new policy_mapping_fn lookup and bind AgentID for the
  125. # duration of this episode to the returned PolicyID.
  126. if agent_id not in self._agent_to_policy:
  127. # Try new API: pass in agent_id and episode as named args.
  128. # New signature should be: (agent_id, episode, worker, **kwargs)
  129. try:
  130. policy_id = self._agent_to_policy[agent_id] = \
  131. self.policy_mapping_fn(agent_id, self, worker=self.worker)
  132. except TypeError as e:
  133. if "positional argument" in e.args[0] or \
  134. "unexpected keyword argument" in e.args[0]:
  135. if log_once("policy_mapping_new_signature"):
  136. deprecation_warning(
  137. old="policy_mapping_fn(agent_id)",
  138. new="policy_mapping_fn(agent_id, episode, "
  139. "worker, **kwargs)")
  140. policy_id = self._agent_to_policy[agent_id] = \
  141. self.policy_mapping_fn(agent_id)
  142. else:
  143. raise e
  144. # Use already determined PolicyID.
  145. else:
  146. policy_id = self._agent_to_policy[agent_id]
  147. # PolicyID not found in policy map -> Error.
  148. if policy_id not in self.policy_map:
  149. raise KeyError("policy_mapping_fn returned invalid policy id "
  150. f"'{policy_id}'!")
  151. return policy_id
  152. @DeveloperAPI
  153. def last_observation_for(
  154. self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
  155. """Returns the last observation for the specified AgentID.
  156. Args:
  157. agent_id: The agent's ID to get the last observation for.
  158. Returns:
  159. Last observation the specified AgentID has seen. None in case
  160. the agent has never made any observations in the episode.
  161. """
  162. return self._agent_to_last_obs.get(agent_id)
  163. @DeveloperAPI
  164. def last_raw_obs_for(
  165. self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
  166. """Returns the last un-preprocessed obs for the specified AgentID.
  167. Args:
  168. agent_id: The agent's ID to get the last un-preprocessed
  169. observation for.
  170. Returns:
  171. Last un-preprocessed observation the specified AgentID has seen.
  172. None in case the agent has never made any observations in the
  173. episode.
  174. """
  175. return self._agent_to_last_raw_obs.get(agent_id)
  176. @DeveloperAPI
  177. def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
  178. ) -> Optional[EnvInfoDict]:
  179. """Returns the last info for the specified AgentID.
  180. Args:
  181. agent_id: The agent's ID to get the last info for.
  182. Returns:
  183. Last info dict the specified AgentID has seen.
  184. None in case the agent has never made any observations in the
  185. episode.
  186. """
  187. return self._agent_to_last_info.get(agent_id)
  188. @DeveloperAPI
  189. def last_action_for(self,
  190. agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
  191. """Returns the last action for the specified AgentID, or zeros.
  192. The "last" action is the most recent one taken by the agent.
  193. Args:
  194. agent_id: The agent's ID to get the last action for.
  195. Returns:
  196. Last action the specified AgentID has executed.
  197. Zeros in case the agent has never performed any actions in the
  198. episode.
  199. """
  200. policy_id = self.policy_for(agent_id)
  201. policy = self.policy_map[policy_id]
  202. # Agent has already taken at least one action in the episode.
  203. if agent_id in self._agent_to_last_action:
  204. if policy.config.get("_disable_action_flattening"):
  205. return self._agent_to_last_action[agent_id]
  206. else:
  207. return flatten_to_single_ndarray(
  208. self._agent_to_last_action[agent_id])
  209. # Agent has not acted yet, return all zeros.
  210. else:
  211. if policy.config.get("_disable_action_flattening"):
  212. return tree.map_structure(
  213. lambda s: np.zeros_like(s.sample(), s.dtype) if
  214. hasattr(s, "dtype") else np.zeros_like(s.sample()),
  215. policy.action_space_struct,
  216. )
  217. else:
  218. flat = flatten_to_single_ndarray(policy.action_space.sample())
  219. if hasattr(policy.action_space, "dtype"):
  220. return np.zeros_like(flat, dtype=policy.action_space.dtype)
  221. return np.zeros_like(flat)
  222. @DeveloperAPI
  223. def prev_action_for(self,
  224. agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
  225. """Returns the previous action for the specified agent, or zeros.
  226. The "previous" action is the one taken one timestep before the
  227. most recent action taken by the agent.
  228. Args:
  229. agent_id: The agent's ID to get the previous action for.
  230. Returns:
  231. Previous action the specified AgentID has executed.
  232. Zero in case the agent has never performed any actions (or only
  233. one) in the episode.
  234. """
  235. policy_id = self.policy_for(agent_id)
  236. policy = self.policy_map[policy_id]
  237. # We are at t > 1 -> There has been a previous action by this agent.
  238. if agent_id in self._agent_to_prev_action:
  239. if policy.config.get("_disable_action_flattening"):
  240. return self._agent_to_prev_action[agent_id]
  241. else:
  242. return flatten_to_single_ndarray(
  243. self._agent_to_prev_action[agent_id])
  244. # We're at t <= 1, so return all zeros.
  245. else:
  246. if policy.config.get("_disable_action_flattening"):
  247. return tree.map_structure(
  248. lambda a: np.zeros_like(a, a.dtype) if # noqa
  249. hasattr(a, "dtype") else np.zeros_like(a), # noqa
  250. self.last_action_for(agent_id),
  251. )
  252. else:
  253. return np.zeros_like(self.last_action_for(agent_id))
  254. @DeveloperAPI
  255. def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
  256. """Returns the last reward for the specified agent, or zero.
  257. The "last" reward is the one received most recently by the agent.
  258. Args:
  259. agent_id: The agent's ID to get the last reward for.
  260. Returns:
  261. Last reward for the the specified AgentID.
  262. Zero in case the agent has never performed any actions
  263. (and thus received rewards) in the episode.
  264. """
  265. history = self._agent_reward_history[agent_id]
  266. # We are at t > 0 -> Return previously received reward.
  267. if len(history) >= 1:
  268. return history[-1]
  269. # We're at t=0, so there is no previous reward, just return zero.
  270. else:
  271. return 0.0
  272. @DeveloperAPI
  273. def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
  274. """Returns the previous reward for the specified agent, or zero.
  275. The "previous" reward is the one received one timestep before the
  276. most recently received reward of the agent.
  277. Args:
  278. agent_id: The agent's ID to get the previous reward for.
  279. Returns:
  280. Previous reward for the the specified AgentID.
  281. Zero in case the agent has never performed any actions (or only
  282. one) in the episode.
  283. """
  284. history = self._agent_reward_history[agent_id]
  285. # We are at t > 1 -> Return reward prior to most recent (last) one.
  286. if len(history) >= 2:
  287. return history[-2]
  288. # We're at t <= 1, so there is no previous reward, just return zero.
  289. else:
  290. return 0.0
  291. @DeveloperAPI
  292. def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
  293. """Returns the last RNN state for the specified agent.
  294. Args:
  295. agent_id: The agent's ID to get the most recent RNN state for.
  296. Returns:
  297. Most recent RNN state of the the specified AgentID.
  298. """
  299. if agent_id not in self._agent_to_rnn_state:
  300. policy_id = self.policy_for(agent_id)
  301. policy = self.policy_map[policy_id]
  302. self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
  303. return self._agent_to_rnn_state[agent_id]
  304. @DeveloperAPI
  305. def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
  306. """Returns the last done flag for the specified AgentID.
  307. Args:
  308. agent_id: The agent's ID to get the last done flag for.
  309. Returns:
  310. Last done flag for the specified AgentID.
  311. """
  312. if agent_id not in self._agent_to_last_done:
  313. self._agent_to_last_done[agent_id] = False
  314. return self._agent_to_last_done[agent_id]
  315. @DeveloperAPI
  316. def last_extra_action_outs_for(
  317. self,
  318. agent_id: AgentID = _DUMMY_AGENT_ID,
  319. ) -> dict:
  320. """Returns the last extra-action outputs for the specified agent.
  321. This data is returned by a call to
  322. `Policy.compute_actions_from_input_dict` as the 3rd return value
  323. (1st return value = action; 2nd return value = RNN state outs).
  324. Args:
  325. agent_id: The agent's ID to get the last extra-action outs for.
  326. Returns:
  327. The last extra-action outs for the specified AgentID.
  328. """
  329. return self._agent_to_last_extra_action_outs[agent_id]
  330. @DeveloperAPI
  331. def get_agents(self) -> List[AgentID]:
  332. """Returns list of agent IDs that have appeared in this episode.
  333. Returns:
  334. The list of all agent IDs that have appeared so far in this
  335. episode.
  336. """
  337. return list(self._agent_to_index.keys())
  338. def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None:
  339. for agent_id, reward in reward_dict.items():
  340. if reward is not None:
  341. self.agent_rewards[agent_id,
  342. self.policy_for(agent_id)] += reward
  343. self.total_reward += reward
  344. self._agent_reward_history[agent_id].append(reward)
  345. def _set_rnn_state(self, agent_id, rnn_state):
  346. self._agent_to_rnn_state[agent_id] = rnn_state
  347. def _set_last_observation(self, agent_id, obs):
  348. self._agent_to_last_obs[agent_id] = obs
  349. def _set_last_raw_obs(self, agent_id, obs):
  350. self._agent_to_last_raw_obs[agent_id] = obs
  351. def _set_last_done(self, agent_id, done):
  352. self._agent_to_last_done[agent_id] = done
  353. def _set_last_info(self, agent_id, info):
  354. self._agent_to_last_info[agent_id] = info
  355. def _set_last_action(self, agent_id, action):
  356. if agent_id in self._agent_to_last_action:
  357. self._agent_to_prev_action[agent_id] = \
  358. self._agent_to_last_action[agent_id]
  359. self._agent_to_last_action[agent_id] = action
  360. def _set_last_extra_action_outs(self, agent_id, pi_info):
  361. self._agent_to_last_extra_action_outs[agent_id] = pi_info
  362. def _agent_index(self, agent_id):
  363. if agent_id not in self._agent_to_index:
  364. self._agent_to_index[agent_id] = self._next_agent_index
  365. self._next_agent_index += 1
  366. return self._agent_to_index[agent_id]
  367. @property
  368. def _policy_mapping_fn(self):
  369. deprecation_warning(
  370. old="Episode._policy_mapping_fn",
  371. new="Episode.policy_mapping_fn",
  372. error=False,
  373. )
  374. return self.policy_mapping_fn
  375. @Deprecated(new="Episode.last_extra_action_outs_for", error=False)
  376. def last_pi_info_for(self, *args, **kwargs):
  377. return self.last_extra_action_outs_for(*args, **kwargs)
  378. # Backward compatibility. The name Episode implies that there is
  379. # also a (single agent?) Episode.
  380. @Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
  381. class MultiAgentEpisode(Episode):
  382. pass