from collections import defaultdict import numpy as np import random import tree # pip install dm_tree from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.utils.annotations import Deprecated, DeveloperAPI from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \ EnvActionType, EnvID, EnvInfoDict, EnvObsType from ray.util import log_once if TYPE_CHECKING: from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.sample_batch_builder import \ MultiAgentSampleBatchBuilder @DeveloperAPI class Episode: """Tracks the current state of a (possibly multi-agent) episode. Attributes: new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder. add_extra_batch (func): Return a built MultiAgentBatch to the sampler. batch_builder (obj): Batch builder for the current episode. total_reward (float): Summed reward across all agents in this episode. length (int): Length of this episode. episode_id (int): Unique id identifying this trajectory. agent_rewards (dict): Summed rewards broken down by agent. custom_metrics (dict): Dict where the you can add custom metrics. user_data (dict): Dict that you can use for temporary storage. E.g. in between two custom callbacks referring to the same episode. hist_data (dict): Dict mapping str keys to List[float] for storage of per-timestep float data throughout the episode. Use case 1: Model-based rollouts in multi-agent: A custom compute_actions() function in a policy can inspect the current episode state and perform a number of rollouts based on the policies and state of other agents in the environment. Use case 2: Returning extra rollouts data. The model rollouts can be returned back to the sampler by calling: >>> batch = episode.new_batch_builder() >>> for each transition: batch.add_values(...) # see sampler for usage >>> episode.extra_batches.add(batch.build_and_reset()) """ def __init__( self, policies: PolicyMap, policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"], PolicyID], batch_builder_factory: Callable[[], "MultiAgentSampleBatchBuilder"], extra_batch_callback: Callable[[SampleBatchType], None], env_id: EnvID, *, worker: Optional["RolloutWorker"] = None, ): """Initializes an Episode instance. Args: policies: The PolicyMap object (mapping PolicyIDs to Policy objects) to use for determining, which policy is used for which agent. policy_mapping_fn: The mapping function mapping AgentIDs to PolicyIDs. batch_builder_factory: extra_batch_callback: env_id: The environment's ID in which this episode runs. worker: The RolloutWorker instance, in which this episode runs. """ self.new_batch_builder: Callable[ [], "MultiAgentSampleBatchBuilder"] = batch_builder_factory self.add_extra_batch: Callable[[SampleBatchType], None] = extra_batch_callback self.batch_builder: "MultiAgentSampleBatchBuilder" = \ batch_builder_factory() self.total_reward: float = 0.0 self.length: int = 0 self.episode_id: int = random.randrange(2e9) self.env_id = env_id self.worker = worker self.agent_rewards: Dict[AgentID, float] = defaultdict(float) self.custom_metrics: Dict[str, float] = {} self.user_data: Dict[str, Any] = {} self.hist_data: Dict[str, List[float]] = {} self.media: Dict[str, Any] = {} self.policy_map: PolicyMap = policies self._policies = self.policy_map # backward compatibility self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"], PolicyID] = policy_mapping_fn self._next_agent_index: int = 0 self._agent_to_index: Dict[AgentID, int] = {} self._agent_to_policy: Dict[AgentID, PolicyID] = {} self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {} self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {} self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {} self._agent_to_last_done: Dict[AgentID, bool] = {} self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {} self._agent_to_last_action: Dict[AgentID, EnvActionType] = {} self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {} self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {} self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict( list) @DeveloperAPI def soft_reset(self) -> None: """Clears rewards and metrics, but retains RNN and other state. This is used to carry state across multiple logical episodes in the same env (i.e., if `soft_horizon` is set). """ self.length = 0 self.episode_id = random.randrange(2e9) self.total_reward = 0.0 self.agent_rewards = defaultdict(float) self._agent_reward_history = defaultdict(list) @DeveloperAPI def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID: """Returns and stores the policy ID for the specified agent. If the agent is new, the policy mapping fn will be called to bind the agent to a policy for the duration of the entire episode (even if the policy_mapping_fn is changed in the meantime!). Args: agent_id: The agent ID to lookup the policy ID for. Returns: The policy ID for the specified agent. """ # Perform a new policy_mapping_fn lookup and bind AgentID for the # duration of this episode to the returned PolicyID. if agent_id not in self._agent_to_policy: # Try new API: pass in agent_id and episode as named args. # New signature should be: (agent_id, episode, worker, **kwargs) try: policy_id = self._agent_to_policy[agent_id] = \ self.policy_mapping_fn(agent_id, self, worker=self.worker) except TypeError as e: if "positional argument" in e.args[0] or \ "unexpected keyword argument" in e.args[0]: if log_once("policy_mapping_new_signature"): deprecation_warning( old="policy_mapping_fn(agent_id)", new="policy_mapping_fn(agent_id, episode, " "worker, **kwargs)") policy_id = self._agent_to_policy[agent_id] = \ self.policy_mapping_fn(agent_id) else: raise e # Use already determined PolicyID. else: policy_id = self._agent_to_policy[agent_id] # PolicyID not found in policy map -> Error. if policy_id not in self.policy_map: raise KeyError("policy_mapping_fn returned invalid policy id " f"'{policy_id}'!") return policy_id @DeveloperAPI def last_observation_for( self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]: """Returns the last observation for the specified AgentID. Args: agent_id: The agent's ID to get the last observation for. Returns: Last observation the specified AgentID has seen. None in case the agent has never made any observations in the episode. """ return self._agent_to_last_obs.get(agent_id) @DeveloperAPI def last_raw_obs_for( self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]: """Returns the last un-preprocessed obs for the specified AgentID. Args: agent_id: The agent's ID to get the last un-preprocessed observation for. Returns: Last un-preprocessed observation the specified AgentID has seen. None in case the agent has never made any observations in the episode. """ return self._agent_to_last_raw_obs.get(agent_id) @DeveloperAPI def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID ) -> Optional[EnvInfoDict]: """Returns the last info for the specified AgentID. Args: agent_id: The agent's ID to get the last info for. Returns: Last info dict the specified AgentID has seen. None in case the agent has never made any observations in the episode. """ return self._agent_to_last_info.get(agent_id) @DeveloperAPI def last_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: """Returns the last action for the specified AgentID, or zeros. The "last" action is the most recent one taken by the agent. Args: agent_id: The agent's ID to get the last action for. Returns: Last action the specified AgentID has executed. Zeros in case the agent has never performed any actions in the episode. """ policy_id = self.policy_for(agent_id) policy = self.policy_map[policy_id] # Agent has already taken at least one action in the episode. if agent_id in self._agent_to_last_action: if policy.config.get("_disable_action_flattening"): return self._agent_to_last_action[agent_id] else: return flatten_to_single_ndarray( self._agent_to_last_action[agent_id]) # Agent has not acted yet, return all zeros. else: if policy.config.get("_disable_action_flattening"): return tree.map_structure( lambda s: np.zeros_like(s.sample(), s.dtype) if hasattr(s, "dtype") else np.zeros_like(s.sample()), policy.action_space_struct, ) else: flat = flatten_to_single_ndarray(policy.action_space.sample()) if hasattr(policy.action_space, "dtype"): return np.zeros_like(flat, dtype=policy.action_space.dtype) return np.zeros_like(flat) @DeveloperAPI def prev_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: """Returns the previous action for the specified agent, or zeros. The "previous" action is the one taken one timestep before the most recent action taken by the agent. Args: agent_id: The agent's ID to get the previous action for. Returns: Previous action the specified AgentID has executed. Zero in case the agent has never performed any actions (or only one) in the episode. """ policy_id = self.policy_for(agent_id) policy = self.policy_map[policy_id] # We are at t > 1 -> There has been a previous action by this agent. if agent_id in self._agent_to_prev_action: if policy.config.get("_disable_action_flattening"): return self._agent_to_prev_action[agent_id] else: return flatten_to_single_ndarray( self._agent_to_prev_action[agent_id]) # We're at t <= 1, so return all zeros. else: if policy.config.get("_disable_action_flattening"): return tree.map_structure( lambda a: np.zeros_like(a, a.dtype) if # noqa hasattr(a, "dtype") else np.zeros_like(a), # noqa self.last_action_for(agent_id), ) else: return np.zeros_like(self.last_action_for(agent_id)) @DeveloperAPI def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float: """Returns the last reward for the specified agent, or zero. The "last" reward is the one received most recently by the agent. Args: agent_id: The agent's ID to get the last reward for. Returns: Last reward for the the specified AgentID. Zero in case the agent has never performed any actions (and thus received rewards) in the episode. """ history = self._agent_reward_history[agent_id] # We are at t > 0 -> Return previously received reward. if len(history) >= 1: return history[-1] # We're at t=0, so there is no previous reward, just return zero. else: return 0.0 @DeveloperAPI def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float: """Returns the previous reward for the specified agent, or zero. The "previous" reward is the one received one timestep before the most recently received reward of the agent. Args: agent_id: The agent's ID to get the previous reward for. Returns: Previous reward for the the specified AgentID. Zero in case the agent has never performed any actions (or only one) in the episode. """ history = self._agent_reward_history[agent_id] # We are at t > 1 -> Return reward prior to most recent (last) one. if len(history) >= 2: return history[-2] # We're at t <= 1, so there is no previous reward, just return zero. else: return 0.0 @DeveloperAPI def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]: """Returns the last RNN state for the specified agent. Args: agent_id: The agent's ID to get the most recent RNN state for. Returns: Most recent RNN state of the the specified AgentID. """ if agent_id not in self._agent_to_rnn_state: policy_id = self.policy_for(agent_id) policy = self.policy_map[policy_id] self._agent_to_rnn_state[agent_id] = policy.get_initial_state() return self._agent_to_rnn_state[agent_id] @DeveloperAPI def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool: """Returns the last done flag for the specified AgentID. Args: agent_id: The agent's ID to get the last done flag for. Returns: Last done flag for the specified AgentID. """ if agent_id not in self._agent_to_last_done: self._agent_to_last_done[agent_id] = False return self._agent_to_last_done[agent_id] @DeveloperAPI def last_extra_action_outs_for( self, agent_id: AgentID = _DUMMY_AGENT_ID, ) -> dict: """Returns the last extra-action outputs for the specified agent. This data is returned by a call to `Policy.compute_actions_from_input_dict` as the 3rd return value (1st return value = action; 2nd return value = RNN state outs). Args: agent_id: The agent's ID to get the last extra-action outs for. Returns: The last extra-action outs for the specified AgentID. """ return self._agent_to_last_extra_action_outs[agent_id] @DeveloperAPI def get_agents(self) -> List[AgentID]: """Returns list of agent IDs that have appeared in this episode. Returns: The list of all agent IDs that have appeared so far in this episode. """ return list(self._agent_to_index.keys()) def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None: for agent_id, reward in reward_dict.items(): if reward is not None: self.agent_rewards[agent_id, self.policy_for(agent_id)] += reward self.total_reward += reward self._agent_reward_history[agent_id].append(reward) def _set_rnn_state(self, agent_id, rnn_state): self._agent_to_rnn_state[agent_id] = rnn_state def _set_last_observation(self, agent_id, obs): self._agent_to_last_obs[agent_id] = obs def _set_last_raw_obs(self, agent_id, obs): self._agent_to_last_raw_obs[agent_id] = obs def _set_last_done(self, agent_id, done): self._agent_to_last_done[agent_id] = done def _set_last_info(self, agent_id, info): self._agent_to_last_info[agent_id] = info def _set_last_action(self, agent_id, action): if agent_id in self._agent_to_last_action: self._agent_to_prev_action[agent_id] = \ self._agent_to_last_action[agent_id] self._agent_to_last_action[agent_id] = action def _set_last_extra_action_outs(self, agent_id, pi_info): self._agent_to_last_extra_action_outs[agent_id] = pi_info def _agent_index(self, agent_id): if agent_id not in self._agent_to_index: self._agent_to_index[agent_id] = self._next_agent_index self._next_agent_index += 1 return self._agent_to_index[agent_id] @property def _policy_mapping_fn(self): deprecation_warning( old="Episode._policy_mapping_fn", new="Episode.policy_mapping_fn", error=False, ) return self.policy_mapping_fn @Deprecated(new="Episode.last_extra_action_outs_for", error=False) def last_pi_info_for(self, *args, **kwargs): return self.last_extra_action_outs_for(*args, **kwargs) # Backward compatibility. The name Episode implies that there is # also a (single agent?) Episode. @Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False) class MultiAgentEpisode(Episode): pass