test_trajectory_view_api.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622
  1. import copy
  2. import gym
  3. from gym.spaces import Box, Discrete
  4. import numpy as np
  5. import unittest
  6. import ray
  7. from ray.rllib.agents.callbacks import DefaultCallbacks
  8. import ray.rllib.agents.dqn as dqn
  9. import ray.rllib.agents.ppo as ppo
  10. from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
  11. from ray.rllib.examples.env.multi_agent import MultiAgentPendulum
  12. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  13. from ray.rllib.examples.policy.episode_env_aware_policy import \
  14. EpisodeEnvAwareAttentionPolicy, EpisodeEnvAwareLSTMPolicy
  15. from ray.rllib.models.tf.attention_net import GTrXLNet
  16. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  17. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  18. from ray.rllib.policy.view_requirement import ViewRequirement
  19. from ray.rllib.utils.annotations import override
  20. from ray.rllib.utils.test_utils import framework_iterator, check
  21. class MyCallbacks(DefaultCallbacks):
  22. @override(DefaultCallbacks)
  23. def on_learn_on_batch(self, *, policy, train_batch, result, **kwargs):
  24. assert train_batch.count == 201
  25. assert sum(train_batch[SampleBatch.SEQ_LENS]) == 201
  26. for k, v in train_batch.items():
  27. if k in ["state_in_0", SampleBatch.SEQ_LENS]:
  28. assert len(v) == len(train_batch[SampleBatch.SEQ_LENS])
  29. else:
  30. assert len(v) == 201
  31. current = None
  32. for o in train_batch[SampleBatch.OBS]:
  33. if current:
  34. assert o == current + 1
  35. current = o
  36. if o == 15:
  37. current = None
  38. class TestTrajectoryViewAPI(unittest.TestCase):
  39. @classmethod
  40. def setUpClass(cls) -> None:
  41. ray.init()
  42. @classmethod
  43. def tearDownClass(cls) -> None:
  44. ray.shutdown()
  45. def test_traj_view_normal_case(self):
  46. """Tests, whether Model and Policy return the correct ViewRequirements.
  47. """
  48. config = dqn.DEFAULT_CONFIG.copy()
  49. config["num_envs_per_worker"] = 10
  50. config["rollout_fragment_length"] = 4
  51. for _ in framework_iterator(config):
  52. trainer = dqn.DQNTrainer(
  53. config,
  54. env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv")
  55. policy = trainer.get_policy()
  56. view_req_model = policy.model.view_requirements
  57. view_req_policy = policy.view_requirements
  58. assert len(view_req_model) == 1, view_req_model
  59. assert len(view_req_policy) == 10, view_req_policy
  60. for key in [
  61. SampleBatch.OBS,
  62. SampleBatch.ACTIONS,
  63. SampleBatch.REWARDS,
  64. SampleBatch.DONES,
  65. SampleBatch.NEXT_OBS,
  66. SampleBatch.EPS_ID,
  67. SampleBatch.AGENT_INDEX,
  68. "weights",
  69. ]:
  70. assert key in view_req_policy
  71. # None of the view cols has a special underlying data_col,
  72. # except next-obs.
  73. if key != SampleBatch.NEXT_OBS:
  74. assert view_req_policy[key].data_col is None
  75. else:
  76. assert view_req_policy[key].data_col == SampleBatch.OBS
  77. assert view_req_policy[key].shift == 1
  78. rollout_worker = trainer.workers.local_worker()
  79. sample_batch = rollout_worker.sample()
  80. expected_count = \
  81. config["num_envs_per_worker"] * \
  82. config["rollout_fragment_length"]
  83. assert sample_batch.count == expected_count
  84. for v in sample_batch.values():
  85. assert len(v) == expected_count
  86. trainer.stop()
  87. def test_traj_view_lstm_prev_actions_and_rewards(self):
  88. """Tests, whether Policy/Model return correct LSTM ViewRequirements.
  89. """
  90. config = ppo.DEFAULT_CONFIG.copy()
  91. config["model"] = config["model"].copy()
  92. # Activate LSTM + prev-action + rewards.
  93. config["model"]["use_lstm"] = True
  94. config["model"]["lstm_use_prev_action"] = True
  95. config["model"]["lstm_use_prev_reward"] = True
  96. for _ in framework_iterator(config):
  97. trainer = ppo.PPOTrainer(config, env="CartPole-v0")
  98. policy = trainer.get_policy()
  99. view_req_model = policy.model.view_requirements
  100. view_req_policy = policy.view_requirements
  101. # 7=obs, prev-a + r, 2x state-in, 2x state-out.
  102. assert len(view_req_model) == 7, view_req_model
  103. assert len(view_req_policy) == 20,\
  104. (len(view_req_policy), view_req_policy)
  105. for key in [
  106. SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS,
  107. SampleBatch.DONES, SampleBatch.NEXT_OBS,
  108. SampleBatch.VF_PREDS, SampleBatch.PREV_ACTIONS,
  109. SampleBatch.PREV_REWARDS, "advantages", "value_targets",
  110. SampleBatch.ACTION_DIST_INPUTS, SampleBatch.ACTION_LOGP
  111. ]:
  112. assert key in view_req_policy
  113. if key == SampleBatch.PREV_ACTIONS:
  114. assert view_req_policy[key].data_col == SampleBatch.ACTIONS
  115. assert view_req_policy[key].shift == -1
  116. elif key == SampleBatch.PREV_REWARDS:
  117. assert view_req_policy[key].data_col == SampleBatch.REWARDS
  118. assert view_req_policy[key].shift == -1
  119. elif key not in [
  120. SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
  121. SampleBatch.PREV_REWARDS
  122. ]:
  123. assert view_req_policy[key].data_col is None
  124. else:
  125. assert view_req_policy[key].data_col == SampleBatch.OBS
  126. assert view_req_policy[key].shift == 1
  127. trainer.stop()
  128. def test_traj_view_attention_net(self):
  129. config = ppo.DEFAULT_CONFIG.copy()
  130. # Setup attention net.
  131. config["model"] = config["model"].copy()
  132. config["model"]["max_seq_len"] = 50
  133. config["model"]["custom_model"] = GTrXLNet
  134. config["model"]["custom_model_config"] = {
  135. "num_transformer_units": 1,
  136. "attention_dim": 64,
  137. "num_heads": 2,
  138. "memory_inference": 50,
  139. "memory_training": 50,
  140. "head_dim": 32,
  141. "ff_hidden_dim": 32,
  142. }
  143. # Test with odd batch numbers.
  144. config["train_batch_size"] = 1031
  145. config["sgd_minibatch_size"] = 201
  146. config["num_sgd_iter"] = 5
  147. config["num_workers"] = 0
  148. config["callbacks"] = MyCallbacks
  149. config["env_config"] = {
  150. "config": {
  151. "start_at_t": 1
  152. }
  153. } # first obs is [1.0]
  154. for _ in framework_iterator(config, frameworks="tf2"):
  155. trainer = ppo.PPOTrainer(
  156. config,
  157. env="ray.rllib.examples.env.debug_counter_env.DebugCounterEnv",
  158. )
  159. rw = trainer.workers.local_worker()
  160. sample = rw.sample()
  161. assert sample.count == trainer.config["rollout_fragment_length"]
  162. results = trainer.train()
  163. assert results["timesteps_total"] == config["train_batch_size"]
  164. trainer.stop()
  165. def test_traj_view_next_action(self):
  166. action_space = Discrete(2)
  167. rollout_worker_w_api = RolloutWorker(
  168. env_creator=lambda _: gym.make("CartPole-v0"),
  169. policy_config=ppo.DEFAULT_CONFIG,
  170. rollout_fragment_length=200,
  171. policy_spec=ppo.PPOTorchPolicy,
  172. policy_mapping_fn=None,
  173. num_envs=1,
  174. )
  175. # Add the next action (a') and 2nd next action (a'') to the view
  176. # requirements of the policy.
  177. # This should be visible then in postprocessing and train batches.
  178. # Switch off for action computations (can't be there as we don't know
  179. # the next actions already at action computation time).
  180. rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
  181. "next_actions"] = ViewRequirement(
  182. SampleBatch.ACTIONS,
  183. shift=1,
  184. space=action_space,
  185. used_for_compute_actions=False)
  186. rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
  187. "2nd_next_actions"] = ViewRequirement(
  188. SampleBatch.ACTIONS,
  189. shift=2,
  190. space=action_space,
  191. used_for_compute_actions=False)
  192. # Make sure, we have DONEs as well.
  193. rollout_worker_w_api.policy_map[DEFAULT_POLICY_ID].view_requirements[
  194. "dones"] = ViewRequirement()
  195. batch = rollout_worker_w_api.sample()
  196. self.assertTrue("next_actions" in batch)
  197. self.assertTrue("2nd_next_actions" in batch)
  198. expected_a_ = None # expected next action
  199. expected_a__ = None # expected 2nd next action
  200. for i in range(len(batch["actions"])):
  201. a, d, a_, a__ = \
  202. batch["actions"][i], batch["dones"][i], \
  203. batch["next_actions"][i], batch["2nd_next_actions"][i]
  204. # Episode done: next action and 2nd next action should be 0.
  205. if d:
  206. check(a_, 0)
  207. check(a__, 0)
  208. expected_a_ = None
  209. expected_a__ = None
  210. continue
  211. # Episode is not done and we have an expected next-a.
  212. if expected_a_ is not None:
  213. check(a, expected_a_)
  214. if expected_a__ is not None:
  215. check(a_, expected_a__)
  216. expected_a__ = a__
  217. expected_a_ = a_
  218. def test_traj_view_lstm_functionality(self):
  219. action_space = Box(float("-inf"), float("inf"), shape=(3, ))
  220. obs_space = Box(float("-inf"), float("inf"), (4, ))
  221. max_seq_len = 50
  222. rollout_fragment_length = 200
  223. assert rollout_fragment_length % max_seq_len == 0
  224. policies = {
  225. "pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
  226. }
  227. def policy_fn(agent_id, episode, **kwargs):
  228. return "pol0"
  229. config = {
  230. "multiagent": {
  231. "policies": policies,
  232. "policy_mapping_fn": policy_fn,
  233. },
  234. "model": {
  235. "use_lstm": True,
  236. "max_seq_len": max_seq_len,
  237. },
  238. }
  239. rw = RolloutWorker(
  240. env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
  241. policy_config=config,
  242. rollout_fragment_length=rollout_fragment_length,
  243. policy_spec=policies,
  244. policy_mapping_fn=policy_fn,
  245. normalize_actions=False,
  246. num_envs=1,
  247. )
  248. for iteration in range(20):
  249. result = rw.sample()
  250. check(result.count, rollout_fragment_length)
  251. pol_batch_w = result.policy_batches["pol0"]
  252. assert pol_batch_w.count >= rollout_fragment_length
  253. analyze_rnn_batch(
  254. pol_batch_w,
  255. max_seq_len,
  256. view_requirements=rw.policy_map["pol0"].view_requirements)
  257. def test_traj_view_attention_functionality(self):
  258. action_space = Box(float("-inf"), float("inf"), shape=(3, ))
  259. obs_space = Box(float("-inf"), float("inf"), (4, ))
  260. max_seq_len = 50
  261. rollout_fragment_length = 201
  262. policies = {
  263. "pol0": (EpisodeEnvAwareAttentionPolicy, obs_space, action_space,
  264. {}),
  265. }
  266. def policy_fn(agent_id, episode, **kwargs):
  267. return "pol0"
  268. config = {
  269. "multiagent": {
  270. "policies": policies,
  271. "policy_mapping_fn": policy_fn,
  272. },
  273. "model": {
  274. "max_seq_len": max_seq_len,
  275. },
  276. }
  277. rollout_worker_w_api = RolloutWorker(
  278. env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}),
  279. policy_config=config,
  280. rollout_fragment_length=rollout_fragment_length,
  281. policy_spec=policies,
  282. policy_mapping_fn=policy_fn,
  283. normalize_actions=False,
  284. num_envs=1,
  285. )
  286. batch = rollout_worker_w_api.sample()
  287. print(batch)
  288. def test_counting_by_agent_steps(self):
  289. """Test whether a PPOTrainer can be built with all frameworks."""
  290. config = copy.deepcopy(ppo.DEFAULT_CONFIG)
  291. num_agents = 3
  292. config["num_workers"] = 2
  293. config["num_sgd_iter"] = 2
  294. config["framework"] = "torch"
  295. config["rollout_fragment_length"] = 21
  296. config["train_batch_size"] = 147
  297. config["multiagent"] = {
  298. "policies": {f"p{i}"
  299. for i in range(num_agents)},
  300. "policy_mapping_fn": lambda aid, **kwargs: "p{}".format(aid),
  301. "count_steps_by": "agent_steps",
  302. }
  303. # Env setup.
  304. config["env"] = MultiAgentPendulum
  305. config["env_config"] = {"num_agents": num_agents}
  306. num_iterations = 2
  307. trainer = ppo.PPOTrainer(config=config)
  308. results = None
  309. for i in range(num_iterations):
  310. results = trainer.train()
  311. self.assertEqual(results["agent_timesteps_total"],
  312. results["timesteps_total"] * num_agents)
  313. self.assertGreaterEqual(results["agent_timesteps_total"],
  314. num_iterations * config["train_batch_size"])
  315. self.assertLessEqual(results["agent_timesteps_total"],
  316. (num_iterations + 1) * config["train_batch_size"])
  317. trainer.stop()
  318. def test_get_single_step_input_dict_batch_repeat_value_larger_1(self):
  319. """Test whether a SampleBatch produces the correct 1-step input dict.
  320. """
  321. space = Box(-1.0, 1.0, ())
  322. # With batch-repeat-value > 1: state_in_0 is only built every n
  323. # timesteps.
  324. view_reqs = {
  325. "state_in_0": ViewRequirement(
  326. data_col="state_out_0",
  327. shift="-5:-1",
  328. space=space,
  329. batch_repeat_value=5,
  330. ),
  331. "state_out_0": ViewRequirement(
  332. space=space, used_for_compute_actions=False),
  333. }
  334. # Trajectory of 1 ts (0) (we would like to compute the 1st).
  335. batch = SampleBatch({
  336. "state_in_0": np.array([
  337. [0, 0, 0, 0, 0], # ts=0
  338. ]),
  339. "state_out_0": np.array([1]),
  340. })
  341. input_dict = batch.get_single_step_input_dict(
  342. view_requirements=view_reqs, index="last")
  343. check(
  344. input_dict,
  345. {
  346. "state_in_0": [[0, 0, 0, 0, 1]], # ts=1
  347. "seq_lens": [1],
  348. })
  349. # Trajectory of 6 ts (0-5) (we would like to compute the 6th).
  350. batch = SampleBatch({
  351. "state_in_0": np.array([
  352. [0, 0, 0, 0, 0], # ts=0
  353. [1, 2, 3, 4, 5], # ts=5
  354. ]),
  355. "state_out_0": np.array([1, 2, 3, 4, 5, 6]),
  356. })
  357. input_dict = batch.get_single_step_input_dict(
  358. view_requirements=view_reqs, index="last")
  359. check(
  360. input_dict,
  361. {
  362. "state_in_0": [[2, 3, 4, 5, 6]], # ts=6
  363. "seq_lens": [1],
  364. })
  365. # Trajectory of 12 ts (0-11) (we would like to compute the 12th).
  366. batch = SampleBatch({
  367. "state_in_0": np.array([
  368. [0, 0, 0, 0, 0], # ts=0
  369. [1, 2, 3, 4, 5], # ts=5
  370. [6, 7, 8, 9, 10], # ts=10
  371. ]),
  372. "state_out_0": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
  373. })
  374. input_dict = batch.get_single_step_input_dict(
  375. view_requirements=view_reqs, index="last")
  376. check(
  377. input_dict,
  378. {
  379. "state_in_0": [[8, 9, 10, 11, 12]], # ts=12
  380. "seq_lens": [1],
  381. })
  382. def test_get_single_step_input_dict_batch_repeat_value_1(self):
  383. """Test whether a SampleBatch produces the correct 1-step input dict.
  384. """
  385. space = Box(-1.0, 1.0, ())
  386. # With batch-repeat-value==1: state_in_0 is built each timestep.
  387. view_reqs = {
  388. "state_in_0": ViewRequirement(
  389. data_col="state_out_0",
  390. shift="-5:-1",
  391. space=space,
  392. batch_repeat_value=1,
  393. ),
  394. "state_out_0": ViewRequirement(
  395. space=space, used_for_compute_actions=False),
  396. }
  397. # Trajectory of 1 ts (0) (we would like to compute the 1st).
  398. batch = SampleBatch({
  399. "state_in_0": np.array([
  400. [0, 0, 0, 0, 0], # ts=0
  401. ]),
  402. "state_out_0": np.array([1]),
  403. })
  404. input_dict = batch.get_single_step_input_dict(
  405. view_requirements=view_reqs, index="last")
  406. check(
  407. input_dict,
  408. {
  409. "state_in_0": [[0, 0, 0, 0, 1]], # ts=1
  410. "seq_lens": [1],
  411. })
  412. # Trajectory of 6 ts (0-5) (we would like to compute the 6th).
  413. batch = SampleBatch({
  414. "state_in_0": np.array([
  415. [0, 0, 0, 0, 0], # ts=0
  416. [0, 0, 0, 0, 1], # ts=1
  417. [0, 0, 0, 1, 2], # ts=2
  418. [0, 0, 1, 2, 3], # ts=3
  419. [0, 1, 2, 3, 4], # ts=4
  420. [1, 2, 3, 4, 5], # ts=5
  421. ]),
  422. "state_out_0": np.array([1, 2, 3, 4, 5, 6]),
  423. })
  424. input_dict = batch.get_single_step_input_dict(
  425. view_requirements=view_reqs, index="last")
  426. check(
  427. input_dict,
  428. {
  429. "state_in_0": [[2, 3, 4, 5, 6]], # ts=6
  430. "seq_lens": [1],
  431. })
  432. # Trajectory of 12 ts (0-11) (we would like to compute the 12th).
  433. batch = SampleBatch({
  434. "state_in_0": np.array([
  435. [0, 0, 0, 0, 0], # ts=0
  436. [0, 0, 0, 0, 1], # ts=1
  437. [0, 0, 0, 1, 2], # ts=2
  438. [0, 0, 1, 2, 3], # ts=3
  439. [0, 1, 2, 3, 4], # ts=4
  440. [1, 2, 3, 4, 5], # ts=5
  441. [2, 3, 4, 5, 6], # ts=6
  442. [3, 4, 5, 6, 7], # ts=7
  443. [4, 5, 6, 7, 8], # ts=8
  444. [5, 6, 7, 8, 9], # ts=9
  445. [6, 7, 8, 9, 10], # ts=10
  446. [7, 8, 9, 10, 11], # ts=11
  447. ]),
  448. "state_out_0": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
  449. })
  450. input_dict = batch.get_single_step_input_dict(
  451. view_requirements=view_reqs, index="last")
  452. check(
  453. input_dict,
  454. {
  455. "state_in_0": [[8, 9, 10, 11, 12]], # ts=12
  456. "seq_lens": [1],
  457. })
  458. def analyze_rnn_batch(batch, max_seq_len, view_requirements):
  459. count = batch.count
  460. # Check prev_reward/action, next_obs consistency.
  461. for idx in range(count):
  462. # If timestep tracked by batch, good.
  463. if "t" in batch:
  464. ts = batch["t"][idx]
  465. # Else, ts
  466. else:
  467. ts = batch["obs"][idx][3]
  468. obs_t = batch["obs"][idx]
  469. a_t = batch["actions"][idx]
  470. r_t = batch["rewards"][idx]
  471. state_in_0 = batch["state_in_0"][idx]
  472. state_in_1 = batch["state_in_1"][idx]
  473. # Check postprocessing outputs.
  474. if "2xobs" in batch:
  475. postprocessed_col_t = batch["2xobs"][idx]
  476. assert (obs_t == postprocessed_col_t / 2.0).all()
  477. # Check state-in/out and next-obs values.
  478. if idx > 0:
  479. next_obs_t_m_1 = batch["new_obs"][idx - 1]
  480. state_out_0_t_m_1 = batch["state_out_0"][idx - 1]
  481. state_out_1_t_m_1 = batch["state_out_1"][idx - 1]
  482. # Same trajectory as for t-1 -> Should be able to match.
  483. if (batch[SampleBatch.AGENT_INDEX][idx] ==
  484. batch[SampleBatch.AGENT_INDEX][idx - 1]
  485. and batch[SampleBatch.EPS_ID][idx] ==
  486. batch[SampleBatch.EPS_ID][idx - 1]):
  487. assert batch["unroll_id"][idx - 1] == batch["unroll_id"][idx]
  488. assert (obs_t == next_obs_t_m_1).all()
  489. assert (state_in_0 == state_out_0_t_m_1).all()
  490. assert (state_in_1 == state_out_1_t_m_1).all()
  491. # Different trajectory.
  492. else:
  493. assert batch["unroll_id"][idx - 1] != batch["unroll_id"][idx]
  494. assert not (obs_t == next_obs_t_m_1).all()
  495. assert not (state_in_0 == state_out_0_t_m_1).all()
  496. assert not (state_in_1 == state_out_1_t_m_1).all()
  497. # Check initial 0-internal states.
  498. if ts == 0:
  499. assert (state_in_0 == 0.0).all()
  500. assert (state_in_1 == 0.0).all()
  501. # Check initial 0-internal states (at ts=0).
  502. if ts == 0:
  503. assert (state_in_0 == 0.0).all()
  504. assert (state_in_1 == 0.0).all()
  505. # Check prev. a/r values.
  506. if idx < count - 1:
  507. prev_actions_t_p_1 = batch["prev_actions"][idx + 1]
  508. prev_rewards_t_p_1 = batch["prev_rewards"][idx + 1]
  509. # Same trajectory as for t+1 -> Should be able to match.
  510. if batch[SampleBatch.AGENT_INDEX][idx] == \
  511. batch[SampleBatch.AGENT_INDEX][idx + 1] and \
  512. batch[SampleBatch.EPS_ID][idx] == \
  513. batch[SampleBatch.EPS_ID][idx + 1]:
  514. assert (a_t == prev_actions_t_p_1).all()
  515. assert r_t == prev_rewards_t_p_1
  516. # Different (new) trajectory. Assume t-1 (prev-a/r) to be
  517. # always 0.0s. [3]=ts
  518. elif ts == 0:
  519. assert (prev_actions_t_p_1 == 0).all()
  520. assert prev_rewards_t_p_1 == 0.0
  521. pad_batch_to_sequences_of_same_size(
  522. batch,
  523. max_seq_len=max_seq_len,
  524. shuffle=False,
  525. batch_divisibility_req=1,
  526. view_requirements=view_requirements,
  527. )
  528. # Check after seq-len 0-padding.
  529. cursor = 0
  530. for i, seq_len in enumerate(batch[SampleBatch.SEQ_LENS]):
  531. state_in_0 = batch["state_in_0"][i]
  532. state_in_1 = batch["state_in_1"][i]
  533. for j in range(seq_len):
  534. k = cursor + j
  535. ts = batch["t"][k]
  536. obs_t = batch["obs"][k]
  537. a_t = batch["actions"][k]
  538. r_t = batch["rewards"][k]
  539. # Check postprocessing outputs.
  540. if "2xobs" in batch:
  541. postprocessed_col_t = batch["2xobs"][k]
  542. assert (obs_t == postprocessed_col_t / 2.0).all()
  543. # Check state-in/out and next-obs values.
  544. if j > 0:
  545. next_obs_t_m_1 = batch["new_obs"][k - 1]
  546. # state_out_0_t_m_1 = batch["state_out_0"][k - 1]
  547. # state_out_1_t_m_1 = batch["state_out_1"][k - 1]
  548. # Always same trajectory as for t-1.
  549. assert batch["unroll_id"][k - 1] == batch["unroll_id"][k]
  550. assert (obs_t == next_obs_t_m_1).all()
  551. # assert (state_in_0 == state_out_0_t_m_1).all())
  552. # assert (state_in_1 == state_out_1_t_m_1).all())
  553. # Check initial 0-internal states.
  554. elif ts == 0:
  555. assert (state_in_0 == 0.0).all()
  556. assert (state_in_1 == 0.0).all()
  557. for j in range(seq_len, max_seq_len):
  558. k = cursor + j
  559. obs_t = batch["obs"][k]
  560. a_t = batch["actions"][k]
  561. r_t = batch["rewards"][k]
  562. assert (obs_t == 0.0).all()
  563. assert (a_t == 0.0).all()
  564. assert (r_t == 0.0).all()
  565. cursor += max_seq_len
  566. if __name__ == "__main__":
  567. import pytest
  568. import sys
  569. sys.exit(pytest.main(["-v", __file__]))