123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- from typing import Dict
- from ray.rllib.env import BaseEnv
- from ray.rllib.policy import Policy
- from ray.rllib.evaluation import Episode, RolloutWorker
- from ray.rllib.utils.framework import TensorType
- from ray.rllib.utils.typing import AgentID, PolicyID
- class ObservationFunction:
- """Interceptor function for rewriting observations from the environment.
- These callbacks can be used for preprocessing of observations, especially
- in multi-agent scenarios.
- Observation functions can be specified in the multi-agent config by
- specifying ``{"observation_fn": your_obs_func}``. Note that
- ``your_obs_func`` can be a plain Python function.
- This API is **experimental**.
- """
- def __call__(self, agent_obs: Dict[AgentID, TensorType],
- worker: RolloutWorker, base_env: BaseEnv,
- policies: Dict[PolicyID, Policy], episode: Episode,
- **kw) -> Dict[AgentID, TensorType]:
- """Callback run on each environment step to observe the environment.
- This method takes in the original agent observation dict returned by
- a MultiAgentEnv, and returns a possibly modified one. It can be
- thought of as a "wrapper" around the environment.
- TODO(ekl): allow end-to-end differentiation through the observation
- function and policy losses.
- TODO(ekl): enable batch processing.
- Args:
- agent_obs (dict): Dictionary of default observations from the
- environment. The default implementation of observe() simply
- returns this dict.
- worker (RolloutWorker): Reference to the current rollout worker.
- base_env (BaseEnv): BaseEnv running the episode. The underlying
- sub environment objects (BaseEnvs are vectorized) can be
- retrieved by calling `base_env.get_sub_environments()`.
- policies (dict): Mapping of policy id to policy objects. In single
- agent mode there will only be a single "default" policy.
- episode (Episode): Episode state object.
- kwargs: Forward compatibility placeholder.
- Returns:
- new_agent_obs (dict): copy of agent obs with updates. You can
- rewrite or drop data from the dict if needed (e.g., the env
- can have a dummy "global" observation, and the observer can
- merge the global state into individual observations.
- Examples:
- >>> # Observer that merges global state into individual obs. It is
- ... # rewriting the discrete obs into a tuple with global state.
- >>> example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
- {"a": [1, 101], "b": [2, 101]}
- >>> # Observer for e.g., custom centralized critic model. It is
- ... # rewriting the discrete obs into a dict with more data.
- >>> example_obs_fn2({"a": 1, "b": 2}, ...)
- {"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
- """
- return agent_obs
|