12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import gym
- import numpy as np
- import unittest
- import ray
- from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
- from ray.rllib.examples.env.multi_agent import BasicMultiAgent
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.tests.test_external_env import make_simple_serving
- SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
- class TestExternalMultiAgentEnv(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- ray.init()
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_external_multi_agent_env_complete_episodes(self):
- agents = 4
- ev = RolloutWorker(
- env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
- policy_spec=MockPolicy,
- rollout_fragment_length=40,
- batch_mode="complete_episodes")
- for _ in range(3):
- batch = ev.sample()
- self.assertEqual(batch.count, 40)
- self.assertEqual(
- len(np.unique(batch[SampleBatch.AGENT_INDEX])), agents)
- def test_external_multi_agent_env_truncate_episodes(self):
- agents = 4
- ev = RolloutWorker(
- env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
- policy_spec=MockPolicy,
- rollout_fragment_length=40,
- batch_mode="truncate_episodes")
- for _ in range(3):
- batch = ev.sample()
- self.assertEqual(batch.count, 160)
- self.assertEqual(
- len(np.unique(batch[SampleBatch.AGENT_INDEX])), agents)
- def test_external_multi_agent_env_sample(self):
- agents = 2
- act_space = gym.spaces.Discrete(2)
- obs_space = gym.spaces.Discrete(2)
- ev = RolloutWorker(
- env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
- policy_spec={
- "p0": (MockPolicy, obs_space, act_space, {}),
- "p1": (MockPolicy, obs_space, act_space, {}),
- },
- policy_mapping_fn=lambda aid, **kwargs: "p{}".format(aid % 2),
- rollout_fragment_length=50)
- batch = ev.sample()
- self.assertEqual(batch.count, 50)
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|