123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456 |
- from collections import defaultdict
- import numpy as np
- import random
- import tree # pip install dm_tree
- from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
- from ray.rllib.env.base_env import _DUMMY_AGENT_ID
- from ray.rllib.policy.policy_map import PolicyMap
- from ray.rllib.utils.annotations import Deprecated, DeveloperAPI
- from ray.rllib.utils.deprecation import deprecation_warning
- from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
- from ray.rllib.utils.typing import SampleBatchType, AgentID, PolicyID, \
- EnvActionType, EnvID, EnvInfoDict, EnvObsType
- from ray.util import log_once
- if TYPE_CHECKING:
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.evaluation.sample_batch_builder import \
- MultiAgentSampleBatchBuilder
- @DeveloperAPI
- class Episode:
- """Tracks the current state of a (possibly multi-agent) episode.
- Attributes:
- new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
- add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
- batch_builder (obj): Batch builder for the current episode.
- total_reward (float): Summed reward across all agents in this episode.
- length (int): Length of this episode.
- episode_id (int): Unique id identifying this trajectory.
- agent_rewards (dict): Summed rewards broken down by agent.
- custom_metrics (dict): Dict where the you can add custom metrics.
- user_data (dict): Dict that you can use for temporary storage. E.g.
- in between two custom callbacks referring to the same episode.
- hist_data (dict): Dict mapping str keys to List[float] for storage of
- per-timestep float data throughout the episode.
- Use case 1: Model-based rollouts in multi-agent:
- A custom compute_actions() function in a policy can inspect the
- current episode state and perform a number of rollouts based on the
- policies and state of other agents in the environment.
- Use case 2: Returning extra rollouts data.
- The model rollouts can be returned back to the sampler by calling:
- >>> batch = episode.new_batch_builder()
- >>> for each transition:
- batch.add_values(...) # see sampler for usage
- >>> episode.extra_batches.add(batch.build_and_reset())
- """
- def __init__(
- self,
- policies: PolicyMap,
- policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
- PolicyID],
- batch_builder_factory: Callable[[],
- "MultiAgentSampleBatchBuilder"],
- extra_batch_callback: Callable[[SampleBatchType], None],
- env_id: EnvID,
- *,
- worker: Optional["RolloutWorker"] = None,
- ):
- """Initializes an Episode instance.
- Args:
- policies: The PolicyMap object (mapping PolicyIDs to Policy
- objects) to use for determining, which policy is used for
- which agent.
- policy_mapping_fn: The mapping function mapping AgentIDs to
- PolicyIDs.
- batch_builder_factory:
- extra_batch_callback:
- env_id: The environment's ID in which this episode runs.
- worker: The RolloutWorker instance, in which this episode runs.
- """
- self.new_batch_builder: Callable[
- [], "MultiAgentSampleBatchBuilder"] = batch_builder_factory
- self.add_extra_batch: Callable[[SampleBatchType],
- None] = extra_batch_callback
- self.batch_builder: "MultiAgentSampleBatchBuilder" = \
- batch_builder_factory()
- self.total_reward: float = 0.0
- self.length: int = 0
- self.episode_id: int = random.randrange(2e9)
- self.env_id = env_id
- self.worker = worker
- self.agent_rewards: Dict[AgentID, float] = defaultdict(float)
- self.custom_metrics: Dict[str, float] = {}
- self.user_data: Dict[str, Any] = {}
- self.hist_data: Dict[str, List[float]] = {}
- self.media: Dict[str, Any] = {}
- self.policy_map: PolicyMap = policies
- self._policies = self.policy_map # backward compatibility
- self.policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"],
- PolicyID] = policy_mapping_fn
- self._next_agent_index: int = 0
- self._agent_to_index: Dict[AgentID, int] = {}
- self._agent_to_policy: Dict[AgentID, PolicyID] = {}
- self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {}
- self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {}
- self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {}
- self._agent_to_last_done: Dict[AgentID, bool] = {}
- self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {}
- self._agent_to_last_action: Dict[AgentID, EnvActionType] = {}
- self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {}
- self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {}
- self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(
- list)
- @DeveloperAPI
- def soft_reset(self) -> None:
- """Clears rewards and metrics, but retains RNN and other state.
- This is used to carry state across multiple logical episodes in the
- same env (i.e., if `soft_horizon` is set).
- """
- self.length = 0
- self.episode_id = random.randrange(2e9)
- self.total_reward = 0.0
- self.agent_rewards = defaultdict(float)
- self._agent_reward_history = defaultdict(list)
- @DeveloperAPI
- def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID:
- """Returns and stores the policy ID for the specified agent.
- If the agent is new, the policy mapping fn will be called to bind the
- agent to a policy for the duration of the entire episode (even if the
- policy_mapping_fn is changed in the meantime!).
- Args:
- agent_id: The agent ID to lookup the policy ID for.
- Returns:
- The policy ID for the specified agent.
- """
- # Perform a new policy_mapping_fn lookup and bind AgentID for the
- # duration of this episode to the returned PolicyID.
- if agent_id not in self._agent_to_policy:
- # Try new API: pass in agent_id and episode as named args.
- # New signature should be: (agent_id, episode, worker, **kwargs)
- try:
- policy_id = self._agent_to_policy[agent_id] = \
- self.policy_mapping_fn(agent_id, self, worker=self.worker)
- except TypeError as e:
- if "positional argument" in e.args[0] or \
- "unexpected keyword argument" in e.args[0]:
- if log_once("policy_mapping_new_signature"):
- deprecation_warning(
- old="policy_mapping_fn(agent_id)",
- new="policy_mapping_fn(agent_id, episode, "
- "worker, **kwargs)")
- policy_id = self._agent_to_policy[agent_id] = \
- self.policy_mapping_fn(agent_id)
- else:
- raise e
- # Use already determined PolicyID.
- else:
- policy_id = self._agent_to_policy[agent_id]
- # PolicyID not found in policy map -> Error.
- if policy_id not in self.policy_map:
- raise KeyError("policy_mapping_fn returned invalid policy id "
- f"'{policy_id}'!")
- return policy_id
- @DeveloperAPI
- def last_observation_for(
- self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
- """Returns the last observation for the specified AgentID.
- Args:
- agent_id: The agent's ID to get the last observation for.
- Returns:
- Last observation the specified AgentID has seen. None in case
- the agent has never made any observations in the episode.
- """
- return self._agent_to_last_obs.get(agent_id)
- @DeveloperAPI
- def last_raw_obs_for(
- self, agent_id: AgentID = _DUMMY_AGENT_ID) -> Optional[EnvObsType]:
- """Returns the last un-preprocessed obs for the specified AgentID.
- Args:
- agent_id: The agent's ID to get the last un-preprocessed
- observation for.
- Returns:
- Last un-preprocessed observation the specified AgentID has seen.
- None in case the agent has never made any observations in the
- episode.
- """
- return self._agent_to_last_raw_obs.get(agent_id)
- @DeveloperAPI
- def last_info_for(self, agent_id: AgentID = _DUMMY_AGENT_ID
- ) -> Optional[EnvInfoDict]:
- """Returns the last info for the specified AgentID.
- Args:
- agent_id: The agent's ID to get the last info for.
- Returns:
- Last info dict the specified AgentID has seen.
- None in case the agent has never made any observations in the
- episode.
- """
- return self._agent_to_last_info.get(agent_id)
- @DeveloperAPI
- def last_action_for(self,
- agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
- """Returns the last action for the specified AgentID, or zeros.
- The "last" action is the most recent one taken by the agent.
- Args:
- agent_id: The agent's ID to get the last action for.
- Returns:
- Last action the specified AgentID has executed.
- Zeros in case the agent has never performed any actions in the
- episode.
- """
- policy_id = self.policy_for(agent_id)
- policy = self.policy_map[policy_id]
- # Agent has already taken at least one action in the episode.
- if agent_id in self._agent_to_last_action:
- if policy.config.get("_disable_action_flattening"):
- return self._agent_to_last_action[agent_id]
- else:
- return flatten_to_single_ndarray(
- self._agent_to_last_action[agent_id])
- # Agent has not acted yet, return all zeros.
- else:
- if policy.config.get("_disable_action_flattening"):
- return tree.map_structure(
- lambda s: np.zeros_like(s.sample(), s.dtype) if
- hasattr(s, "dtype") else np.zeros_like(s.sample()),
- policy.action_space_struct,
- )
- else:
- flat = flatten_to_single_ndarray(policy.action_space.sample())
- if hasattr(policy.action_space, "dtype"):
- return np.zeros_like(flat, dtype=policy.action_space.dtype)
- return np.zeros_like(flat)
- @DeveloperAPI
- def prev_action_for(self,
- agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType:
- """Returns the previous action for the specified agent, or zeros.
- The "previous" action is the one taken one timestep before the
- most recent action taken by the agent.
- Args:
- agent_id: The agent's ID to get the previous action for.
- Returns:
- Previous action the specified AgentID has executed.
- Zero in case the agent has never performed any actions (or only
- one) in the episode.
- """
- policy_id = self.policy_for(agent_id)
- policy = self.policy_map[policy_id]
- # We are at t > 1 -> There has been a previous action by this agent.
- if agent_id in self._agent_to_prev_action:
- if policy.config.get("_disable_action_flattening"):
- return self._agent_to_prev_action[agent_id]
- else:
- return flatten_to_single_ndarray(
- self._agent_to_prev_action[agent_id])
- # We're at t <= 1, so return all zeros.
- else:
- if policy.config.get("_disable_action_flattening"):
- return tree.map_structure(
- lambda a: np.zeros_like(a, a.dtype) if # noqa
- hasattr(a, "dtype") else np.zeros_like(a), # noqa
- self.last_action_for(agent_id),
- )
- else:
- return np.zeros_like(self.last_action_for(agent_id))
- @DeveloperAPI
- def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
- """Returns the last reward for the specified agent, or zero.
- The "last" reward is the one received most recently by the agent.
- Args:
- agent_id: The agent's ID to get the last reward for.
- Returns:
- Last reward for the the specified AgentID.
- Zero in case the agent has never performed any actions
- (and thus received rewards) in the episode.
- """
- history = self._agent_reward_history[agent_id]
- # We are at t > 0 -> Return previously received reward.
- if len(history) >= 1:
- return history[-1]
- # We're at t=0, so there is no previous reward, just return zero.
- else:
- return 0.0
- @DeveloperAPI
- def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float:
- """Returns the previous reward for the specified agent, or zero.
- The "previous" reward is the one received one timestep before the
- most recently received reward of the agent.
- Args:
- agent_id: The agent's ID to get the previous reward for.
- Returns:
- Previous reward for the the specified AgentID.
- Zero in case the agent has never performed any actions (or only
- one) in the episode.
- """
- history = self._agent_reward_history[agent_id]
- # We are at t > 1 -> Return reward prior to most recent (last) one.
- if len(history) >= 2:
- return history[-2]
- # We're at t <= 1, so there is no previous reward, just return zero.
- else:
- return 0.0
- @DeveloperAPI
- def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]:
- """Returns the last RNN state for the specified agent.
- Args:
- agent_id: The agent's ID to get the most recent RNN state for.
- Returns:
- Most recent RNN state of the the specified AgentID.
- """
- if agent_id not in self._agent_to_rnn_state:
- policy_id = self.policy_for(agent_id)
- policy = self.policy_map[policy_id]
- self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
- return self._agent_to_rnn_state[agent_id]
- @DeveloperAPI
- def last_done_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool:
- """Returns the last done flag for the specified AgentID.
- Args:
- agent_id: The agent's ID to get the last done flag for.
- Returns:
- Last done flag for the specified AgentID.
- """
- if agent_id not in self._agent_to_last_done:
- self._agent_to_last_done[agent_id] = False
- return self._agent_to_last_done[agent_id]
- @DeveloperAPI
- def last_extra_action_outs_for(
- self,
- agent_id: AgentID = _DUMMY_AGENT_ID,
- ) -> dict:
- """Returns the last extra-action outputs for the specified agent.
- This data is returned by a call to
- `Policy.compute_actions_from_input_dict` as the 3rd return value
- (1st return value = action; 2nd return value = RNN state outs).
- Args:
- agent_id: The agent's ID to get the last extra-action outs for.
- Returns:
- The last extra-action outs for the specified AgentID.
- """
- return self._agent_to_last_extra_action_outs[agent_id]
- @DeveloperAPI
- def get_agents(self) -> List[AgentID]:
- """Returns list of agent IDs that have appeared in this episode.
- Returns:
- The list of all agent IDs that have appeared so far in this
- episode.
- """
- return list(self._agent_to_index.keys())
- def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None:
- for agent_id, reward in reward_dict.items():
- if reward is not None:
- self.agent_rewards[agent_id,
- self.policy_for(agent_id)] += reward
- self.total_reward += reward
- self._agent_reward_history[agent_id].append(reward)
- def _set_rnn_state(self, agent_id, rnn_state):
- self._agent_to_rnn_state[agent_id] = rnn_state
- def _set_last_observation(self, agent_id, obs):
- self._agent_to_last_obs[agent_id] = obs
- def _set_last_raw_obs(self, agent_id, obs):
- self._agent_to_last_raw_obs[agent_id] = obs
- def _set_last_done(self, agent_id, done):
- self._agent_to_last_done[agent_id] = done
- def _set_last_info(self, agent_id, info):
- self._agent_to_last_info[agent_id] = info
- def _set_last_action(self, agent_id, action):
- if agent_id in self._agent_to_last_action:
- self._agent_to_prev_action[agent_id] = \
- self._agent_to_last_action[agent_id]
- self._agent_to_last_action[agent_id] = action
- def _set_last_extra_action_outs(self, agent_id, pi_info):
- self._agent_to_last_extra_action_outs[agent_id] = pi_info
- def _agent_index(self, agent_id):
- if agent_id not in self._agent_to_index:
- self._agent_to_index[agent_id] = self._next_agent_index
- self._next_agent_index += 1
- return self._agent_to_index[agent_id]
- @property
- def _policy_mapping_fn(self):
- deprecation_warning(
- old="Episode._policy_mapping_fn",
- new="Episode.policy_mapping_fn",
- error=False,
- )
- return self.policy_mapping_fn
- @Deprecated(new="Episode.last_extra_action_outs_for", error=False)
- def last_pi_info_for(self, *args, **kwargs):
- return self.last_extra_action_outs_for(*args, **kwargs)
- # Backward compatibility. The name Episode implies that there is
- # also a (single agent?) Episode.
- @Deprecated(new="ray.rllib.evaluation.episode.Episode", error=False)
- class MultiAgentEpisode(Episode):
- pass
|