test_rollout_worker.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756
  1. from collections import Counter
  2. import gym
  3. from gym.spaces import Box, Discrete
  4. import numpy as np
  5. import os
  6. import random
  7. import tempfile
  8. import time
  9. import unittest
  10. import ray
  11. from ray.rllib.agents.pg import PGTrainer
  12. from ray.rllib.agents.a3c import A2CTrainer
  13. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  14. from ray.rllib.env.utils import VideoMonitor
  15. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  16. from ray.rllib.evaluation.metrics import collect_metrics
  17. from ray.rllib.evaluation.postprocessing import compute_advantages
  18. from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2, MockVectorEnv,\
  19. VectorizedMockEnv
  20. from ray.rllib.examples.env.multi_agent import BasicMultiAgent,\
  21. MultiAgentCartPole
  22. from ray.rllib.examples.policy.random_policy import RandomPolicy
  23. from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
  24. STEPS_TRAINED_COUNTER
  25. from ray.rllib.policy.policy import Policy
  26. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \
  27. SampleBatch
  28. from ray.rllib.utils.annotations import override
  29. from ray.rllib.utils.test_utils import check, framework_iterator
  30. from ray.tune.registry import register_env
  31. class MockPolicy(RandomPolicy):
  32. @override(RandomPolicy)
  33. def compute_actions(self,
  34. obs_batch,
  35. state_batches=None,
  36. prev_action_batch=None,
  37. prev_reward_batch=None,
  38. episodes=None,
  39. explore=None,
  40. timestep=None,
  41. **kwargs):
  42. return np.array([random.choice([0, 1])] * len(obs_batch)), [], {}
  43. @override(Policy)
  44. def postprocess_trajectory(self,
  45. batch,
  46. other_agent_batches=None,
  47. episode=None):
  48. assert episode is not None
  49. super().postprocess_trajectory(batch, other_agent_batches, episode)
  50. return compute_advantages(
  51. batch, 100.0, 0.9, use_gae=False, use_critic=False)
  52. class BadPolicy(RandomPolicy):
  53. @override(RandomPolicy)
  54. def compute_actions(self,
  55. obs_batch,
  56. state_batches=None,
  57. prev_action_batch=None,
  58. prev_reward_batch=None,
  59. episodes=None,
  60. explore=None,
  61. timestep=None,
  62. **kwargs):
  63. raise Exception("intentional error")
  64. class FailOnStepEnv(gym.Env):
  65. def __init__(self):
  66. self.observation_space = gym.spaces.Discrete(1)
  67. self.action_space = gym.spaces.Discrete(2)
  68. def reset(self):
  69. raise ValueError("kaboom")
  70. def step(self, action):
  71. raise ValueError("kaboom")
  72. class TestRolloutWorker(unittest.TestCase):
  73. @classmethod
  74. def setUpClass(cls):
  75. ray.init(num_cpus=5)
  76. @classmethod
  77. def tearDownClass(cls):
  78. ray.shutdown()
  79. def test_basic(self):
  80. ev = RolloutWorker(
  81. env_creator=lambda _: gym.make("CartPole-v0"),
  82. policy_spec=MockPolicy)
  83. batch = ev.sample()
  84. for key in [
  85. "obs", "actions", "rewards", "dones", "advantages",
  86. "prev_rewards", "prev_actions"
  87. ]:
  88. self.assertIn(key, batch)
  89. self.assertGreater(np.abs(np.mean(batch[key])), 0)
  90. def to_prev(vec):
  91. out = np.zeros_like(vec)
  92. for i, v in enumerate(vec):
  93. if i + 1 < len(out) and not batch["dones"][i]:
  94. out[i + 1] = v
  95. return out.tolist()
  96. self.assertEqual(batch["prev_rewards"].tolist(),
  97. to_prev(batch["rewards"]))
  98. self.assertEqual(batch["prev_actions"].tolist(),
  99. to_prev(batch["actions"]))
  100. self.assertGreater(batch["advantages"][0], 1)
  101. ev.stop()
  102. def test_batch_ids(self):
  103. fragment_len = 100
  104. ev = RolloutWorker(
  105. env_creator=lambda _: gym.make("CartPole-v0"),
  106. policy_spec=MockPolicy,
  107. rollout_fragment_length=fragment_len)
  108. batch1 = ev.sample()
  109. batch2 = ev.sample()
  110. unroll_ids_1 = set(batch1["unroll_id"])
  111. unroll_ids_2 = set(batch2["unroll_id"])
  112. # Assert no overlap of unroll IDs between sample() calls.
  113. self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1))
  114. # CartPole episodes should be short initially: Expect more than one
  115. # unroll ID in each batch.
  116. self.assertTrue(len(unroll_ids_1) > 1)
  117. self.assertTrue(len(unroll_ids_2) > 1)
  118. ev.stop()
  119. def test_global_vars_update(self):
  120. for fw in framework_iterator(frameworks=("tf2", "tf")):
  121. agent = A2CTrainer(
  122. env="CartPole-v0",
  123. config={
  124. "num_workers": 1,
  125. # lr = 0.1 - [(0.1 - 0.000001) / 100000] * ts
  126. "lr_schedule": [[0, 0.1], [100000, 0.000001]],
  127. "framework": fw,
  128. })
  129. policy = agent.get_policy()
  130. for i in range(3):
  131. result = agent.train()
  132. print("{}={}".format(STEPS_TRAINED_COUNTER,
  133. result["info"][STEPS_TRAINED_COUNTER]))
  134. print("{}={}".format(STEPS_SAMPLED_COUNTER,
  135. result["info"][STEPS_SAMPLED_COUNTER]))
  136. global_timesteps = policy.global_timestep
  137. print("global_timesteps={}".format(global_timesteps))
  138. expected_lr = \
  139. 0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
  140. lr = policy.cur_lr
  141. if fw == "tf":
  142. lr = policy.get_session().run(lr)
  143. check(lr, expected_lr, rtol=0.05)
  144. agent.stop()
  145. def test_no_step_on_init(self):
  146. register_env("fail", lambda _: FailOnStepEnv())
  147. for fw in framework_iterator():
  148. # We expect this to fail already on Trainer init due
  149. # to the env sanity check right after env creation (inside
  150. # RolloutWorker).
  151. self.assertRaises(Exception, lambda: PGTrainer(
  152. env="fail", config={
  153. "num_workers": 2,
  154. "framework": fw,
  155. }))
  156. def test_callbacks(self):
  157. for fw in framework_iterator(frameworks=("torch", "tf")):
  158. counts = Counter()
  159. pg = PGTrainer(
  160. env="CartPole-v0", config={
  161. "num_workers": 0,
  162. "rollout_fragment_length": 50,
  163. "train_batch_size": 50,
  164. "callbacks": {
  165. "on_episode_start":
  166. lambda x: counts.update({"start": 1}),
  167. "on_episode_step":
  168. lambda x: counts.update({"step": 1}),
  169. "on_episode_end": lambda x: counts.update({"end": 1}),
  170. "on_sample_end":
  171. lambda x: counts.update({"sample": 1}),
  172. },
  173. "framework": fw,
  174. })
  175. pg.train()
  176. pg.train()
  177. self.assertGreater(counts["sample"], 0)
  178. self.assertGreater(counts["start"], 0)
  179. self.assertGreater(counts["end"], 0)
  180. self.assertGreater(counts["step"], 0)
  181. pg.stop()
  182. def test_query_evaluators(self):
  183. register_env("test", lambda _: gym.make("CartPole-v0"))
  184. for fw in framework_iterator(frameworks=("torch", "tf")):
  185. pg = PGTrainer(
  186. env="test",
  187. config={
  188. "num_workers": 2,
  189. "rollout_fragment_length": 5,
  190. "num_envs_per_worker": 2,
  191. "framework": fw,
  192. "create_env_on_driver": True,
  193. })
  194. results = pg.workers.foreach_worker(
  195. lambda ev: ev.rollout_fragment_length)
  196. results2 = pg.workers.foreach_worker_with_index(
  197. lambda ev, i: (i, ev.rollout_fragment_length))
  198. results3 = pg.workers.foreach_worker(
  199. lambda ev: ev.foreach_env(lambda env: 1))
  200. self.assertEqual(results, [10, 10, 10])
  201. self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
  202. self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
  203. pg.stop()
  204. def test_action_clipping(self):
  205. from ray.rllib.examples.env.random_env import RandomEnv
  206. action_space = gym.spaces.Box(-2.0, 1.0, (3, ))
  207. # Clipping: True (clip between Policy's action_space.low/high).
  208. ev = RolloutWorker(
  209. env_creator=lambda _: RandomEnv(config=dict(
  210. action_space=action_space,
  211. max_episode_len=10,
  212. p_done=0.0,
  213. check_action_bounds=True,
  214. )),
  215. policy_spec=RandomPolicy,
  216. policy_config=dict(
  217. action_space=action_space,
  218. ignore_action_bounds=True,
  219. ),
  220. normalize_actions=False,
  221. clip_actions=True,
  222. batch_mode="complete_episodes")
  223. sample = ev.sample()
  224. # Check, whether the action bounds have been breached (expected).
  225. # We still arrived here b/c we clipped according to the Env's action
  226. # space.
  227. self.assertGreater(np.max(sample["actions"]), action_space.high[0])
  228. self.assertLess(np.min(sample["actions"]), action_space.low[0])
  229. ev.stop()
  230. # Clipping: False and RandomPolicy produces invalid actions.
  231. # Expect Env to complain.
  232. ev2 = RolloutWorker(
  233. env_creator=lambda _: RandomEnv(config=dict(
  234. action_space=action_space,
  235. max_episode_len=10,
  236. p_done=0.0,
  237. check_action_bounds=True,
  238. )),
  239. policy_spec=RandomPolicy,
  240. policy_config=dict(
  241. action_space=action_space,
  242. ignore_action_bounds=True,
  243. ),
  244. # No normalization (+clipping) and no clipping ->
  245. # Should lead to Env complaining.
  246. normalize_actions=False,
  247. clip_actions=False,
  248. batch_mode="complete_episodes")
  249. self.assertRaisesRegex(ValueError, r"Illegal action", ev2.sample)
  250. ev2.stop()
  251. # Clipping: False and RandomPolicy produces valid (bounded) actions.
  252. # Expect "actions" in SampleBatch to be unclipped.
  253. ev3 = RolloutWorker(
  254. env_creator=lambda _: RandomEnv(config=dict(
  255. action_space=action_space,
  256. max_episode_len=10,
  257. p_done=0.0,
  258. check_action_bounds=True,
  259. )),
  260. policy_spec=RandomPolicy,
  261. policy_config=dict(action_space=action_space),
  262. # Should not be a problem as RandomPolicy abides to bounds.
  263. normalize_actions=False,
  264. clip_actions=False,
  265. batch_mode="complete_episodes")
  266. sample = ev3.sample()
  267. self.assertGreater(np.min(sample["actions"]), action_space.low[0])
  268. self.assertLess(np.max(sample["actions"]), action_space.high[0])
  269. ev3.stop()
  270. def test_action_normalization(self):
  271. from ray.rllib.examples.env.random_env import RandomEnv
  272. action_space = gym.spaces.Box(0.0001, 0.0002, (5, ))
  273. # Normalize: True (unsquash between Policy's action_space.low/high).
  274. ev = RolloutWorker(
  275. env_creator=lambda _: RandomEnv(config=dict(
  276. action_space=action_space,
  277. max_episode_len=10,
  278. p_done=0.0,
  279. check_action_bounds=True,
  280. )),
  281. policy_spec=RandomPolicy,
  282. policy_config=dict(
  283. action_space=action_space,
  284. ignore_action_bounds=True,
  285. ),
  286. normalize_actions=True,
  287. clip_actions=False,
  288. batch_mode="complete_episodes")
  289. sample = ev.sample()
  290. # Check, whether the action bounds have been breached (expected).
  291. # We still arrived here b/c we unsquashed according to the Env's action
  292. # space.
  293. self.assertGreater(np.max(sample["actions"]), action_space.high[0])
  294. self.assertLess(np.min(sample["actions"]), action_space.low[0])
  295. ev.stop()
  296. def test_reward_clipping(self):
  297. # Clipping: True (clip between -1.0 and 1.0).
  298. ev = RolloutWorker(
  299. env_creator=lambda _: MockEnv2(episode_length=10),
  300. policy_spec=MockPolicy,
  301. clip_rewards=True,
  302. batch_mode="complete_episodes")
  303. self.assertEqual(max(ev.sample()["rewards"]), 1)
  304. result = collect_metrics(ev, [])
  305. self.assertEqual(result["episode_reward_mean"], 1000)
  306. ev.stop()
  307. from ray.rllib.examples.env.random_env import RandomEnv
  308. # Clipping in certain range (-2.0, 2.0).
  309. ev2 = RolloutWorker(
  310. env_creator=lambda _: RandomEnv(
  311. dict(
  312. reward_space=gym.spaces.Box(low=-10, high=10, shape=()),
  313. p_done=0.0,
  314. max_episode_len=10,
  315. )),
  316. policy_spec=MockPolicy,
  317. clip_rewards=2.0,
  318. batch_mode="complete_episodes")
  319. sample = ev2.sample()
  320. self.assertEqual(max(sample["rewards"]), 2.0)
  321. self.assertEqual(min(sample["rewards"]), -2.0)
  322. self.assertLess(np.mean(sample["rewards"]), 0.5)
  323. self.assertGreater(np.mean(sample["rewards"]), -0.5)
  324. ev2.stop()
  325. # Clipping: Off.
  326. ev2 = RolloutWorker(
  327. env_creator=lambda _: MockEnv2(episode_length=10),
  328. policy_spec=MockPolicy,
  329. clip_rewards=False,
  330. batch_mode="complete_episodes")
  331. self.assertEqual(max(ev2.sample()["rewards"]), 100)
  332. result2 = collect_metrics(ev2, [])
  333. self.assertEqual(result2["episode_reward_mean"], 1000)
  334. ev2.stop()
  335. def test_hard_horizon(self):
  336. ev = RolloutWorker(
  337. env_creator=lambda _: MockEnv2(episode_length=10),
  338. policy_spec=MockPolicy,
  339. batch_mode="complete_episodes",
  340. rollout_fragment_length=10,
  341. episode_horizon=4,
  342. soft_horizon=False)
  343. samples = ev.sample()
  344. # Three logical episodes and correct episode resets (always after 4
  345. # steps).
  346. self.assertEqual(len(set(samples["eps_id"])), 3)
  347. for i in range(4):
  348. self.assertEqual(np.argmax(samples["obs"][i]), i)
  349. self.assertEqual(np.argmax(samples["obs"][4]), 0)
  350. # 3 done values.
  351. self.assertEqual(sum(samples["dones"]), 3)
  352. ev.stop()
  353. # A gym env's max_episode_steps is smaller than Trainer's horizon.
  354. ev = RolloutWorker(
  355. env_creator=lambda _: gym.make("CartPole-v0"),
  356. policy_spec=MockPolicy,
  357. batch_mode="complete_episodes",
  358. rollout_fragment_length=10,
  359. episode_horizon=6,
  360. soft_horizon=False)
  361. samples = ev.sample()
  362. # 12 steps due to `complete_episodes` batch_mode.
  363. self.assertEqual(len(samples["eps_id"]), 12)
  364. # Two logical episodes and correct episode resets (always after 6(!)
  365. # steps).
  366. self.assertEqual(len(set(samples["eps_id"])), 2)
  367. # 2 done values after 6 and 12 steps.
  368. check(samples["dones"], [
  369. False, False, False, False, False, True, False, False, False,
  370. False, False, True
  371. ])
  372. ev.stop()
  373. def test_soft_horizon(self):
  374. ev = RolloutWorker(
  375. env_creator=lambda _: MockEnv(episode_length=10),
  376. policy_spec=MockPolicy,
  377. batch_mode="complete_episodes",
  378. rollout_fragment_length=10,
  379. episode_horizon=4,
  380. soft_horizon=True)
  381. samples = ev.sample()
  382. # three logical episodes
  383. self.assertEqual(len(set(samples["eps_id"])), 3)
  384. # only 1 hard done value
  385. self.assertEqual(sum(samples["dones"]), 1)
  386. ev.stop()
  387. def test_metrics(self):
  388. ev = RolloutWorker(
  389. env_creator=lambda _: MockEnv(episode_length=10),
  390. policy_spec=MockPolicy,
  391. batch_mode="complete_episodes")
  392. remote_ev = RolloutWorker.as_remote().remote(
  393. env_creator=lambda _: MockEnv(episode_length=10),
  394. policy_spec=MockPolicy,
  395. batch_mode="complete_episodes")
  396. ev.sample()
  397. ray.get(remote_ev.sample.remote())
  398. result = collect_metrics(ev, [remote_ev])
  399. self.assertEqual(result["episodes_this_iter"], 20)
  400. self.assertEqual(result["episode_reward_mean"], 10)
  401. ev.stop()
  402. def test_async(self):
  403. ev = RolloutWorker(
  404. env_creator=lambda _: gym.make("CartPole-v0"),
  405. sample_async=True,
  406. policy_spec=MockPolicy)
  407. batch = ev.sample()
  408. for key in ["obs", "actions", "rewards", "dones", "advantages"]:
  409. self.assertIn(key, batch)
  410. self.assertGreater(batch["advantages"][0], 1)
  411. ev.stop()
  412. def test_auto_vectorization(self):
  413. ev = RolloutWorker(
  414. env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
  415. policy_spec=MockPolicy,
  416. batch_mode="truncate_episodes",
  417. rollout_fragment_length=2,
  418. num_envs=8)
  419. for _ in range(8):
  420. batch = ev.sample()
  421. self.assertEqual(batch.count, 16)
  422. result = collect_metrics(ev, [])
  423. self.assertEqual(result["episodes_this_iter"], 0)
  424. for _ in range(8):
  425. batch = ev.sample()
  426. self.assertEqual(batch.count, 16)
  427. result = collect_metrics(ev, [])
  428. self.assertEqual(result["episodes_this_iter"], 8)
  429. indices = []
  430. for env in ev.async_env.vector_env.envs:
  431. self.assertEqual(env.unwrapped.config.worker_index, 0)
  432. indices.append(env.unwrapped.config.vector_index)
  433. self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
  434. ev.stop()
  435. def test_batches_larger_when_vectorized(self):
  436. ev = RolloutWorker(
  437. env_creator=lambda _: MockEnv(episode_length=8),
  438. policy_spec=MockPolicy,
  439. batch_mode="truncate_episodes",
  440. rollout_fragment_length=4,
  441. num_envs=4)
  442. batch = ev.sample()
  443. self.assertEqual(batch.count, 16)
  444. result = collect_metrics(ev, [])
  445. self.assertEqual(result["episodes_this_iter"], 0)
  446. batch = ev.sample()
  447. result = collect_metrics(ev, [])
  448. self.assertEqual(result["episodes_this_iter"], 4)
  449. ev.stop()
  450. def test_vector_env_support(self):
  451. # Test a vector env that contains 8 actual envs
  452. # (MockEnv instances).
  453. ev = RolloutWorker(
  454. env_creator=(
  455. lambda _: VectorizedMockEnv(episode_length=20, num_envs=8)),
  456. policy_spec=MockPolicy,
  457. batch_mode="truncate_episodes",
  458. rollout_fragment_length=10)
  459. for _ in range(8):
  460. batch = ev.sample()
  461. self.assertEqual(batch.count, 10)
  462. result = collect_metrics(ev, [])
  463. self.assertEqual(result["episodes_this_iter"], 0)
  464. for _ in range(8):
  465. batch = ev.sample()
  466. self.assertEqual(batch.count, 10)
  467. result = collect_metrics(ev, [])
  468. self.assertEqual(result["episodes_this_iter"], 8)
  469. ev.stop()
  470. # Test a vector env that pretends(!) to contain 4 envs, but actually
  471. # only has 1 (CartPole).
  472. ev = RolloutWorker(
  473. env_creator=(lambda _: MockVectorEnv(20, mocked_num_envs=4)),
  474. policy_spec=MockPolicy,
  475. batch_mode="truncate_episodes",
  476. rollout_fragment_length=10)
  477. for _ in range(8):
  478. batch = ev.sample()
  479. self.assertEqual(batch.count, 10)
  480. result = collect_metrics(ev, [])
  481. self.assertGreater(result["episodes_this_iter"], 3)
  482. for _ in range(8):
  483. batch = ev.sample()
  484. self.assertEqual(batch.count, 10)
  485. result = collect_metrics(ev, [])
  486. self.assertGreater(result["episodes_this_iter"], 7)
  487. ev.stop()
  488. def test_truncate_episodes(self):
  489. ev_env_steps = RolloutWorker(
  490. env_creator=lambda _: MockEnv(10),
  491. policy_spec=MockPolicy,
  492. rollout_fragment_length=15,
  493. batch_mode="truncate_episodes")
  494. batch = ev_env_steps.sample()
  495. self.assertEqual(batch.count, 15)
  496. self.assertTrue(isinstance(batch, SampleBatch))
  497. ev_env_steps.stop()
  498. action_space = Discrete(2)
  499. obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)
  500. ev_agent_steps = RolloutWorker(
  501. env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
  502. policy_spec={
  503. "pol0": (MockPolicy, obs_space, action_space, {}),
  504. "pol1": (MockPolicy, obs_space, action_space, {}),
  505. },
  506. policy_mapping_fn=lambda agent_id, episode, **kwargs:
  507. "pol0" if agent_id == 0 else "pol1",
  508. rollout_fragment_length=301,
  509. count_steps_by="env_steps",
  510. batch_mode="truncate_episodes",
  511. )
  512. batch = ev_agent_steps.sample()
  513. self.assertTrue(isinstance(batch, MultiAgentBatch))
  514. self.assertGreater(batch.agent_steps(), 301)
  515. self.assertEqual(batch.env_steps(), 301)
  516. ev_agent_steps.stop()
  517. ev_agent_steps = RolloutWorker(
  518. env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
  519. policy_spec={
  520. "pol0": (MockPolicy, obs_space, action_space, {}),
  521. "pol1": (MockPolicy, obs_space, action_space, {}),
  522. },
  523. policy_mapping_fn=lambda agent_id, episode, **kwargs:
  524. "pol0" if agent_id == 0 else "pol1",
  525. rollout_fragment_length=301,
  526. count_steps_by="agent_steps",
  527. batch_mode="truncate_episodes")
  528. batch = ev_agent_steps.sample()
  529. self.assertTrue(isinstance(batch, MultiAgentBatch))
  530. self.assertLess(batch.env_steps(), 301)
  531. # When counting agent steps, the count may be slightly larger than
  532. # rollout_fragment_length, b/c we have up to N agents stepping in each
  533. # env step and we only check, whether we should build after each env
  534. # step.
  535. self.assertGreaterEqual(batch.agent_steps(), 301)
  536. ev_agent_steps.stop()
  537. def test_complete_episodes(self):
  538. ev = RolloutWorker(
  539. env_creator=lambda _: MockEnv(10),
  540. policy_spec=MockPolicy,
  541. rollout_fragment_length=5,
  542. batch_mode="complete_episodes")
  543. batch = ev.sample()
  544. self.assertEqual(batch.count, 10)
  545. ev.stop()
  546. def test_complete_episodes_packing(self):
  547. ev = RolloutWorker(
  548. env_creator=lambda _: MockEnv(10),
  549. policy_spec=MockPolicy,
  550. rollout_fragment_length=15,
  551. batch_mode="complete_episodes")
  552. batch = ev.sample()
  553. self.assertEqual(batch.count, 20)
  554. self.assertEqual(
  555. batch["t"].tolist(),
  556. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
  557. ev.stop()
  558. def test_filter_sync(self):
  559. ev = RolloutWorker(
  560. env_creator=lambda _: gym.make("CartPole-v0"),
  561. policy_spec=MockPolicy,
  562. sample_async=True,
  563. observation_filter="ConcurrentMeanStdFilter")
  564. time.sleep(2)
  565. ev.sample()
  566. filters = ev.get_filters(flush_after=True)
  567. obs_f = filters[DEFAULT_POLICY_ID]
  568. self.assertNotEqual(obs_f.rs.n, 0)
  569. self.assertNotEqual(obs_f.buffer.n, 0)
  570. ev.stop()
  571. def test_get_filters(self):
  572. ev = RolloutWorker(
  573. env_creator=lambda _: gym.make("CartPole-v0"),
  574. policy_spec=MockPolicy,
  575. sample_async=True,
  576. observation_filter="ConcurrentMeanStdFilter")
  577. self.sample_and_flush(ev)
  578. filters = ev.get_filters(flush_after=False)
  579. time.sleep(2)
  580. filters2 = ev.get_filters(flush_after=False)
  581. obs_f = filters[DEFAULT_POLICY_ID]
  582. obs_f2 = filters2[DEFAULT_POLICY_ID]
  583. self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
  584. self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
  585. ev.stop()
  586. def test_sync_filter(self):
  587. ev = RolloutWorker(
  588. env_creator=lambda _: gym.make("CartPole-v0"),
  589. policy_spec=MockPolicy,
  590. sample_async=True,
  591. observation_filter="ConcurrentMeanStdFilter")
  592. obs_f = self.sample_and_flush(ev)
  593. # Current State
  594. filters = ev.get_filters(flush_after=False)
  595. obs_f = filters[DEFAULT_POLICY_ID]
  596. self.assertLessEqual(obs_f.buffer.n, 20)
  597. new_obsf = obs_f.copy()
  598. new_obsf.rs._n = 100
  599. ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
  600. filters = ev.get_filters(flush_after=False)
  601. obs_f = filters[DEFAULT_POLICY_ID]
  602. self.assertGreaterEqual(obs_f.rs.n, 100)
  603. self.assertLessEqual(obs_f.buffer.n, 20)
  604. ev.stop()
  605. def test_extra_python_envs(self):
  606. extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"}
  607. self.assertFalse("env_key_1" in os.environ)
  608. self.assertFalse("env_key_2" in os.environ)
  609. ev = RolloutWorker(
  610. env_creator=lambda _: MockEnv(10),
  611. policy_spec=MockPolicy,
  612. extra_python_environs=extra_envs)
  613. self.assertTrue("env_key_1" in os.environ)
  614. self.assertTrue("env_key_2" in os.environ)
  615. ev.stop()
  616. # reset to original
  617. del os.environ["env_key_1"]
  618. del os.environ["env_key_2"]
  619. def test_no_env_seed(self):
  620. ev = RolloutWorker(
  621. env_creator=lambda _: MockVectorEnv(20, mocked_num_envs=8),
  622. policy_spec=MockPolicy,
  623. seed=1)
  624. assert not hasattr(ev.env, "seed")
  625. ev.stop()
  626. def test_multi_env_seed(self):
  627. ev = RolloutWorker(
  628. env_creator=lambda _: MockEnv2(100),
  629. num_envs=3,
  630. policy_spec=MockPolicy,
  631. seed=1)
  632. # Make sure we can properly sample from the wrapped env.
  633. ev.sample()
  634. # Make sure all environments got a different deterministic seed.
  635. seeds = ev.foreach_env(lambda env: env.rng_seed)
  636. self.assertEqual(seeds, [1, 2, 3])
  637. ev.stop()
  638. def test_wrap_multi_agent_env(self):
  639. ev = RolloutWorker(
  640. env_creator=lambda _: BasicMultiAgent(10),
  641. policy_spec=MockPolicy,
  642. policy_config={
  643. "in_evaluation": False,
  644. },
  645. record_env=tempfile.gettempdir())
  646. # Make sure we can properly sample from the wrapped env.
  647. ev.sample()
  648. # Make sure the resulting environment is indeed still an
  649. # instance of MultiAgentEnv and VideoMonitor.
  650. self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
  651. self.assertTrue(isinstance(ev.env, gym.Env))
  652. self.assertTrue(isinstance(ev.env, VideoMonitor))
  653. ev.stop()
  654. def test_no_training(self):
  655. class NoTrainingEnv(MockEnv):
  656. def __init__(self, episode_length, training_enabled):
  657. super(NoTrainingEnv, self).__init__(episode_length)
  658. self.training_enabled = training_enabled
  659. def step(self, action):
  660. obs, rew, done, info = super(NoTrainingEnv, self).step(action)
  661. return obs, rew, done, {
  662. **info, "training_enabled": self.training_enabled
  663. }
  664. ev = RolloutWorker(
  665. env_creator=lambda _: NoTrainingEnv(10, True),
  666. policy_spec=MockPolicy,
  667. rollout_fragment_length=5,
  668. batch_mode="complete_episodes")
  669. batch = ev.sample()
  670. self.assertEqual(batch.count, 10)
  671. self.assertEqual(len(batch["obs"]), 10)
  672. ev.stop()
  673. ev = RolloutWorker(
  674. env_creator=lambda _: NoTrainingEnv(10, False),
  675. policy_spec=MockPolicy,
  676. rollout_fragment_length=5,
  677. batch_mode="complete_episodes")
  678. batch = ev.sample()
  679. self.assertTrue(isinstance(batch, MultiAgentBatch))
  680. self.assertEqual(len(batch.policy_batches), 0)
  681. ev.stop()
  682. def sample_and_flush(self, ev):
  683. time.sleep(2)
  684. ev.sample()
  685. filters = ev.get_filters(flush_after=True)
  686. obs_f = filters[DEFAULT_POLICY_ID]
  687. self.assertNotEqual(obs_f.rs.n, 0)
  688. self.assertNotEqual(obs_f.buffer.n, 0)
  689. return obs_f
  690. if __name__ == "__main__":
  691. import pytest
  692. import sys
  693. sys.exit(pytest.main(["-v", __file__]))