test_check_multi_agent.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import unittest
  2. from ray.rllib.algorithms.pg import PGConfig
  3. from ray.rllib.policy.policy import PolicySpec
  4. class TestCheckMultiAgent(unittest.TestCase):
  5. def test_multi_agent_invalid_args(self):
  6. self.assertRaisesRegex(
  7. TypeError,
  8. "got an unexpected keyword argument 'wrong_key'",
  9. lambda: (
  10. PGConfig().multi_agent(
  11. policies={"p0"}, policies_to_train=["p0"], wrong_key=1
  12. )
  13. ),
  14. )
  15. def test_multi_agent_bad_policy_ids(self):
  16. self.assertRaisesRegex(
  17. KeyError,
  18. "Policy IDs must always be of type",
  19. lambda: (
  20. PGConfig().multi_agent(
  21. policies={1, "good_id"},
  22. policy_mapping_fn=lambda agent_id, episode, worker, **kw: "good_id",
  23. )
  24. ),
  25. )
  26. def test_multi_agent_invalid_sub_values(self):
  27. self.assertRaisesRegex(
  28. ValueError,
  29. "config.multi_agent\\(count_steps_by=..\\) must be one of",
  30. lambda: (PGConfig().multi_agent(count_steps_by="invalid_value")),
  31. )
  32. def test_multi_agent_invalid_override_configs(self):
  33. self.assertRaisesRegex(
  34. KeyError,
  35. "Invalid property name invdli for config class PGConfig",
  36. lambda: (
  37. PGConfig().multi_agent(
  38. policies={
  39. "p0": PolicySpec(config=PGConfig.overrides(invdli=42.0)),
  40. }
  41. )
  42. ),
  43. )
  44. self.assertRaisesRegex(
  45. KeyError,
  46. "Invalid property name invdli for config class PGConfig",
  47. lambda: (
  48. PGConfig().multi_agent(
  49. policies={
  50. "p0": PolicySpec(config=PGConfig.overrides(invdli=42.0)),
  51. }
  52. )
  53. ),
  54. )
  55. def test_setting_multiagent_key_in_config_should_fail(self):
  56. config = PGConfig().multi_agent(
  57. policies={
  58. "pol1": (None, None, None, None),
  59. "pol2": (None, None, None, PGConfig.overrides(lr=0.001)),
  60. }
  61. )
  62. def set_ma(config):
  63. # not ok: cannot set "multiagent" key in AlgorithmConfig anymore.
  64. config["multiagent"] = {"policies": {"pol1", "pol2"}}
  65. self.assertRaisesRegex(
  66. AttributeError,
  67. "Cannot set `multiagent` key in an AlgorithmConfig!",
  68. lambda: set_ma(config),
  69. )
  70. if __name__ == "__main__":
  71. import pytest
  72. pytest.main()