test_rollout_worker.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047
  1. import gymnasium as gym
  2. from gymnasium.spaces import Box, Discrete
  3. import json
  4. import numpy as np
  5. import os
  6. import random
  7. import tempfile
  8. import time
  9. import unittest
  10. import ray
  11. from ray.rllib.algorithms.a2c import A2CConfig
  12. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  13. from ray.rllib.algorithms.pg import PGConfig
  14. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  15. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  16. from ray.rllib.evaluation.metrics import collect_metrics
  17. from ray.rllib.evaluation.postprocessing import compute_advantages
  18. from ray.rllib.evaluation.worker_set import WorkerSet
  19. from ray.rllib.examples.env.mock_env import (
  20. MockEnv,
  21. MockEnv2,
  22. MockVectorEnv,
  23. VectorizedMockEnv,
  24. )
  25. from ray.rllib.examples.env.multi_agent import BasicMultiAgent, MultiAgentCartPole
  26. from ray.rllib.examples.policy.random_policy import RandomPolicy
  27. from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards
  28. from ray.rllib.offline.json_reader import JsonReader
  29. from ray.rllib.policy.policy import Policy, PolicySpec
  30. from ray.rllib.policy.sample_batch import (
  31. DEFAULT_POLICY_ID,
  32. MultiAgentBatch,
  33. SampleBatch,
  34. convert_ma_batch_to_sample_batch,
  35. )
  36. from ray.rllib.utils.annotations import override
  37. from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED
  38. from ray.rllib.utils.test_utils import check, framework_iterator
  39. from ray.tune.registry import register_env
  40. class MockPolicy(RandomPolicy):
  41. @override(RandomPolicy)
  42. def compute_actions(
  43. self,
  44. obs_batch,
  45. state_batches=None,
  46. prev_action_batch=None,
  47. prev_reward_batch=None,
  48. episodes=None,
  49. explore=None,
  50. timestep=None,
  51. **kwargs
  52. ):
  53. return np.array([random.choice([0, 1])] * len(obs_batch)), [], {}
  54. @override(Policy)
  55. def postprocess_trajectory(self, batch, other_agent_batches=None, episode=None):
  56. assert episode is not None
  57. super().postprocess_trajectory(batch, other_agent_batches, episode)
  58. return compute_advantages(batch, 100.0, 0.9, use_gae=False, use_critic=False)
  59. class BadPolicy(RandomPolicy):
  60. @override(RandomPolicy)
  61. def compute_actions(
  62. self,
  63. obs_batch,
  64. state_batches=None,
  65. prev_action_batch=None,
  66. prev_reward_batch=None,
  67. episodes=None,
  68. explore=None,
  69. timestep=None,
  70. **kwargs
  71. ):
  72. raise Exception("intentional error")
  73. class FailOnStepEnv(gym.Env):
  74. def __init__(self):
  75. self.observation_space = gym.spaces.Discrete(1)
  76. self.action_space = gym.spaces.Discrete(2)
  77. def reset(self, *, seed=None, options=None):
  78. raise ValueError("kaboom")
  79. def step(self, action):
  80. raise ValueError("kaboom")
  81. class TestRolloutWorker(unittest.TestCase):
  82. @classmethod
  83. def setUpClass(cls):
  84. ray.init(num_cpus=5)
  85. @classmethod
  86. def tearDownClass(cls):
  87. ray.shutdown()
  88. def test_basic(self):
  89. ev = RolloutWorker(
  90. env_creator=lambda _: gym.make("CartPole-v1"),
  91. default_policy_class=MockPolicy,
  92. config=AlgorithmConfig().rollouts(num_rollout_workers=0),
  93. )
  94. batch = convert_ma_batch_to_sample_batch(ev.sample())
  95. for key in [
  96. "obs",
  97. "actions",
  98. "rewards",
  99. "terminateds",
  100. "terminateds",
  101. "advantages",
  102. "prev_rewards",
  103. "prev_actions",
  104. ]:
  105. self.assertIn(key, batch)
  106. self.assertGreater(np.abs(np.mean(batch[key])), 0)
  107. # Our MockPolicy should never reach a full truncated episode.
  108. # Expect all truncateds flags to be False.
  109. self.assertEqual(np.abs(np.mean(batch["truncateds"])), 0.0)
  110. def to_prev(vec):
  111. out = np.zeros_like(vec)
  112. for i, v in enumerate(vec):
  113. if i + 1 < len(out) and not batch["terminateds"][i]:
  114. out[i + 1] = v
  115. return out.tolist()
  116. self.assertEqual(batch["prev_rewards"].tolist(), to_prev(batch["rewards"]))
  117. self.assertEqual(batch["prev_actions"].tolist(), to_prev(batch["actions"]))
  118. self.assertGreater(batch["advantages"][0], 1)
  119. ev.stop()
  120. def test_batch_ids(self):
  121. fragment_len = 100
  122. ev = RolloutWorker(
  123. env_creator=lambda _: gym.make("CartPole-v1"),
  124. default_policy_class=MockPolicy,
  125. config=AlgorithmConfig().rollouts(
  126. rollout_fragment_length=fragment_len, num_rollout_workers=0
  127. ),
  128. )
  129. batch1 = convert_ma_batch_to_sample_batch(ev.sample())
  130. batch2 = convert_ma_batch_to_sample_batch(ev.sample())
  131. unroll_ids_1 = set(batch1["unroll_id"])
  132. unroll_ids_2 = set(batch2["unroll_id"])
  133. # Assert no overlap of unroll IDs between sample() calls.
  134. self.assertTrue(not any(uid in unroll_ids_2 for uid in unroll_ids_1))
  135. # CartPole episodes should be short initially: Expect more than one
  136. # unroll ID in each batch.
  137. self.assertTrue(len(unroll_ids_1) > 1)
  138. self.assertTrue(len(unroll_ids_2) > 1)
  139. ev.stop()
  140. def test_global_vars_update(self):
  141. config = (
  142. A2CConfig()
  143. .environment("CartPole-v1")
  144. .rollouts(num_envs_per_worker=1)
  145. # lr = 0.1 - [(0.1 - 0.000001) / 100000] * ts
  146. .training(lr_schedule=[[0, 0.1], [100000, 0.000001]])
  147. )
  148. for fw in framework_iterator(config, frameworks=("tf2", "tf")):
  149. algo = config.build()
  150. policy = algo.get_policy()
  151. for i in range(3):
  152. result = algo.train()
  153. print(
  154. "{}={}".format(
  155. NUM_AGENT_STEPS_TRAINED, result["info"][NUM_AGENT_STEPS_TRAINED]
  156. )
  157. )
  158. print(
  159. "{}={}".format(
  160. NUM_AGENT_STEPS_SAMPLED, result["info"][NUM_AGENT_STEPS_SAMPLED]
  161. )
  162. )
  163. global_timesteps = (
  164. policy.global_timestep
  165. if fw == "tf"
  166. else policy.global_timestep.numpy()
  167. )
  168. print("global_timesteps={}".format(global_timesteps))
  169. expected_lr = 0.1 - ((0.1 - 0.000001) / 100000) * global_timesteps
  170. lr = policy.cur_lr
  171. if fw == "tf":
  172. lr = policy.get_session().run(lr)
  173. check(lr, expected_lr, rtol=0.05)
  174. algo.stop()
  175. def test_no_step_on_init(self):
  176. register_env("fail", lambda _: FailOnStepEnv())
  177. config = PGConfig().environment("fail").rollouts(num_rollout_workers=2)
  178. for _ in framework_iterator(config):
  179. # We expect this to fail already on Algorithm init due
  180. # to the env sanity check right after env creation (inside
  181. # RolloutWorker).
  182. self.assertRaises(
  183. Exception,
  184. lambda: config.build(),
  185. )
  186. def test_query_evaluators(self):
  187. register_env("test", lambda _: gym.make("CartPole-v1"))
  188. config = (
  189. PGConfig()
  190. .environment("test")
  191. .rollouts(
  192. num_rollout_workers=2,
  193. num_envs_per_worker=2,
  194. create_env_on_local_worker=True,
  195. )
  196. .training(train_batch_size=20)
  197. )
  198. for _ in framework_iterator(config, frameworks=("torch", "tf")):
  199. pg = config.build()
  200. results = pg.workers.foreach_worker(
  201. lambda w: w.total_rollout_fragment_length
  202. )
  203. results2 = pg.workers.foreach_worker_with_id(
  204. lambda i, w: (i, w.total_rollout_fragment_length)
  205. )
  206. results3 = pg.workers.foreach_worker(lambda w: w.foreach_env(lambda env: 1))
  207. self.assertEqual(results, [10, 10, 10])
  208. self.assertEqual(results2, [(0, 10), (1, 10), (2, 10)])
  209. self.assertEqual(results3, [[1, 1], [1, 1], [1, 1]])
  210. pg.stop()
  211. def test_action_clipping(self):
  212. from ray.rllib.examples.env.random_env import RandomEnv
  213. action_space = gym.spaces.Box(-2.0, 1.0, (3,))
  214. # Clipping: True (clip between Policy's action_space.low/high).
  215. ev = RolloutWorker(
  216. env_creator=lambda _: RandomEnv(
  217. config=dict(
  218. action_space=action_space,
  219. max_episode_len=10,
  220. p_terminated=0.0,
  221. check_action_bounds=True,
  222. )
  223. ),
  224. config=AlgorithmConfig()
  225. .multi_agent(
  226. policies={
  227. "default_policy": PolicySpec(
  228. policy_class=RandomPolicy,
  229. config={"ignore_action_bounds": True},
  230. )
  231. }
  232. )
  233. .rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
  234. .environment(
  235. action_space=action_space, normalize_actions=False, clip_actions=True
  236. ),
  237. )
  238. sample = convert_ma_batch_to_sample_batch(ev.sample())
  239. # Check, whether the action bounds have been breached (expected).
  240. # We still arrived here b/c we clipped according to the Env's action
  241. # space.
  242. self.assertGreater(np.max(sample["actions"]), action_space.high[0])
  243. self.assertLess(np.min(sample["actions"]), action_space.low[0])
  244. ev.stop()
  245. # Clipping: False and RandomPolicy produces invalid actions.
  246. # Expect Env to complain.
  247. ev2 = RolloutWorker(
  248. env_creator=lambda _: RandomEnv(
  249. config=dict(
  250. action_space=action_space,
  251. max_episode_len=10,
  252. p_terminated=0.0,
  253. check_action_bounds=True,
  254. )
  255. ),
  256. # No normalization (+clipping) and no clipping ->
  257. # Should lead to Env complaining.
  258. config=AlgorithmConfig()
  259. .environment(
  260. normalize_actions=False,
  261. clip_actions=False,
  262. action_space=action_space,
  263. )
  264. .rollouts(batch_mode="complete_episodes", num_rollout_workers=0)
  265. .multi_agent(
  266. policies={
  267. "default_policy": PolicySpec(
  268. policy_class=RandomPolicy,
  269. config={"ignore_action_bounds": True},
  270. )
  271. }
  272. ),
  273. )
  274. self.assertRaisesRegex(ValueError, r"Illegal action", ev2.sample)
  275. ev2.stop()
  276. # Clipping: False and RandomPolicy produces valid (bounded) actions.
  277. # Expect "actions" in SampleBatch to be unclipped.
  278. ev3 = RolloutWorker(
  279. env_creator=lambda _: RandomEnv(
  280. config=dict(
  281. action_space=action_space,
  282. max_episode_len=10,
  283. p_terminated=0.0,
  284. check_action_bounds=True,
  285. )
  286. ),
  287. default_policy_class=RandomPolicy,
  288. config=AlgorithmConfig().rollouts(
  289. num_rollout_workers=0, batch_mode="complete_episodes"
  290. )
  291. # Should not be a problem as RandomPolicy abides to bounds.
  292. .environment(
  293. action_space=action_space, normalize_actions=False, clip_actions=False
  294. ),
  295. )
  296. sample = convert_ma_batch_to_sample_batch(ev3.sample())
  297. self.assertGreater(np.min(sample["actions"]), action_space.low[0])
  298. self.assertLess(np.max(sample["actions"]), action_space.high[0])
  299. ev3.stop()
  300. def test_action_normalization(self):
  301. from ray.rllib.examples.env.random_env import RandomEnv
  302. action_space = gym.spaces.Box(0.0001, 0.0002, (5,))
  303. # Normalize: True (unsquash between Policy's action_space.low/high).
  304. ev = RolloutWorker(
  305. env_creator=lambda _: RandomEnv(
  306. config=dict(
  307. action_space=action_space,
  308. max_episode_len=10,
  309. p_terminated=0.0,
  310. check_action_bounds=True,
  311. )
  312. ),
  313. config=AlgorithmConfig()
  314. .multi_agent(
  315. policies={
  316. "default_policy": PolicySpec(
  317. policy_class=RandomPolicy,
  318. config={"ignore_action_bounds": True},
  319. )
  320. }
  321. )
  322. .rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
  323. .environment(
  324. action_space=action_space, normalize_actions=True, clip_actions=False
  325. ),
  326. )
  327. sample = convert_ma_batch_to_sample_batch(ev.sample())
  328. # Check, whether the action bounds have been breached (expected).
  329. # We still arrived here b/c we unsquashed according to the Env's action
  330. # space.
  331. self.assertGreater(np.max(sample["actions"]), action_space.high[0])
  332. self.assertLess(np.min(sample["actions"]), action_space.low[0])
  333. ev.stop()
  334. def test_action_normalization_offline_dataset(self):
  335. with tempfile.TemporaryDirectory() as tmp_dir:
  336. # create environment
  337. env = gym.make("Pendulum-v1")
  338. # create temp data with actions at min and max
  339. data = {
  340. "type": "SampleBatch",
  341. "actions": [[2.0], [-2.0]],
  342. "terminateds": [0.0, 0.0],
  343. "truncateds": [0.0, 0.0],
  344. "rewards": [0.0, 0.0],
  345. "obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
  346. "new_obs": [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
  347. }
  348. data_file = os.path.join(tmp_dir, "data.json")
  349. with open(data_file, "w") as f:
  350. json.dump(data, f)
  351. # create input reader functions
  352. def dataset_reader_creator(ioctx):
  353. config = AlgorithmConfig().offline_data(
  354. input_="dataset",
  355. input_config={"format": "json", "paths": data_file},
  356. )
  357. _, shards = get_dataset_and_shards(config, num_workers=0)
  358. return DatasetReader(shards[0], ioctx)
  359. def json_reader_creator(ioctx):
  360. return JsonReader(data_file, ioctx)
  361. input_creators = [dataset_reader_creator, json_reader_creator]
  362. # actions_in_input_normalized, normalize_actions
  363. parameters = [
  364. (True, True),
  365. (True, False),
  366. (False, True),
  367. (False, False),
  368. ]
  369. # check that samples from dataset will be normalized if and only if
  370. # actions_in_input_normalized == False and
  371. # normalize_actions == True
  372. for input_creator in input_creators:
  373. for actions_in_input_normalized, normalize_actions in parameters:
  374. ev = RolloutWorker(
  375. env_creator=lambda _: env,
  376. default_policy_class=MockPolicy,
  377. config=AlgorithmConfig()
  378. .rollouts(
  379. num_rollout_workers=0,
  380. rollout_fragment_length=1,
  381. )
  382. .environment(
  383. normalize_actions=normalize_actions,
  384. clip_actions=False,
  385. )
  386. .training(train_batch_size=1)
  387. .offline_data(
  388. offline_sampling=True,
  389. actions_in_input_normalized=actions_in_input_normalized,
  390. input_=input_creator,
  391. ),
  392. )
  393. sample = ev.sample()
  394. if normalize_actions and not actions_in_input_normalized:
  395. # check if the samples from dataset are normalized properly
  396. self.assertLessEqual(np.max(sample["actions"]), 1.0)
  397. self.assertGreaterEqual(np.min(sample["actions"]), -1.0)
  398. else:
  399. # check if the samples from dataset are not normalized
  400. self.assertGreater(np.max(sample["actions"]), 1.5)
  401. self.assertLess(np.min(sample["actions"]), -1.5)
  402. ev.stop()
  403. def test_action_immutability(self):
  404. from ray.rllib.examples.env.random_env import RandomEnv
  405. action_space = gym.spaces.Box(0.0001, 0.0002, (5,))
  406. class ActionMutationEnv(RandomEnv):
  407. def init(self, config):
  408. self.test_case = config["test_case"]
  409. super().__init__(config=config)
  410. def step(self, action):
  411. # Ensure that it is called from inside the sampling process.
  412. import inspect
  413. curframe = inspect.currentframe()
  414. called_from_check = any(
  415. frame[3] == "check_gym_environments"
  416. for frame in inspect.getouterframes(curframe, 2)
  417. )
  418. # Check, whether the action is immutable.
  419. if action.flags.writeable and not called_from_check:
  420. self.test_case.assertFalse(
  421. action.flags.writeable, "Action is mutable"
  422. )
  423. return super().step(action)
  424. ev = RolloutWorker(
  425. env_creator=lambda _: ActionMutationEnv(
  426. config=dict(
  427. test_case=self,
  428. action_space=action_space,
  429. max_episode_len=10,
  430. p_terminated=0.0,
  431. check_action_bounds=True,
  432. )
  433. ),
  434. config=AlgorithmConfig()
  435. .multi_agent(
  436. policies={
  437. "default_policy": PolicySpec(
  438. policy_class=RandomPolicy,
  439. config={"ignore_action_bounds": True},
  440. )
  441. }
  442. )
  443. .environment(action_space=action_space, clip_actions=False)
  444. .rollouts(batch_mode="complete_episodes", num_rollout_workers=0),
  445. )
  446. ev.sample()
  447. ev.stop()
  448. def test_reward_clipping(self):
  449. # Clipping: True (clip between -1.0 and 1.0).
  450. config = (
  451. AlgorithmConfig()
  452. .rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
  453. .environment(clip_rewards=True)
  454. )
  455. ev = RolloutWorker(
  456. env_creator=lambda _: MockEnv2(episode_length=10),
  457. default_policy_class=MockPolicy,
  458. config=config,
  459. )
  460. sample = convert_ma_batch_to_sample_batch(ev.sample())
  461. ws = WorkerSet._from_existing(
  462. local_worker=ev,
  463. remote_workers=[],
  464. )
  465. self.assertEqual(max(sample["rewards"]), 1)
  466. result = collect_metrics(ws, [])
  467. # Shows different behavior when connector is on/off.
  468. if config.enable_connectors:
  469. # episode_reward_mean shows the correct clipped value.
  470. self.assertEqual(result["episode_reward_mean"], 10)
  471. else:
  472. # episode_reward_mean shows the unclipped raw value
  473. # when connector is off, and old env_runner v1 is used.
  474. self.assertEqual(result["episode_reward_mean"], 1000)
  475. ev.stop()
  476. from ray.rllib.examples.env.random_env import RandomEnv
  477. # Clipping in certain range (-2.0, 2.0).
  478. ev2 = RolloutWorker(
  479. env_creator=lambda _: RandomEnv(
  480. dict(
  481. reward_space=gym.spaces.Box(low=-10, high=10, shape=()),
  482. p_terminated=0.0,
  483. max_episode_len=10,
  484. )
  485. ),
  486. default_policy_class=MockPolicy,
  487. config=AlgorithmConfig()
  488. .rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
  489. .environment(clip_rewards=2.0),
  490. )
  491. sample = convert_ma_batch_to_sample_batch(ev2.sample())
  492. self.assertEqual(max(sample["rewards"]), 2.0)
  493. self.assertEqual(min(sample["rewards"]), -2.0)
  494. self.assertLess(np.mean(sample["rewards"]), 0.5)
  495. self.assertGreater(np.mean(sample["rewards"]), -0.5)
  496. ev2.stop()
  497. # Clipping: Off.
  498. ev2 = RolloutWorker(
  499. env_creator=lambda _: MockEnv2(episode_length=10),
  500. default_policy_class=MockPolicy,
  501. config=AlgorithmConfig()
  502. .rollouts(num_rollout_workers=0, batch_mode="complete_episodes")
  503. .environment(clip_rewards=False),
  504. )
  505. sample = convert_ma_batch_to_sample_batch(ev2.sample())
  506. ws2 = WorkerSet._from_existing(
  507. local_worker=ev2,
  508. remote_workers=[],
  509. )
  510. self.assertEqual(max(sample["rewards"]), 100)
  511. result2 = collect_metrics(ws2, [])
  512. self.assertEqual(result2["episode_reward_mean"], 1000)
  513. ev2.stop()
  514. def test_metrics(self):
  515. ev = RolloutWorker(
  516. env_creator=lambda _: MockEnv(episode_length=10),
  517. default_policy_class=MockPolicy,
  518. config=AlgorithmConfig().rollouts(
  519. rollout_fragment_length=100,
  520. num_rollout_workers=0,
  521. batch_mode="complete_episodes",
  522. ),
  523. )
  524. remote_ev = ray.remote(RolloutWorker).remote(
  525. env_creator=lambda _: MockEnv(episode_length=10),
  526. default_policy_class=MockPolicy,
  527. config=AlgorithmConfig().rollouts(
  528. rollout_fragment_length=100,
  529. num_rollout_workers=0,
  530. batch_mode="complete_episodes",
  531. ),
  532. )
  533. ws = WorkerSet._from_existing(
  534. local_worker=ev,
  535. remote_workers=[remote_ev],
  536. )
  537. ev.sample()
  538. ray.get(remote_ev.sample.remote())
  539. result = collect_metrics(ws)
  540. self.assertEqual(result["episodes_this_iter"], 20)
  541. self.assertEqual(result["episode_reward_mean"], 10)
  542. ev.stop()
  543. def test_async(self):
  544. ev = RolloutWorker(
  545. env_creator=lambda _: gym.make("CartPole-v1"),
  546. default_policy_class=MockPolicy,
  547. config=AlgorithmConfig().rollouts(sample_async=True, num_rollout_workers=0),
  548. )
  549. batch = convert_ma_batch_to_sample_batch(ev.sample())
  550. for key in [
  551. "obs",
  552. "actions",
  553. "rewards",
  554. "terminateds",
  555. "truncateds",
  556. "advantages",
  557. ]:
  558. self.assertIn(key, batch)
  559. self.assertGreater(batch["advantages"][0], 1)
  560. ev.stop()
  561. def test_auto_vectorization(self):
  562. ev = RolloutWorker(
  563. env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
  564. default_policy_class=MockPolicy,
  565. config=AlgorithmConfig().rollouts(
  566. rollout_fragment_length=2,
  567. num_envs_per_worker=8,
  568. num_rollout_workers=0,
  569. batch_mode="truncate_episodes",
  570. ),
  571. )
  572. ws = WorkerSet._from_existing(
  573. local_worker=ev,
  574. remote_workers=[],
  575. )
  576. for _ in range(8):
  577. batch = ev.sample()
  578. self.assertEqual(batch.count, 16)
  579. result = collect_metrics(ws, [])
  580. self.assertEqual(result["episodes_this_iter"], 0)
  581. for _ in range(8):
  582. batch = ev.sample()
  583. self.assertEqual(batch.count, 16)
  584. result = collect_metrics(ws, [])
  585. self.assertEqual(result["episodes_this_iter"], 8)
  586. indices = []
  587. for env in ev.async_env.vector_env.envs:
  588. self.assertEqual(env.unwrapped.config.worker_index, 0)
  589. indices.append(env.unwrapped.config.vector_index)
  590. self.assertEqual(indices, [0, 1, 2, 3, 4, 5, 6, 7])
  591. ev.stop()
  592. def test_batches_larger_when_vectorized(self):
  593. ev = RolloutWorker(
  594. env_creator=lambda _: MockEnv(episode_length=8),
  595. default_policy_class=MockPolicy,
  596. config=AlgorithmConfig().rollouts(
  597. rollout_fragment_length=4,
  598. num_envs_per_worker=4,
  599. num_rollout_workers=0,
  600. batch_mode="truncate_episodes",
  601. ),
  602. )
  603. ws = WorkerSet._from_existing(
  604. local_worker=ev,
  605. remote_workers=[],
  606. )
  607. batch = ev.sample()
  608. self.assertEqual(batch.count, 16)
  609. result = collect_metrics(ws, [])
  610. self.assertEqual(result["episodes_this_iter"], 0)
  611. batch = ev.sample()
  612. result = collect_metrics(ws, [])
  613. self.assertEqual(result["episodes_this_iter"], 4)
  614. ev.stop()
  615. def test_vector_env_support(self):
  616. # Test a vector env that contains 8 actual envs
  617. # (MockEnv instances).
  618. ev = RolloutWorker(
  619. env_creator=(lambda _: VectorizedMockEnv(episode_length=20, num_envs=8)),
  620. default_policy_class=MockPolicy,
  621. config=AlgorithmConfig().rollouts(
  622. rollout_fragment_length=10,
  623. num_rollout_workers=0,
  624. batch_mode="truncate_episodes",
  625. ),
  626. )
  627. ws = WorkerSet._from_existing(
  628. local_worker=ev,
  629. remote_workers=[],
  630. )
  631. for _ in range(8):
  632. batch = ev.sample()
  633. self.assertEqual(batch.count, 10)
  634. result = collect_metrics(ws, [])
  635. self.assertEqual(result["episodes_this_iter"], 0)
  636. for _ in range(8):
  637. batch = ev.sample()
  638. self.assertEqual(batch.count, 10)
  639. result = collect_metrics(ws, [])
  640. self.assertEqual(result["episodes_this_iter"], 8)
  641. ev.stop()
  642. # Test a vector env that pretends(!) to contain 4 envs, but actually
  643. # only has 1 (CartPole).
  644. ev = RolloutWorker(
  645. env_creator=(lambda _: MockVectorEnv(20, mocked_num_envs=4)),
  646. default_policy_class=MockPolicy,
  647. config=AlgorithmConfig().rollouts(
  648. rollout_fragment_length=10,
  649. num_rollout_workers=0,
  650. batch_mode="truncate_episodes",
  651. ),
  652. )
  653. ws = WorkerSet._from_existing(
  654. local_worker=ev,
  655. remote_workers=[],
  656. )
  657. for _ in range(8):
  658. batch = ev.sample()
  659. self.assertEqual(batch.count, 10)
  660. result = collect_metrics(ws, [])
  661. self.assertGreater(result["episodes_this_iter"], 3)
  662. for _ in range(8):
  663. batch = ev.sample()
  664. self.assertEqual(batch.count, 10)
  665. result = collect_metrics(ws, [])
  666. self.assertGreater(result["episodes_this_iter"], 6)
  667. ev.stop()
  668. def test_truncate_episodes(self):
  669. ev_env_steps = RolloutWorker(
  670. env_creator=lambda _: MockEnv(10),
  671. default_policy_class=MockPolicy,
  672. config=AlgorithmConfig().rollouts(
  673. rollout_fragment_length=15,
  674. num_rollout_workers=0,
  675. batch_mode="truncate_episodes",
  676. ),
  677. )
  678. batch = ev_env_steps.sample()
  679. self.assertEqual(batch.count, 15)
  680. self.assertTrue(issubclass(type(batch), (SampleBatch, MultiAgentBatch)))
  681. ev_env_steps.stop()
  682. action_space = Discrete(2)
  683. obs_space = Box(float("-inf"), float("inf"), (4,), dtype=np.float32)
  684. ev_agent_steps = RolloutWorker(
  685. env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
  686. default_policy_class=MockPolicy,
  687. config=AlgorithmConfig()
  688. .rollouts(
  689. num_rollout_workers=0,
  690. batch_mode="truncate_episodes",
  691. rollout_fragment_length=301,
  692. )
  693. .multi_agent(
  694. policies={"pol0", "pol1"},
  695. policy_mapping_fn=(
  696. lambda agent_id, episode, worker, **kwargs: "pol0"
  697. if agent_id == 0
  698. else "pol1"
  699. ),
  700. )
  701. .environment(action_space=action_space, observation_space=obs_space),
  702. )
  703. batch = ev_agent_steps.sample()
  704. self.assertTrue(isinstance(batch, MultiAgentBatch))
  705. self.assertGreater(batch.agent_steps(), 301)
  706. self.assertEqual(batch.env_steps(), 301)
  707. ev_agent_steps.stop()
  708. ev_agent_steps = RolloutWorker(
  709. env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}),
  710. default_policy_class=MockPolicy,
  711. config=AlgorithmConfig()
  712. .rollouts(
  713. num_rollout_workers=0,
  714. rollout_fragment_length=301,
  715. )
  716. .multi_agent(
  717. count_steps_by="agent_steps",
  718. policies={"pol0", "pol1"},
  719. policy_mapping_fn=(
  720. lambda agent_id, episode, worker, **kwargs: "pol0"
  721. if agent_id == 0
  722. else "pol1"
  723. ),
  724. ),
  725. )
  726. batch = ev_agent_steps.sample()
  727. self.assertTrue(isinstance(batch, MultiAgentBatch))
  728. self.assertLess(batch.env_steps(), 301)
  729. # When counting agent steps, the count may be slightly larger than
  730. # rollout_fragment_length, b/c we have up to N agents stepping in each
  731. # env step and we only check, whether we should build after each env
  732. # step.
  733. self.assertGreaterEqual(batch.agent_steps(), 301)
  734. ev_agent_steps.stop()
  735. def test_complete_episodes(self):
  736. ev = RolloutWorker(
  737. env_creator=lambda _: MockEnv(10),
  738. default_policy_class=MockPolicy,
  739. config=AlgorithmConfig().rollouts(
  740. rollout_fragment_length=5,
  741. num_rollout_workers=0,
  742. batch_mode="complete_episodes",
  743. ),
  744. )
  745. batch = ev.sample()
  746. self.assertEqual(batch.count, 10)
  747. ev.stop()
  748. def test_complete_episodes_packing(self):
  749. ev = RolloutWorker(
  750. env_creator=lambda _: MockEnv(10),
  751. default_policy_class=MockPolicy,
  752. config=AlgorithmConfig().rollouts(
  753. rollout_fragment_length=15,
  754. num_rollout_workers=0,
  755. batch_mode="complete_episodes",
  756. ),
  757. )
  758. batch = ev.sample()
  759. batch = convert_ma_batch_to_sample_batch(batch)
  760. self.assertEqual(batch.count, 20)
  761. self.assertEqual(
  762. batch["t"].tolist(),
  763. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
  764. )
  765. ev.stop()
  766. def test_filter_sync(self):
  767. ev = RolloutWorker(
  768. env_creator=lambda _: gym.make("CartPole-v1"),
  769. default_policy_class=MockPolicy,
  770. config=AlgorithmConfig().rollouts(
  771. sample_async=True,
  772. num_rollout_workers=0,
  773. observation_filter="ConcurrentMeanStdFilter",
  774. ),
  775. )
  776. time.sleep(2)
  777. ev.sample()
  778. filters = ev.get_filters(flush_after=True)
  779. obs_f = filters[DEFAULT_POLICY_ID]
  780. self.assertNotEqual(obs_f.running_stats.n, 0)
  781. self.assertNotEqual(obs_f.buffer.n, 0)
  782. ev.stop()
  783. def test_get_filters(self):
  784. ev = RolloutWorker(
  785. env_creator=lambda _: gym.make("CartPole-v1"),
  786. default_policy_class=MockPolicy,
  787. config=AlgorithmConfig().rollouts(
  788. observation_filter="ConcurrentMeanStdFilter",
  789. num_rollout_workers=0,
  790. sample_async=True,
  791. ),
  792. )
  793. self.sample_and_flush(ev)
  794. filters = ev.get_filters(flush_after=False)
  795. time.sleep(2)
  796. filters2 = ev.get_filters(flush_after=False)
  797. obs_f = filters[DEFAULT_POLICY_ID]
  798. obs_f2 = filters2[DEFAULT_POLICY_ID]
  799. self.assertGreaterEqual(obs_f2.running_stats.n, obs_f.running_stats.n)
  800. self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
  801. ev.stop()
  802. def test_sync_filter(self):
  803. ev = RolloutWorker(
  804. env_creator=lambda _: gym.make("CartPole-v1"),
  805. default_policy_class=MockPolicy,
  806. config=AlgorithmConfig().rollouts(
  807. observation_filter="ConcurrentMeanStdFilter",
  808. num_rollout_workers=0,
  809. sample_async=True,
  810. ),
  811. )
  812. obs_f = self.sample_and_flush(ev)
  813. # Current State
  814. filters = ev.get_filters(flush_after=False)
  815. obs_f = filters[DEFAULT_POLICY_ID]
  816. self.assertLessEqual(obs_f.buffer.n, 20)
  817. new_obsf = obs_f.copy()
  818. new_obsf.running_stats.num_pushes = 100
  819. ev.sync_filters({DEFAULT_POLICY_ID: new_obsf})
  820. filters = ev.get_filters(flush_after=False)
  821. obs_f = filters[DEFAULT_POLICY_ID]
  822. self.assertGreaterEqual(obs_f.running_stats.n, 100)
  823. self.assertLessEqual(obs_f.buffer.n, 20)
  824. ev.stop()
  825. def test_extra_python_envs(self):
  826. extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"}
  827. self.assertFalse("env_key_1" in os.environ)
  828. self.assertFalse("env_key_2" in os.environ)
  829. ev = RolloutWorker(
  830. env_creator=lambda _: MockEnv(10),
  831. default_policy_class=MockPolicy,
  832. config=AlgorithmConfig()
  833. .python_environment(extra_python_environs_for_driver=extra_envs)
  834. .rollouts(num_rollout_workers=0),
  835. )
  836. self.assertTrue("env_key_1" in os.environ)
  837. self.assertTrue("env_key_2" in os.environ)
  838. ev.stop()
  839. # reset to original
  840. del os.environ["env_key_1"]
  841. del os.environ["env_key_2"]
  842. def test_no_env_seed(self):
  843. ev = RolloutWorker(
  844. env_creator=lambda _: MockVectorEnv(20, mocked_num_envs=8),
  845. default_policy_class=MockPolicy,
  846. config=AlgorithmConfig().rollouts(num_rollout_workers=0).debugging(seed=1),
  847. )
  848. assert not hasattr(ev.env, "seed")
  849. ev.stop()
  850. def test_multi_env_seed(self):
  851. ev = RolloutWorker(
  852. env_creator=lambda _: MockEnv2(100),
  853. default_policy_class=MockPolicy,
  854. config=AlgorithmConfig()
  855. .rollouts(num_envs_per_worker=3, num_rollout_workers=0)
  856. .debugging(seed=1),
  857. )
  858. # Make sure we can properly sample from the wrapped env.
  859. ev.sample()
  860. # Make sure all environments got a different deterministic seed.
  861. seeds = ev.foreach_env(lambda env: env.rng_seed)
  862. self.assertEqual(seeds, [1, 2, 3])
  863. ev.stop()
  864. def test_determine_spaces_for_multi_agent_dict(self):
  865. class MockMultiAgentEnv(MultiAgentEnv):
  866. """A mock testing MultiAgentEnv that doesn't call super.__init__()."""
  867. def __init__(self):
  868. # Intentinoally don't call super().__init__(),
  869. # so this env doesn't have
  870. # `self._[action|observation]_space_in_preferred_format`attributes.
  871. self.observation_space = gym.spaces.Discrete(2)
  872. self.action_space = gym.spaces.Discrete(2)
  873. def reset(self, *, seed=None, options=None):
  874. pass
  875. def step(self, action_dict):
  876. obs = {1: [0, 0], 2: [1, 1]}
  877. rewards = {1: 0, 2: 0}
  878. terminateds = truncated = {1: False, 2: False, "__all__": False}
  879. infos = {1: {}, 2: {}}
  880. return obs, rewards, terminateds, truncated, infos
  881. ev = RolloutWorker(
  882. env_creator=lambda _: MockMultiAgentEnv(),
  883. default_policy_class=MockPolicy,
  884. config=AlgorithmConfig()
  885. .rollouts(num_envs_per_worker=3, num_rollout_workers=0)
  886. .multi_agent(policies={"policy_1", "policy_2"})
  887. .debugging(seed=1),
  888. )
  889. # The fact that this RolloutWorker can be created without throwing
  890. # exceptions means AlgorithmConfig.get_multi_agent_setup() is
  891. # handling multi-agent user environments properly.
  892. self.assertIsNotNone(ev)
  893. def test_wrap_multi_agent_env(self):
  894. ev = RolloutWorker(
  895. env_creator=lambda _: BasicMultiAgent(10),
  896. default_policy_class=MockPolicy,
  897. config=AlgorithmConfig().rollouts(
  898. rollout_fragment_length=5,
  899. batch_mode="complete_episodes",
  900. num_rollout_workers=0,
  901. ),
  902. )
  903. # Make sure we can properly sample from the wrapped env.
  904. ev.sample()
  905. # Make sure the resulting environment is indeed still an
  906. self.assertTrue(isinstance(ev.env.unwrapped, MultiAgentEnv))
  907. self.assertTrue(isinstance(ev.env, gym.Env))
  908. ev.stop()
  909. def test_no_training(self):
  910. class NoTrainingEnv(MockEnv):
  911. def __init__(self, episode_length, training_enabled):
  912. super().__init__(episode_length)
  913. self.training_enabled = training_enabled
  914. def step(self, action):
  915. obs, rew, terminated, truncated, info = super().step(action)
  916. return (
  917. obs,
  918. rew,
  919. terminated,
  920. truncated,
  921. {**info, "training_enabled": self.training_enabled},
  922. )
  923. ev = RolloutWorker(
  924. env_creator=lambda _: NoTrainingEnv(10, True),
  925. default_policy_class=MockPolicy,
  926. config=AlgorithmConfig().rollouts(
  927. rollout_fragment_length=5,
  928. batch_mode="complete_episodes",
  929. num_rollout_workers=0,
  930. ),
  931. )
  932. batch = ev.sample()
  933. batch = convert_ma_batch_to_sample_batch(batch)
  934. self.assertEqual(batch.count, 10)
  935. self.assertEqual(len(batch["obs"]), 10)
  936. ev.stop()
  937. ev = RolloutWorker(
  938. env_creator=lambda _: NoTrainingEnv(10, False),
  939. default_policy_class=MockPolicy,
  940. config=AlgorithmConfig().rollouts(
  941. rollout_fragment_length=5,
  942. batch_mode="complete_episodes",
  943. num_rollout_workers=0,
  944. ),
  945. )
  946. batch = ev.sample()
  947. self.assertTrue(isinstance(batch, MultiAgentBatch))
  948. self.assertEqual(len(batch.policy_batches), 0)
  949. ev.stop()
  950. def sample_and_flush(self, ev):
  951. time.sleep(2)
  952. ev.sample()
  953. filters = ev.get_filters(flush_after=True)
  954. obs_f = filters[DEFAULT_POLICY_ID]
  955. self.assertNotEqual(obs_f.running_stats.n, 0)
  956. self.assertNotEqual(obs_f.buffer.n, 0)
  957. return obs_f
  958. if __name__ == "__main__":
  959. import pytest
  960. import sys
  961. sys.exit(pytest.main(["-v", __file__]))