test_multi_agent_env.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import gym
  2. import numpy as np
  3. import random
  4. import unittest
  5. import ray
  6. from ray.tune.registry import register_env
  7. from ray.rllib.agents.dqn.dqn_tf_policy import DQNTFPolicy
  8. from ray.rllib.agents.pg import PGTrainer
  9. from ray.rllib.evaluation.episode import Episode
  10. from ray.rllib.evaluation.rollout_worker import get_global_worker
  11. from ray.rllib.examples.policy.random_policy import RandomPolicy
  12. from ray.rllib.examples.env.multi_agent import MultiAgentCartPole, \
  13. BasicMultiAgent, EarlyDoneMultiAgent, FlexAgentsMultiAgent, \
  14. RoundRobinMultiAgent
  15. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  16. from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy
  17. from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
  18. from ray.rllib.policy.policy import PolicySpec
  19. from ray.rllib.utils.numpy import one_hot
  20. from ray.rllib.utils.test_utils import check
  21. class TestMultiAgentEnv(unittest.TestCase):
  22. @classmethod
  23. def setUpClass(cls) -> None:
  24. ray.init(num_cpus=4)
  25. @classmethod
  26. def tearDownClass(cls) -> None:
  27. ray.shutdown()
  28. def test_basic_mock(self):
  29. env = BasicMultiAgent(4)
  30. obs = env.reset()
  31. self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
  32. for _ in range(24):
  33. obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
  34. self.assertEqual(obs, {0: 0, 1: 0, 2: 0, 3: 0})
  35. self.assertEqual(rew, {0: 1, 1: 1, 2: 1, 3: 1})
  36. self.assertEqual(done, {
  37. 0: False,
  38. 1: False,
  39. 2: False,
  40. 3: False,
  41. "__all__": False
  42. })
  43. obs, rew, done, info = env.step({0: 0, 1: 0, 2: 0, 3: 0})
  44. self.assertEqual(done, {
  45. 0: True,
  46. 1: True,
  47. 2: True,
  48. 3: True,
  49. "__all__": True
  50. })
  51. def test_round_robin_mock(self):
  52. env = RoundRobinMultiAgent(2)
  53. obs = env.reset()
  54. self.assertEqual(obs, {0: 0})
  55. for _ in range(5):
  56. obs, rew, done, info = env.step({0: 0})
  57. self.assertEqual(obs, {1: 0})
  58. self.assertEqual(done["__all__"], False)
  59. obs, rew, done, info = env.step({1: 0})
  60. self.assertEqual(obs, {0: 0})
  61. self.assertEqual(done["__all__"], False)
  62. obs, rew, done, info = env.step({0: 0})
  63. self.assertEqual(done["__all__"], True)
  64. def test_no_reset_until_poll(self):
  65. env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 1)
  66. self.assertFalse(env.get_sub_environments()[0].resetted)
  67. env.poll()
  68. self.assertTrue(env.get_sub_environments()[0].resetted)
  69. def test_vectorize_basic(self):
  70. env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 2)
  71. obs, rew, dones, _, _ = env.poll()
  72. self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
  73. self.assertEqual(rew, {0: {}, 1: {}})
  74. self.assertEqual(dones, {
  75. 0: {
  76. "__all__": False
  77. },
  78. 1: {
  79. "__all__": False
  80. },
  81. })
  82. for _ in range(24):
  83. env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
  84. obs, rew, dones, _, _ = env.poll()
  85. self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
  86. self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
  87. self.assertEqual(
  88. dones, {
  89. 0: {
  90. 0: False,
  91. 1: False,
  92. "__all__": False
  93. },
  94. 1: {
  95. 0: False,
  96. 1: False,
  97. "__all__": False
  98. }
  99. })
  100. env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
  101. obs, rew, dones, _, _ = env.poll()
  102. self.assertEqual(
  103. dones, {
  104. 0: {
  105. 0: True,
  106. 1: True,
  107. "__all__": True
  108. },
  109. 1: {
  110. 0: True,
  111. 1: True,
  112. "__all__": True
  113. }
  114. })
  115. # Reset processing
  116. self.assertRaises(
  117. ValueError, lambda: env.send_actions({
  118. 0: {
  119. 0: 0,
  120. 1: 0
  121. },
  122. 1: {
  123. 0: 0,
  124. 1: 0
  125. }
  126. }))
  127. self.assertEqual(env.try_reset(0), {0: 0, 1: 0})
  128. self.assertEqual(env.try_reset(1), {0: 0, 1: 0})
  129. env.send_actions({0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
  130. obs, rew, dones, _, _ = env.poll()
  131. self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
  132. self.assertEqual(rew, {0: {0: 1, 1: 1}, 1: {0: 1, 1: 1}})
  133. self.assertEqual(
  134. dones, {
  135. 0: {
  136. 0: False,
  137. 1: False,
  138. "__all__": False
  139. },
  140. 1: {
  141. 0: False,
  142. 1: False,
  143. "__all__": False
  144. }
  145. })
  146. def test_vectorize_round_robin(self):
  147. env = _MultiAgentEnvToBaseEnv(lambda v: RoundRobinMultiAgent(2), [], 2)
  148. obs, rew, dones, _, _ = env.poll()
  149. self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
  150. self.assertEqual(rew, {0: {}, 1: {}})
  151. env.send_actions({0: {0: 0}, 1: {0: 0}})
  152. obs, rew, dones, _, _ = env.poll()
  153. self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}})
  154. env.send_actions({0: {1: 0}, 1: {1: 0}})
  155. obs, rew, dones, _, _ = env.poll()
  156. self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
  157. def test_multi_agent_sample(self):
  158. def policy_mapping_fn(agent_id, episode, worker, **kwargs):
  159. return "p{}".format(agent_id % 2)
  160. ev = RolloutWorker(
  161. env_creator=lambda _: BasicMultiAgent(5),
  162. policy_spec={
  163. "p0": PolicySpec(policy_class=MockPolicy),
  164. "p1": PolicySpec(policy_class=MockPolicy),
  165. },
  166. policy_mapping_fn=policy_mapping_fn,
  167. rollout_fragment_length=50)
  168. batch = ev.sample()
  169. self.assertEqual(batch.count, 50)
  170. self.assertEqual(batch.policy_batches["p0"].count, 150)
  171. self.assertEqual(batch.policy_batches["p1"].count, 100)
  172. self.assertEqual(batch.policy_batches["p0"]["t"].tolist(),
  173. list(range(25)) * 6)
  174. def test_multi_agent_sample_sync_remote(self):
  175. ev = RolloutWorker(
  176. env_creator=lambda _: BasicMultiAgent(5),
  177. policy_spec={
  178. "p0": PolicySpec(policy_class=MockPolicy),
  179. "p1": PolicySpec(policy_class=MockPolicy),
  180. },
  181. # This signature will raise a soft-deprecation warning due
  182. # to the new signature we are using (agent_id, episode, **kwargs),
  183. # but should not break this test.
  184. policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)),
  185. rollout_fragment_length=50,
  186. num_envs=4,
  187. remote_worker_envs=True,
  188. remote_env_batch_wait_ms=99999999)
  189. batch = ev.sample()
  190. self.assertEqual(batch.count, 200)
  191. def test_multi_agent_sample_async_remote(self):
  192. ev = RolloutWorker(
  193. env_creator=lambda _: BasicMultiAgent(5),
  194. policy_spec={
  195. "p0": PolicySpec(policy_class=MockPolicy),
  196. "p1": PolicySpec(policy_class=MockPolicy),
  197. },
  198. policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
  199. rollout_fragment_length=50,
  200. num_envs=4,
  201. remote_worker_envs=True)
  202. batch = ev.sample()
  203. self.assertEqual(batch.count, 200)
  204. def test_multi_agent_sample_with_horizon(self):
  205. ev = RolloutWorker(
  206. env_creator=lambda _: BasicMultiAgent(5),
  207. policy_spec={
  208. "p0": PolicySpec(policy_class=MockPolicy),
  209. "p1": PolicySpec(policy_class=MockPolicy),
  210. },
  211. policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)),
  212. episode_horizon=10, # test with episode horizon set
  213. rollout_fragment_length=50)
  214. batch = ev.sample()
  215. self.assertEqual(batch.count, 50)
  216. def test_sample_from_early_done_env(self):
  217. ev = RolloutWorker(
  218. env_creator=lambda _: EarlyDoneMultiAgent(),
  219. policy_spec={
  220. "p0": PolicySpec(policy_class=MockPolicy),
  221. "p1": PolicySpec(policy_class=MockPolicy),
  222. },
  223. policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)),
  224. batch_mode="complete_episodes",
  225. rollout_fragment_length=1)
  226. # This used to raise an Error due to the EarlyDoneMultiAgent
  227. # terminating at e.g. agent0 w/o publishing the observation for
  228. # agent1 anymore. This limitation is fixed and an env may
  229. # terminate at any time (as well as return rewards for any agent
  230. # at any time, even when that agent doesn't have an obs returned
  231. # in the same call to `step()`).
  232. ma_batch = ev.sample()
  233. # Make sure that agents took the correct (alternating timesteps)
  234. # path. Except for the last timestep, where both agents got
  235. # terminated.
  236. ag0_ts = ma_batch.policy_batches["p0"]["t"]
  237. ag1_ts = ma_batch.policy_batches["p1"]["t"]
  238. self.assertTrue(np.all(np.abs(ag0_ts[:-1] - ag1_ts[:-1]) == 1.0))
  239. self.assertTrue(ag0_ts[-1] == ag1_ts[-1])
  240. def test_multi_agent_with_flex_agents(self):
  241. register_env("flex_agents_multi_agent_cartpole",
  242. lambda _: FlexAgentsMultiAgent())
  243. pg = PGTrainer(
  244. env="flex_agents_multi_agent_cartpole",
  245. config={
  246. "num_workers": 0,
  247. "framework": "tf",
  248. })
  249. for i in range(10):
  250. result = pg.train()
  251. print("Iteration {}, reward {}, timesteps {}".format(
  252. i, result["episode_reward_mean"], result["timesteps_total"]))
  253. def test_multi_agent_sample_round_robin(self):
  254. ev = RolloutWorker(
  255. env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
  256. policy_spec={
  257. "p0": PolicySpec(policy_class=MockPolicy),
  258. },
  259. policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
  260. rollout_fragment_length=50)
  261. batch = ev.sample()
  262. self.assertEqual(batch.count, 50)
  263. # since we round robin introduce agents into the env, some of the env
  264. # steps don't count as proper transitions
  265. self.assertEqual(batch.policy_batches["p0"].count, 42)
  266. check(batch.policy_batches["p0"]["obs"][:10],
  267. one_hot(np.array([0, 1, 2, 3, 4] * 2), 10))
  268. check(batch.policy_batches["p0"]["new_obs"][:10],
  269. one_hot(np.array([1, 2, 3, 4, 5] * 2), 10))
  270. self.assertEqual(batch.policy_batches["p0"]["rewards"].tolist()[:10],
  271. [100, 100, 100, 100, 0] * 2)
  272. self.assertEqual(batch.policy_batches["p0"]["dones"].tolist()[:10],
  273. [False, False, False, False, True] * 2)
  274. self.assertEqual(batch.policy_batches["p0"]["t"].tolist()[:10],
  275. [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
  276. def test_custom_rnn_state_values(self):
  277. h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
  278. class StatefulPolicy(RandomPolicy):
  279. def compute_actions(self,
  280. obs_batch,
  281. state_batches=None,
  282. prev_action_batch=None,
  283. prev_reward_batch=None,
  284. episodes=None,
  285. explore=True,
  286. timestep=None,
  287. **kwargs):
  288. return [0] * len(obs_batch), [[h] * len(obs_batch)], {}
  289. def get_initial_state(self):
  290. return [{}] # empty dict
  291. ev = RolloutWorker(
  292. env_creator=lambda _: gym.make("CartPole-v0"),
  293. policy_spec=StatefulPolicy,
  294. rollout_fragment_length=5)
  295. batch = ev.sample()
  296. self.assertEqual(batch.count, 5)
  297. self.assertEqual(batch["state_in_0"][0], {})
  298. self.assertEqual(batch["state_out_0"][0], h)
  299. self.assertEqual(batch["state_in_0"][1], h)
  300. self.assertEqual(batch["state_out_0"][1], h)
  301. def test_returning_model_based_rollouts_data(self):
  302. class ModelBasedPolicy(DQNTFPolicy):
  303. def compute_actions_from_input_dict(self,
  304. input_dict,
  305. explore=None,
  306. timestep=None,
  307. episodes=None,
  308. **kwargs):
  309. obs_batch = input_dict["obs"]
  310. # In policy loss initialization phase, no episodes are passed
  311. # in.
  312. if episodes is not None:
  313. # Pretend we did a model-based rollout and want to return
  314. # the extra trajectory.
  315. env_id = episodes[0].env_id
  316. fake_eps = Episode(episodes[0].policy_map,
  317. episodes[0].policy_mapping_fn,
  318. lambda: None, lambda x: None, env_id)
  319. builder = get_global_worker().sampler.sample_collector
  320. agent_id = "extra_0"
  321. policy_id = "p1" # use p1 so we can easily check it
  322. builder.add_init_obs(fake_eps, agent_id, env_id, policy_id,
  323. -1, obs_batch[0])
  324. for t in range(4):
  325. builder.add_action_reward_next_obs(
  326. episode_id=fake_eps.episode_id,
  327. agent_id=agent_id,
  328. env_id=env_id,
  329. policy_id=policy_id,
  330. agent_done=t == 3,
  331. values=dict(
  332. t=t,
  333. actions=0,
  334. rewards=0,
  335. dones=t == 3,
  336. infos={},
  337. new_obs=obs_batch[0]))
  338. batch = builder.postprocess_episode(
  339. episode=fake_eps, build=True)
  340. episodes[0].add_extra_batch(batch)
  341. # Just return zeros for actions
  342. return [0] * len(obs_batch), [], {}
  343. ev = RolloutWorker(
  344. env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}),
  345. policy_spec={
  346. "p0": PolicySpec(policy_class=ModelBasedPolicy),
  347. "p1": PolicySpec(policy_class=ModelBasedPolicy),
  348. },
  349. policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0",
  350. rollout_fragment_length=5)
  351. batch = ev.sample()
  352. # 5 environment steps (rollout_fragment_length).
  353. self.assertEqual(batch.count, 5)
  354. # 10 agent steps for p0: 2 agents, both using p0 as their policy.
  355. self.assertEqual(batch.policy_batches["p0"].count, 10)
  356. # 20 agent steps for p1: Each time both(!) agents takes 1 step,
  357. # p1 takes 4: 5 (rollout-fragment length) * 4 = 20
  358. self.assertEqual(batch.policy_batches["p1"].count, 20)
  359. def test_train_multi_agent_cartpole_single_policy(self):
  360. n = 10
  361. register_env("multi_agent_cartpole",
  362. lambda _: MultiAgentCartPole({"num_agents": n}))
  363. pg = PGTrainer(
  364. env="multi_agent_cartpole",
  365. config={
  366. "num_workers": 0,
  367. "framework": "tf",
  368. })
  369. for i in range(50):
  370. result = pg.train()
  371. print("Iteration {}, reward {}, timesteps {}".format(
  372. i, result["episode_reward_mean"], result["timesteps_total"]))
  373. if result["episode_reward_mean"] >= 50 * n:
  374. return
  375. raise Exception("failed to improve reward")
  376. def test_train_multi_agent_cartpole_multi_policy(self):
  377. n = 10
  378. register_env("multi_agent_cartpole",
  379. lambda _: MultiAgentCartPole({"num_agents": n}))
  380. def gen_policy():
  381. config = {
  382. "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
  383. "n_step": random.choice([1, 2, 3, 4, 5]),
  384. }
  385. return PolicySpec(config=config)
  386. pg = PGTrainer(
  387. env="multi_agent_cartpole",
  388. config={
  389. "num_workers": 0,
  390. "multiagent": {
  391. "policies": {
  392. "policy_1": gen_policy(),
  393. "policy_2": gen_policy(),
  394. },
  395. "policy_mapping_fn": lambda aid, **kwargs: "policy_1",
  396. },
  397. "framework": "tf",
  398. })
  399. # Just check that it runs without crashing
  400. for i in range(10):
  401. result = pg.train()
  402. print("Iteration {}, reward {}, timesteps {}".format(
  403. i, result["episode_reward_mean"], result["timesteps_total"]))
  404. self.assertTrue(
  405. pg.compute_single_action([0, 0, 0, 0], policy_id="policy_1") in
  406. [0, 1])
  407. self.assertTrue(
  408. pg.compute_single_action([0, 0, 0, 0], policy_id="policy_2") in
  409. [0, 1])
  410. self.assertRaisesRegex(
  411. KeyError,
  412. "not found in PolicyMap",
  413. lambda: pg.compute_single_action(
  414. [0, 0, 0, 0], policy_id="policy_3"))
  415. if __name__ == "__main__":
  416. import pytest
  417. import sys
  418. sys.exit(pytest.main(["-v", __file__]))