test_supported_multi_agent.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import unittest
  2. import ray
  3. from ray.rllib.agents.registry import get_trainer_class
  4. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
  5. MultiAgentMountainCar
  6. from ray.rllib.policy.policy import PolicySpec
  7. from ray.rllib.utils.test_utils import check_train_results, \
  8. framework_iterator
  9. from ray.tune import register_env
  10. def check_support_multiagent(alg, config):
  11. register_env("multi_agent_mountaincar",
  12. lambda _: MultiAgentMountainCar({"num_agents": 2}))
  13. register_env("multi_agent_cartpole",
  14. lambda _: MultiAgentCartPole({"num_agents": 2}))
  15. # Simulate a simple multi-agent setup.
  16. policies = {
  17. "policy_0": PolicySpec(config={"gamma": 0.99}),
  18. "policy_1": PolicySpec(config={"gamma": 0.95}),
  19. }
  20. policy_ids = list(policies.keys())
  21. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  22. pol_id = policy_ids[agent_id]
  23. return pol_id
  24. config["multiagent"] = {
  25. "policies": policies,
  26. "policy_mapping_fn": policy_mapping_fn,
  27. }
  28. for fw in framework_iterator(config):
  29. if fw in ["tf2", "tfe"] and \
  30. alg in ["A3C", "APEX", "APEX_DDPG", "IMPALA"]:
  31. continue
  32. if alg in ["DDPG", "APEX_DDPG", "SAC"]:
  33. a = get_trainer_class(alg)(
  34. config=config, env="multi_agent_mountaincar")
  35. else:
  36. a = get_trainer_class(alg)(
  37. config=config, env="multi_agent_cartpole")
  38. results = a.train()
  39. check_train_results(results)
  40. print(results)
  41. a.stop()
  42. class TestSupportedMultiAgentPG(unittest.TestCase):
  43. @classmethod
  44. def setUpClass(cls) -> None:
  45. ray.init(num_cpus=4)
  46. @classmethod
  47. def tearDownClass(cls) -> None:
  48. ray.shutdown()
  49. def test_a3c_multiagent(self):
  50. check_support_multiagent("A3C", {
  51. "num_workers": 1,
  52. "optimizer": {
  53. "grads_per_step": 1
  54. }
  55. })
  56. def test_impala_multiagent(self):
  57. check_support_multiagent("IMPALA", {"num_gpus": 0})
  58. def test_pg_multiagent(self):
  59. check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
  60. def test_ppo_multiagent(self):
  61. check_support_multiagent(
  62. "PPO", {
  63. "num_workers": 1,
  64. "num_sgd_iter": 1,
  65. "train_batch_size": 10,
  66. "rollout_fragment_length": 10,
  67. "sgd_minibatch_size": 1,
  68. })
  69. class TestSupportedMultiAgentOffPolicy(unittest.TestCase):
  70. @classmethod
  71. def setUpClass(cls) -> None:
  72. ray.init(num_cpus=6)
  73. @classmethod
  74. def tearDownClass(cls) -> None:
  75. ray.shutdown()
  76. def test_apex_multiagent(self):
  77. check_support_multiagent(
  78. "APEX", {
  79. "num_workers": 2,
  80. "timesteps_per_iteration": 100,
  81. "num_gpus": 0,
  82. "buffer_size": 1000,
  83. "min_time_s_per_reporting": 1,
  84. "learning_starts": 10,
  85. "target_network_update_freq": 100,
  86. "optimizer": {
  87. "num_replay_buffer_shards": 1,
  88. },
  89. })
  90. def test_apex_ddpg_multiagent(self):
  91. check_support_multiagent(
  92. "APEX_DDPG", {
  93. "num_workers": 2,
  94. "timesteps_per_iteration": 100,
  95. "buffer_size": 1000,
  96. "num_gpus": 0,
  97. "min_time_s_per_reporting": 1,
  98. "learning_starts": 10,
  99. "target_network_update_freq": 100,
  100. "use_state_preprocessor": True,
  101. })
  102. def test_ddpg_multiagent(self):
  103. check_support_multiagent(
  104. "DDPG", {
  105. "timesteps_per_iteration": 1,
  106. "buffer_size": 1000,
  107. "use_state_preprocessor": True,
  108. "learning_starts": 500,
  109. })
  110. def test_dqn_multiagent(self):
  111. check_support_multiagent("DQN", {
  112. "timesteps_per_iteration": 1,
  113. "buffer_size": 1000,
  114. })
  115. def test_sac_multiagent(self):
  116. check_support_multiagent("SAC", {
  117. "num_workers": 0,
  118. "buffer_size": 1000,
  119. "normalize_actions": False,
  120. })
  121. if __name__ == "__main__":
  122. import pytest
  123. import sys
  124. # One can specify the specific TestCase class to run.
  125. # None for all unittest.TestCase classes in this file.
  126. class_ = sys.argv[1] if len(sys.argv) > 1 else None
  127. sys.exit(
  128. pytest.main(
  129. ["-v", __file__ + ("" if class_ is None else "::" + class_)]))