123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756 |
- from collections import Counter
- import gym
- from gym.spaces import Box, Discrete
- import numpy as np
- import os
- import random
- import tempfile
- import time
- import unittest
- import ray
- from ray.rllib.agents.pg import PGTrainer
- from ray.rllib.agents.a3c import A2CTrainer
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
- from ray.rllib.env.utils import VideoMonitor
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.evaluation.metrics import collect_metrics
- from ray.rllib.evaluation.postprocessing import compute_advantages
- from ray.rllib.examples.env.mock_env import MockEnv, MockEnv2, MockVectorEnv,\
- VectorizedMockEnv
- from ray.rllib.examples.env.multi_agent import BasicMultiAgent,\
- MultiAgentCartPole
- from ray.rllib.examples.policy.random_policy import RandomPolicy
- from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
- STEPS_TRAINED_COUNTER
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, MultiAgentBatch, \
- SampleBatch
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.test_utils import check, framework_iterator
- from ray.tune.registry import register_env
- class MockPolicy(RandomPolicy):
- @override(RandomPolicy)
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- episodes=None,
- explore=None,
- timestep=None,
- **kwargs):
- return np.array([random.choice([0, 1])] * len(obs_batch)), [], {}
- @override(Policy)
- def postprocess_trajectory(self,
- batch,
- other_agent_batches=None,
- episode=None):
- assert episode is not None
- super().postprocess_trajectory(batch, other_agent_batches, episode)
- return compute_advantages(
- batch, 100.0, 0.9, use_gae=False, use_critic=False)
- class BadPolicy(RandomPolicy):
- @override(RandomPolicy)
- def compute_actions(self,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- episodes=None,
- explore=None,
- timestep=None,
- **kwargs):
- raise Exception("intentional error")
- class FailOnStepEnv(gym.Env):
- def __init__(self):
- self.observation_space = gym.spaces.Discrete(1)
- self.action_space = gym.spaces.Discrete(2)
- def reset(self):
- raise ValueError("kaboom")
- def step(self, action):
- raise ValueError("kaboom")
- class TestRolloutWorker(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- ray.init(num_cpus=5)
- @classmethod
- def tearDownClass(cls):
- ray.shutdown()
- def test_basic(self):
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=MockPolicy)
- batch = ev.sample()
- for key in [
- "obs", "actions", "rewards", "dones", "advantages",
- "prev_rewards", "prev_actions"
- ]:
- self.assertIn(key, batch)
- self.assertGreater(np.abs(np.mean(batch[key])), 0)
- def to_prev(vec):
- out = np.zeros_like(vec)
- for i, v in enumerate(vec):
- if i + 1 < len(out) and not batch["dones"][i]:
- out[i + 1] = v
- return out.tolist()
- self.assertEqual(batch["prev_rewards"].tolist(),
- to_prev(batch["rewards"]))
- self.assertEqual(batch["prev_actions"].tolist(),
- to_prev(batch["actions"]))
- self.assertGreater(batch["advantages"][0], 1)
- ev.stop()
- def test_batch_ids(self):
- fragment_len = 100
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=MockPolicy,
- rollout_fragment_length=fragment_len)
- batch1 = ev.sample()
- batch2 = ev.sample()
- unroll_ids_1 = set(batch1["unroll_id"])
- unroll_ids_2 = set(batch2["unroll_id"])
- # Assert no overlap of unroll IDs between sample() calls.
- self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1))
- # CartPole episodes should be short initially: Expect more than one
- # unroll ID in each batch.
- self.assertTrue(len(unroll_ids_1) > 1)
- self.assertTrue(len(unroll_ids_2) > 1)
- ev.stop()
- def test_global_vars_update(self):
- for fw in framework_iterator(frameworks=("tf2", "tf")):
- agent = A2CTrainer(
- env="CartPole-v0",
- config={
- "num_workers": 1,
- # lr = 0.1 - [(0.1 - 0.000001) / 100000] * ts
- "lr_schedule": [[0, 0.1], [100000, 0.000001]],
- "framework": fw,
- })
- policy = agent.get_policy()
- for i in range(3):
- result = agent.train()
- print("{}={}".format(STEPS_TRAINED_COUNTER,
- result["info"][STEPS_TRAINED_COUNTER]))
- print("{}={}".format(STEPS_SAMPLED_COUNTER,
- result["info"][STEPS_SAMPLED_COUNTER]))
- global_timesteps = policy.global_timestep
- print("global_timesteps={}".format(global_timesteps))
- expected_lr = \
- 0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
- lr = policy.cur_lr
- if fw == "tf":
- lr = policy.get_session().run(lr)
- check(lr, expected_lr, rtol=0.05)
- agent.stop()
- def test_no_step_on_init(self):
- register_env("fail", lambda _: FailOnStepEnv())
- for fw in framework_iterator():
- # We expect this to fail already on Trainer init due
- # to the env sanity check right after env creation (inside
- # RolloutWorker).
- self.assertRaises(Exception, lambda: PGTrainer(
- env="fail", config={
- "num_workers": 2,
- "framework": fw,
- }))
- def test_callbacks(self):
- for fw in framework_iterator(frameworks=("torch", "tf")):
- counts = Counter()
- pg = PGTrainer(
- env="CartPole-v0", config={
- "num_workers": 0,
- "rollout_fragment_length": 50,
- "train_batch_size": 50,
- "callbacks": {
- "on_episode_start":
- lambda x: counts.update({"start": 1}),
- "on_episode_step":
- lambda x: counts.update({"step": 1}),
- "on_episode_end": lambda x: counts.update({"end": 1}),
- "on_sample_end":
- lambda x: counts.update({"sample": 1}),
- },
- "framework": fw,
- })
- pg.train()
- pg.train()
- self.assertGreater(counts["sample"], 0)
- self.assertGreater(counts["start"], 0)
- self.assertGreater(counts["end"], 0)
- self.assertGreater(counts["step"], 0)
- pg.stop()
- def test_query_evaluators(self):
- register_env("test", lambda _: gym.make("CartPole-v0"))
- for fw in framework_iterator(frameworks=("torch", "tf")):
- pg = PGTrainer(
- env="test",
- config={
- "num_workers": 2,
- "rollout_fragment_length": 5,
- "num_envs_per_worker": 2,
- "framework": fw,
- "create_env_on_driver": True,
- })
- results = pg.workers.foreach_worker(
- lambda ev: ev.rollout_fragment_length)
- results2 = pg.workers.foreach_worker_with_index(
- lambda ev, i: (i, ev.rollout_fragment_length))
- results3 = pg.workers.foreach_worker(
- lambda ev: ev.foreach_env(lambda env: 1))
- self.assertEqual(results, [10, 10, 10])
- self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
- self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
- pg.stop()
- def test_action_clipping(self):
- from ray.rllib.examples.env.random_env import RandomEnv
- action_space = gym.spaces.Box(-2.0, 1.0, (3, ))
- # Clipping: True (clip between Policy's action_space.low/high).
- ev = RolloutWorker(
- env_creator=lambda _: RandomEnv(config=dict(
- action_space=action_space,
- max_episode_len=10,
- p_done=0.0,
- check_action_bounds=True,
- )),
- policy_spec=RandomPolicy,
- policy_config=dict(
- action_space=action_space,
- ignore_action_bounds=True,
- ),
- normalize_actions=False,
- clip_actions=True,
- batch_mode="complete_episodes")
- sample = ev.sample()
- # Check, whether the action bounds have been breached (expected).
- # We still arrived here b/c we clipped according to the Env's action
- # space.
- self.assertGreater(np.max(sample["actions"]), action_space.high[0])
- self.assertLess(np.min(sample["actions"]), action_space.low[0])
- ev.stop()
- # Clipping: False and RandomPolicy produces invalid actions.
- # Expect Env to complain.
- ev2 = RolloutWorker(
- env_creator=lambda _: RandomEnv(config=dict(
- action_space=action_space,
- max_episode_len=10,
- p_done=0.0,
- check_action_bounds=True,
- )),
- policy_spec=RandomPolicy,
- policy_config=dict(
- action_space=action_space,
- ignore_action_bounds=True,
- ),
- # No normalization (+clipping) and no clipping ->
- # Should lead to Env complaining.
- normalize_actions=False,
- clip_actions=False,
- batch_mode="complete_episodes")
- self.assertRaisesRegex(ValueError, r"Illegal action", ev2.sample)
- ev2.stop()
- # Clipping: False and RandomPolicy produces valid (bounded) actions.
- # Expect "actions" in SampleBatch to be unclipped.
- ev3 = RolloutWorker(
- env_creator=lambda _: RandomEnv(config=dict(
- action_space=action_space,
- max_episode_len=10,
- p_done=0.0,
- check_action_bounds=True,
- )),
- policy_spec=RandomPolicy,
- policy_config=dict(action_space=action_space),
- # Should not be a problem as RandomPolicy abides to bounds.
- normalize_actions=False,
- clip_actions=False,
- batch_mode="complete_episodes")
- sample = ev3.sample()
- self.assertGreater(np.min(sample["actions"]), action_space.low[0])
- self.assertLess(np.max(sample["actions"]), action_space.high[0])
- ev3.stop()
- def test_action_normalization(self):
- from ray.rllib.examples.env.random_env import RandomEnv
- action_space = gym.spaces.Box(0.0001, 0.0002, (5, ))
- # Normalize: True (unsquash between Policy's action_space.low/high).
- ev = RolloutWorker(
- env_creator=lambda _: RandomEnv(config=dict(
- action_space=action_space,
- max_episode_len=10,
- p_done=0.0,
- check_action_bounds=True,
- )),
- policy_spec=RandomPolicy,
- policy_config=dict(
- action_space=action_space,
- ignore_action_bounds=True,
- ),
- normalize_actions=True,
- clip_actions=False,
- batch_mode="complete_episodes")
- sample = ev.sample()
- # Check, whether the action bounds have been breached (expected).
- # We still arrived here b/c we unsquashed according to the Env's action
- # space.
- self.assertGreater(np.max(sample["actions"]), action_space.high[0])
- self.assertLess(np.min(sample["actions"]), action_space.low[0])
- ev.stop()
- def test_reward_clipping(self):
- # Clipping: True (clip between -1.0 and 1.0).
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv2(episode_length=10),
- policy_spec=MockPolicy,
- clip_rewards=True,
- batch_mode="complete_episodes")
- self.assertEqual(max(ev.sample()["rewards"]), 1)
- result = collect_metrics(ev, [])
- self.assertEqual(result["episode_reward_mean"], 1000)
- ev.stop()
- from ray.rllib.examples.env.random_env import RandomEnv
- # Clipping in certain range (-2.0, 2.0).
- ev2 = RolloutWorker(
- env_creator=lambda _: RandomEnv(
- dict(
- reward_space=gym.spaces.Box(low=-10, high=10, shape=()),
- p_done=0.0,
- max_episode_len=10,
- )),
- policy_spec=MockPolicy,
- clip_rewards=2.0,
- batch_mode="complete_episodes")
- sample = ev2.sample()
- self.assertEqual(max(sample["rewards"]), 2.0)
- self.assertEqual(min(sample["rewards"]), -2.0)
- self.assertLess(np.mean(sample["rewards"]), 0.5)
- self.assertGreater(np.mean(sample["rewards"]), -0.5)
- ev2.stop()
- # Clipping: Off.
- ev2 = RolloutWorker(
- env_creator=lambda _: MockEnv2(episode_length=10),
- policy_spec=MockPolicy,
- clip_rewards=False,
- batch_mode="complete_episodes")
- self.assertEqual(max(ev2.sample()["rewards"]), 100)
- result2 = collect_metrics(ev2, [])
- self.assertEqual(result2["episode_reward_mean"], 1000)
- ev2.stop()
- def test_hard_horizon(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv2(episode_length=10),
- policy_spec=MockPolicy,
- batch_mode="complete_episodes",
- rollout_fragment_length=10,
- episode_horizon=4,
- soft_horizon=False)
- samples = ev.sample()
- # Three logical episodes and correct episode resets (always after 4
- # steps).
- self.assertEqual(len(set(samples["eps_id"])), 3)
- for i in range(4):
- self.assertEqual(np.argmax(samples["obs"][i]), i)
- self.assertEqual(np.argmax(samples["obs"][4]), 0)
- # 3 done values.
- self.assertEqual(sum(samples["dones"]), 3)
- ev.stop()
- # A gym env's max_episode_steps is smaller than Trainer's horizon.
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=MockPolicy,
- batch_mode="complete_episodes",
- rollout_fragment_length=10,
- episode_horizon=6,
- soft_horizon=False)
- samples = ev.sample()
- # 12 steps due to `complete_episodes` batch_mode.
- self.assertEqual(len(samples["eps_id"]), 12)
- # Two logical episodes and correct episode resets (always after 6(!)
- # steps).
- self.assertEqual(len(set(samples["eps_id"])), 2)
- # 2 done values after 6 and 12 steps.
- check(samples["dones"], [
- False, False, False, False, False, True, False, False, False,
- False, False, True
- ])
- ev.stop()
- def test_soft_horizon(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv(episode_length=10),
- policy_spec=MockPolicy,
- batch_mode="complete_episodes",
- rollout_fragment_length=10,
- episode_horizon=4,
- soft_horizon=True)
- samples = ev.sample()
- # three logical episodes
- self.assertEqual(len(set(samples["eps_id"])), 3)
- # only 1 hard done value
- self.assertEqual(sum(samples["dones"]), 1)
- ev.stop()
- def test_metrics(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv(episode_length=10),
- policy_spec=MockPolicy,
- batch_mode="complete_episodes")
- remote_ev = RolloutWorker.as_remote().remote(
- env_creator=lambda _: MockEnv(episode_length=10),
- policy_spec=MockPolicy,
- batch_mode="complete_episodes")
- ev.sample()
- ray.get(remote_ev.sample.remote())
- result = collect_metrics(ev, [remote_ev])
- self.assertEqual(result["episodes_this_iter"], 20)
- self.assertEqual(result["episode_reward_mean"], 10)
- ev.stop()
- def test_async(self):
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- sample_async=True,
- policy_spec=MockPolicy)
- batch = ev.sample()
- for key in ["obs", "actions", "rewards", "dones", "advantages"]:
- self.assertIn(key, batch)
- self.assertGreater(batch["advantages"][0], 1)
- ev.stop()
- def test_auto_vectorization(self):
- ev = RolloutWorker(
- env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
- policy_spec=MockPolicy,
- batch_mode="truncate_episodes",
- rollout_fragment_length=2,
- num_envs=8)
- for _ in range(8):
- batch = ev.sample()
- self.assertEqual(batch.count, 16)
- result = collect_metrics(ev, [])
- self.assertEqual(result["episodes_this_iter"], 0)
- for _ in range(8):
- batch = ev.sample()
- self.assertEqual(batch.count, 16)
- result = collect_metrics(ev, [])
- self.assertEqual(result["episodes_this_iter"], 8)
- indices = []
- for env in ev.async_env.vector_env.envs:
- self.assertEqual(env.unwrapped.config.worker_index, 0)
- indices.append(env.unwrapped.config.vector_index)
- self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
- ev.stop()
- def test_batches_larger_when_vectorized(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv(episode_length=8),
- policy_spec=MockPolicy,
- batch_mode="truncate_episodes",
- rollout_fragment_length=4,
- num_envs=4)
- batch = ev.sample()
- self.assertEqual(batch.count, 16)
- result = collect_metrics(ev, [])
- self.assertEqual(result["episodes_this_iter"], 0)
- batch = ev.sample()
- result = collect_metrics(ev, [])
- self.assertEqual(result["episodes_this_iter"], 4)
- ev.stop()
- def test_vector_env_support(self):
- # Test a vector env that contains 8 actual envs
- # (MockEnv instances).
- ev = RolloutWorker(
- env_creator=(
- lambda _: VectorizedMockEnv(episode_length=20, num_envs=8)),
- policy_spec=MockPolicy,
- batch_mode="truncate_episodes",
- rollout_fragment_length=10)
- for _ in range(8):
- batch = ev.sample()
- self.assertEqual(batch.count, 10)
- result = collect_metrics(ev, [])
- self.assertEqual(result["episodes_this_iter"], 0)
- for _ in range(8):
- batch = ev.sample()
- self.assertEqual(batch.count, 10)
- result = collect_metrics(ev, [])
- self.assertEqual(result["episodes_this_iter"], 8)
- ev.stop()
- # Test a vector env that pretends(!) to contain 4 envs, but actually
- # only has 1 (CartPole).
- ev = RolloutWorker(
- env_creator=(lambda _: MockVectorEnv(20, mocked_num_envs=4)),
- policy_spec=MockPolicy,
- batch_mode="truncate_episodes",
- rollout_fragment_length=10)
- for _ in range(8):
- batch = ev.sample()
- self.assertEqual(batch.count, 10)
- result = collect_metrics(ev, [])
- self.assertGreater(result["episodes_this_iter"], 3)
- for _ in range(8):
- batch = ev.sample()
- self.assertEqual(batch.count, 10)
- result = collect_metrics(ev, [])
- self.assertGreater(result["episodes_this_iter"], 7)
- ev.stop()
- def test_truncate_episodes(self):
- ev_env_steps = RolloutWorker(
- env_creator=lambda _: MockEnv(10),
- policy_spec=MockPolicy,
- rollout_fragment_length=15,
- batch_mode="truncate_episodes")
- batch = ev_env_steps.sample()
- self.assertEqual(batch.count, 15)
- self.assertTrue(isinstance(batch, SampleBatch))
- ev_env_steps.stop()
- action_space = Discrete(2)
- obs_space = Box(float("-inf"), float("inf"), (4, ), dtype=np.float32)
- ev_agent_steps = RolloutWorker(
- env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
- policy_spec={
- "pol0": (MockPolicy, obs_space, action_space, {}),
- "pol1": (MockPolicy, obs_space, action_space, {}),
- },
- policy_mapping_fn=lambda agent_id, episode, **kwargs:
- "pol0" if agent_id == 0 else "pol1",
- rollout_fragment_length=301,
- count_steps_by="env_steps",
- batch_mode="truncate_episodes",
- )
- batch = ev_agent_steps.sample()
- self.assertTrue(isinstance(batch, MultiAgentBatch))
- self.assertGreater(batch.agent_steps(), 301)
- self.assertEqual(batch.env_steps(), 301)
- ev_agent_steps.stop()
- ev_agent_steps = RolloutWorker(
- env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
- policy_spec={
- "pol0": (MockPolicy, obs_space, action_space, {}),
- "pol1": (MockPolicy, obs_space, action_space, {}),
- },
- policy_mapping_fn=lambda agent_id, episode, **kwargs:
- "pol0" if agent_id == 0 else "pol1",
- rollout_fragment_length=301,
- count_steps_by="agent_steps",
- batch_mode="truncate_episodes")
- batch = ev_agent_steps.sample()
- self.assertTrue(isinstance(batch, MultiAgentBatch))
- self.assertLess(batch.env_steps(), 301)
- # When counting agent steps, the count may be slightly larger than
- # rollout_fragment_length, b/c we have up to N agents stepping in each
- # env step and we only check, whether we should build after each env
- # step.
- self.assertGreaterEqual(batch.agent_steps(), 301)
- ev_agent_steps.stop()
- def test_complete_episodes(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv(10),
- policy_spec=MockPolicy,
- rollout_fragment_length=5,
- batch_mode="complete_episodes")
- batch = ev.sample()
- self.assertEqual(batch.count, 10)
- ev.stop()
- def test_complete_episodes_packing(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv(10),
- policy_spec=MockPolicy,
- rollout_fragment_length=15,
- batch_mode="complete_episodes")
- batch = ev.sample()
- self.assertEqual(batch.count, 20)
- self.assertEqual(
- batch["t"].tolist(),
- [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
- ev.stop()
- def test_filter_sync(self):
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=MockPolicy,
- sample_async=True,
- observation_filter="ConcurrentMeanStdFilter")
- time.sleep(2)
- ev.sample()
- filters = ev.get_filters(flush_after=True)
- obs_f = filters[DEFAULT_POLICY_ID]
- self.assertNotEqual(obs_f.rs.n, 0)
- self.assertNotEqual(obs_f.buffer.n, 0)
- ev.stop()
- def test_get_filters(self):
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=MockPolicy,
- sample_async=True,
- observation_filter="ConcurrentMeanStdFilter")
- self.sample_and_flush(ev)
- filters = ev.get_filters(flush_after=False)
- time.sleep(2)
- filters2 = ev.get_filters(flush_after=False)
- obs_f = filters[DEFAULT_POLICY_ID]
- obs_f2 = filters2[DEFAULT_POLICY_ID]
- self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n)
- self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
- ev.stop()
- def test_sync_filter(self):
- ev = RolloutWorker(
- env_creator=lambda _: gym.make("CartPole-v0"),
- policy_spec=MockPolicy,
- sample_async=True,
- observation_filter="ConcurrentMeanStdFilter")
- obs_f = self.sample_and_flush(ev)
- # Current State
- filters = ev.get_filters(flush_after=False)
- obs_f = filters[DEFAULT_POLICY_ID]
- self.assertLessEqual(obs_f.buffer.n, 20)
- new_obsf = obs_f.copy()
- new_obsf.rs._n = 100
- ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
- filters = ev.get_filters(flush_after=False)
- obs_f = filters[DEFAULT_POLICY_ID]
- self.assertGreaterEqual(obs_f.rs.n, 100)
- self.assertLessEqual(obs_f.buffer.n, 20)
- ev.stop()
- def test_extra_python_envs(self):
- extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"}
- self.assertFalse("env_key_1" in os.environ)
- self.assertFalse("env_key_2" in os.environ)
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv(10),
- policy_spec=MockPolicy,
- extra_python_environs=extra_envs)
- self.assertTrue("env_key_1" in os.environ)
- self.assertTrue("env_key_2" in os.environ)
- ev.stop()
- # reset to original
- del os.environ["env_key_1"]
- del os.environ["env_key_2"]
- def test_no_env_seed(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockVectorEnv(20, mocked_num_envs=8),
- policy_spec=MockPolicy,
- seed=1)
- assert not hasattr(ev.env, "seed")
- ev.stop()
- def test_multi_env_seed(self):
- ev = RolloutWorker(
- env_creator=lambda _: MockEnv2(100),
- num_envs=3,
- policy_spec=MockPolicy,
- seed=1)
- # Make sure we can properly sample from the wrapped env.
- ev.sample()
- # Make sure all environments got a different deterministic seed.
- seeds = ev.foreach_env(lambda env: env.rng_seed)
- self.assertEqual(seeds, [1, 2, 3])
- ev.stop()
- def test_wrap_multi_agent_env(self):
- ev = RolloutWorker(
- env_creator=lambda _: BasicMultiAgent(10),
- policy_spec=MockPolicy,
- policy_config={
- "in_evaluation": False,
- },
- record_env=tempfile.gettempdir())
- # Make sure we can properly sample from the wrapped env.
- ev.sample()
- # Make sure the resulting environment is indeed still an
- # instance of MultiAgentEnv and VideoMonitor.
- self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
- self.assertTrue(isinstance(ev.env, gym.Env))
- self.assertTrue(isinstance(ev.env, VideoMonitor))
- ev.stop()
- def test_no_training(self):
- class NoTrainingEnv(MockEnv):
- def __init__(self, episode_length, training_enabled):
- super(NoTrainingEnv, self).__init__(episode_length)
- self.training_enabled = training_enabled
- def step(self, action):
- obs, rew, done, info = super(NoTrainingEnv, self).step(action)
- return obs, rew, done, {
- **info, "training_enabled": self.training_enabled
- }
- ev = RolloutWorker(
- env_creator=lambda _: NoTrainingEnv(10, True),
- policy_spec=MockPolicy,
- rollout_fragment_length=5,
- batch_mode="complete_episodes")
- batch = ev.sample()
- self.assertEqual(batch.count, 10)
- self.assertEqual(len(batch["obs"]), 10)
- ev.stop()
- ev = RolloutWorker(
- env_creator=lambda _: NoTrainingEnv(10, False),
- policy_spec=MockPolicy,
- rollout_fragment_length=5,
- batch_mode="complete_episodes")
- batch = ev.sample()
- self.assertTrue(isinstance(batch, MultiAgentBatch))
- self.assertEqual(len(batch.policy_batches), 0)
- ev.stop()
- def sample_and_flush(self, ev):
- time.sleep(2)
- ev.sample()
- filters = ev.get_filters(flush_after=True)
- obs_f = filters[DEFAULT_POLICY_ID]
- self.assertNotEqual(obs_f.rs.n, 0)
- self.assertNotEqual(obs_f.buffer.n, 0)
- return obs_f
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|