multi_agent.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Tuple
  2. from ray.rllib.policy.policy import PolicySpec
  3. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  4. from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \
  5. PartialTrainerConfigDict
  6. def check_multi_agent(config: PartialTrainerConfigDict) -> \
  7. Tuple[MultiAgentPolicyConfigDict, bool]:
  8. """Checks, whether a (partial) config defines a multi-agent setup.
  9. Args:
  10. config: The user/Trainer/Policy config to check for multi-agent.
  11. Returns:
  12. Tuple consisting of the resulting (all fixed) multi-agent policy
  13. dict and bool indicating whether we have a multi-agent setup or not.
  14. """
  15. multiagent_config = config["multiagent"]
  16. policies = multiagent_config.get("policies")
  17. # Nothing specified in config dict -> Assume simple single agent setup
  18. # with DEFAULT_POLICY_ID as only policy.
  19. if not policies:
  20. policies = {DEFAULT_POLICY_ID}
  21. # Policies given as set (of PolicyIDs) -> Setup each policy automatically
  22. # via empty PolicySpec (will make RLlib infer obs- and action spaces
  23. # as well as the Policy's class).
  24. if isinstance(policies, set):
  25. policies = multiagent_config["policies"] = {
  26. pid: PolicySpec()
  27. for pid in policies
  28. }
  29. # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
  30. # PolicyID found in policies dict.
  31. is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
  32. return policies, is_multiagent