observation_function.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from typing import Dict
  2. from ray.rllib.env import BaseEnv
  3. from ray.rllib.policy import Policy
  4. from ray.rllib.evaluation import Episode, RolloutWorker
  5. from ray.rllib.utils.framework import TensorType
  6. from ray.rllib.utils.typing import AgentID, PolicyID
  7. class ObservationFunction:
  8. """Interceptor function for rewriting observations from the environment.
  9. These callbacks can be used for preprocessing of observations, especially
  10. in multi-agent scenarios.
  11. Observation functions can be specified in the multi-agent config by
  12. specifying ``{"observation_fn": your_obs_func}``. Note that
  13. ``your_obs_func`` can be a plain Python function.
  14. This API is **experimental**.
  15. """
  16. def __call__(self, agent_obs: Dict[AgentID, TensorType],
  17. worker: RolloutWorker, base_env: BaseEnv,
  18. policies: Dict[PolicyID, Policy], episode: Episode,
  19. **kw) -> Dict[AgentID, TensorType]:
  20. """Callback run on each environment step to observe the environment.
  21. This method takes in the original agent observation dict returned by
  22. a MultiAgentEnv, and returns a possibly modified one. It can be
  23. thought of as a "wrapper" around the environment.
  24. TODO(ekl): allow end-to-end differentiation through the observation
  25. function and policy losses.
  26. TODO(ekl): enable batch processing.
  27. Args:
  28. agent_obs (dict): Dictionary of default observations from the
  29. environment. The default implementation of observe() simply
  30. returns this dict.
  31. worker (RolloutWorker): Reference to the current rollout worker.
  32. base_env (BaseEnv): BaseEnv running the episode. The underlying
  33. sub environment objects (BaseEnvs are vectorized) can be
  34. retrieved by calling `base_env.get_sub_environments()`.
  35. policies (dict): Mapping of policy id to policy objects. In single
  36. agent mode there will only be a single "default" policy.
  37. episode (Episode): Episode state object.
  38. kwargs: Forward compatibility placeholder.
  39. Returns:
  40. new_agent_obs (dict): copy of agent obs with updates. You can
  41. rewrite or drop data from the dict if needed (e.g., the env
  42. can have a dummy "global" observation, and the observer can
  43. merge the global state into individual observations.
  44. Examples:
  45. >>> # Observer that merges global state into individual obs. It is
  46. ... # rewriting the discrete obs into a tuple with global state.
  47. >>> example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...)
  48. {"a": [1, 101], "b": [2, 101]}
  49. >>> # Observer for e.g., custom centralized critic model. It is
  50. ... # rewriting the discrete obs into a dict with more data.
  51. >>> example_obs_fn2({"a": 1, "b": 2}, ...)
  52. {"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}}
  53. """
  54. return agent_obs