test_nested_observation_spaces.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649
  1. from gymnasium import spaces
  2. import gymnasium as gym
  3. import numpy as np
  4. import pickle
  5. import unittest
  6. import ray
  7. from ray.rllib.algorithms.a2c import A2CConfig
  8. from ray.rllib.algorithms.pg import PGConfig
  9. from ray.rllib.env import MultiAgentEnv
  10. from ray.rllib.env.base_env import convert_to_base_env
  11. from ray.rllib.env.tests.test_external_env import SimpleServing
  12. from ray.rllib.env.vector_env import VectorEnv
  13. from ray.rllib.models import ModelCatalog
  14. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  15. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
  16. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  17. from ray.rllib.evaluate import rollout
  18. from ray.tune.registry import register_env
  19. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  20. from ray.rllib.utils.numpy import one_hot
  21. from ray.rllib.utils.spaces.repeated import Repeated
  22. from ray.rllib.utils.test_utils import check
  23. tf1, tf, tfv = try_import_tf()
  24. _, nn = try_import_torch()
  25. DICT_SPACE = spaces.Dict(
  26. {
  27. "sensors": spaces.Dict(
  28. {
  29. "position": spaces.Box(low=-100, high=100, shape=(3,)),
  30. "velocity": spaces.Box(low=-1, high=1, shape=(3,)),
  31. "front_cam": spaces.Tuple(
  32. (
  33. spaces.Box(low=0, high=1, shape=(10, 10, 3)),
  34. spaces.Box(low=0, high=1, shape=(10, 10, 3)),
  35. )
  36. ),
  37. "rear_cam": spaces.Box(low=0, high=1, shape=(10, 10, 3)),
  38. }
  39. ),
  40. "inner_state": spaces.Dict(
  41. {
  42. "charge": spaces.Discrete(100),
  43. "job_status": spaces.Dict(
  44. {
  45. "task": spaces.Discrete(5),
  46. "progress": spaces.Box(low=0, high=100, shape=()),
  47. }
  48. ),
  49. }
  50. ),
  51. }
  52. )
  53. DICT_SAMPLES = [DICT_SPACE.sample() for _ in range(10)]
  54. TUPLE_SPACE = spaces.Tuple(
  55. [
  56. spaces.Box(low=-100, high=100, shape=(3,)),
  57. spaces.Tuple(
  58. (
  59. spaces.Box(low=0, high=1, shape=(10, 10, 3)),
  60. spaces.Box(low=0, high=1, shape=(10, 10, 3)),
  61. )
  62. ),
  63. spaces.Discrete(5),
  64. ]
  65. )
  66. TUPLE_SAMPLES = [TUPLE_SPACE.sample() for _ in range(10)]
  67. # Constraints on the Repeated space.
  68. MAX_PLAYERS = 4
  69. MAX_ITEMS = 7
  70. MAX_EFFECTS = 2
  71. ITEM_SPACE = spaces.Box(-5, 5, shape=(1,))
  72. EFFECT_SPACE = spaces.Box(9000, 9999, shape=(4,))
  73. PLAYER_SPACE = spaces.Dict(
  74. {
  75. "location": spaces.Box(-100, 100, shape=(2,)),
  76. "items": Repeated(ITEM_SPACE, max_len=MAX_ITEMS),
  77. "effects": Repeated(EFFECT_SPACE, max_len=MAX_EFFECTS),
  78. "status": spaces.Box(-1, 1, shape=(10,)),
  79. }
  80. )
  81. REPEATED_SPACE = Repeated(PLAYER_SPACE, max_len=MAX_PLAYERS)
  82. REPEATED_SAMPLES = [REPEATED_SPACE.sample() for _ in range(10)]
  83. class NestedDictEnv(gym.Env):
  84. def __init__(self):
  85. self.action_space = spaces.Discrete(2)
  86. self.observation_space = DICT_SPACE
  87. self.steps = 0
  88. def reset(self, *, seed=None, options=None):
  89. self.steps = 0
  90. return DICT_SAMPLES[0], {}
  91. def step(self, action):
  92. self.steps += 1
  93. terminated = False
  94. truncated = self.steps >= 5
  95. return DICT_SAMPLES[self.steps], 1, terminated, truncated, {}
  96. class NestedTupleEnv(gym.Env):
  97. def __init__(self):
  98. self.action_space = spaces.Discrete(2)
  99. self.observation_space = TUPLE_SPACE
  100. self.steps = 0
  101. def reset(self, *, seed=None, options=None):
  102. self.steps = 0
  103. return TUPLE_SAMPLES[0], {}
  104. def step(self, action):
  105. self.steps += 1
  106. terminated = False
  107. truncated = self.steps >= 5
  108. return TUPLE_SAMPLES[self.steps], 1, terminated, truncated, {}
  109. class RepeatedSpaceEnv(gym.Env):
  110. def __init__(self):
  111. self.action_space = spaces.Discrete(2)
  112. self.observation_space = REPEATED_SPACE
  113. self.steps = 0
  114. def reset(self, *, seed=None, options=None):
  115. self.steps = 0
  116. return REPEATED_SAMPLES[0], {}
  117. def step(self, action):
  118. self.steps += 1
  119. terminated = False
  120. truncated = self.steps >= 5
  121. return REPEATED_SAMPLES[self.steps], 1, terminated, truncated, {}
  122. class NestedMultiAgentEnv(MultiAgentEnv):
  123. def __init__(self):
  124. super().__init__()
  125. self.observation_space = spaces.Dict(
  126. {"dict_agent": DICT_SPACE, "tuple_agent": TUPLE_SPACE}
  127. )
  128. self.action_space = spaces.Dict(
  129. {"dict_agent": spaces.Discrete(1), "tuple_agent": spaces.Discrete(1)}
  130. )
  131. self._agent_ids = {"dict_agent", "tuple_agent"}
  132. self.steps = 0
  133. def reset(self, *, seed=None, options=None):
  134. return {
  135. "dict_agent": DICT_SAMPLES[0],
  136. "tuple_agent": TUPLE_SAMPLES[0],
  137. }, {}
  138. def step(self, actions):
  139. self.steps += 1
  140. obs = {
  141. "dict_agent": DICT_SAMPLES[self.steps],
  142. "tuple_agent": TUPLE_SAMPLES[self.steps],
  143. }
  144. rew = {
  145. "dict_agent": 0,
  146. "tuple_agent": 0,
  147. }
  148. terminateds = {"__all__": self.steps >= 5}
  149. truncateds = {"__all__": self.steps >= 5}
  150. infos = {
  151. "dict_agent": {},
  152. "tuple_agent": {},
  153. }
  154. return obs, rew, terminateds, truncateds, infos
  155. class InvalidModel(TorchModelV2):
  156. def forward(self, input_dict, state, seq_lens):
  157. return "not", "valid"
  158. class InvalidModel2(TFModelV2):
  159. def forward(self, input_dict, state, seq_lens):
  160. return tf.constant(0), tf.constant(0)
  161. class TorchSpyModel(TorchModelV2, nn.Module):
  162. capture_index = 0
  163. def __init__(self, obs_space, action_space, num_outputs, model_config, name):
  164. TorchModelV2.__init__(
  165. self, obs_space, action_space, num_outputs, model_config, name
  166. )
  167. nn.Module.__init__(self)
  168. self.fc = FullyConnectedNetwork(
  169. obs_space.original_space["sensors"].spaces["position"],
  170. action_space,
  171. num_outputs,
  172. model_config,
  173. name,
  174. )
  175. def forward(self, input_dict, state, seq_lens):
  176. pos = input_dict["obs"]["sensors"]["position"].detach().cpu().numpy()
  177. front_cam = input_dict["obs"]["sensors"]["front_cam"][0].detach().cpu().numpy()
  178. task = (
  179. input_dict["obs"]["inner_state"]["job_status"]["task"]
  180. .detach()
  181. .cpu()
  182. .numpy()
  183. )
  184. ray.experimental.internal_kv._internal_kv_put(
  185. "torch_spy_in_{}".format(TorchSpyModel.capture_index),
  186. pickle.dumps((pos, front_cam, task)),
  187. overwrite=True,
  188. )
  189. TorchSpyModel.capture_index += 1
  190. return self.fc(
  191. {"obs": input_dict["obs"]["sensors"]["position"]}, state, seq_lens
  192. )
  193. def value_function(self):
  194. return self.fc.value_function()
  195. class TorchRepeatedSpyModel(TorchModelV2, nn.Module):
  196. capture_index = 0
  197. def __init__(self, obs_space, action_space, num_outputs, model_config, name):
  198. TorchModelV2.__init__(
  199. self, obs_space, action_space, num_outputs, model_config, name
  200. )
  201. nn.Module.__init__(self)
  202. self.fc = FullyConnectedNetwork(
  203. obs_space.original_space.child_space["location"],
  204. action_space,
  205. num_outputs,
  206. model_config,
  207. name,
  208. )
  209. def forward(self, input_dict, state, seq_lens):
  210. ray.experimental.internal_kv._internal_kv_put(
  211. "torch_rspy_in_{}".format(TorchRepeatedSpyModel.capture_index),
  212. pickle.dumps(input_dict["obs"].unbatch_all()),
  213. overwrite=True,
  214. )
  215. TorchRepeatedSpyModel.capture_index += 1
  216. return self.fc(
  217. {"obs": input_dict["obs"].values["location"][:, 0]}, state, seq_lens
  218. )
  219. def value_function(self):
  220. return self.fc.value_function()
  221. def to_list(value):
  222. if isinstance(value, list):
  223. return [to_list(x) for x in value]
  224. elif isinstance(value, dict):
  225. return {k: to_list(v) for k, v in value.items()}
  226. elif isinstance(value, np.ndarray):
  227. return value.tolist()
  228. elif isinstance(value, int):
  229. return value
  230. else:
  231. return value.detach().cpu().numpy().tolist()
  232. class DictSpyModel(TFModelV2):
  233. capture_index = 0
  234. def __init__(self, obs_space, action_space, num_outputs, model_config, name):
  235. super().__init__(obs_space, action_space, None, model_config, name)
  236. # Will only feed in sensors->pos.
  237. input_ = tf.keras.layers.Input(
  238. shape=self.obs_space["sensors"]["position"].shape
  239. )
  240. self.num_outputs = num_outputs or 64
  241. out = tf.keras.layers.Dense(self.num_outputs)(input_)
  242. self._main_layer = tf.keras.models.Model([input_], [out])
  243. def forward(self, input_dict, state, seq_lens):
  244. def spy(pos, front_cam, task):
  245. # TF runs this function in an isolated context, so we have to use
  246. # redis to communicate back to our suite
  247. ray.experimental.internal_kv._internal_kv_put(
  248. "d_spy_in_{}".format(DictSpyModel.capture_index),
  249. pickle.dumps((pos, front_cam, task)),
  250. overwrite=True,
  251. )
  252. DictSpyModel.capture_index += 1
  253. return np.array(0, dtype=np.int64)
  254. spy_fn = tf1.py_func(
  255. spy,
  256. [
  257. input_dict["obs"]["sensors"]["position"],
  258. input_dict["obs"]["sensors"]["front_cam"][0],
  259. input_dict["obs"]["inner_state"]["job_status"]["task"],
  260. ],
  261. tf.int64,
  262. stateful=True,
  263. )
  264. with tf1.control_dependencies([spy_fn]):
  265. output = self._main_layer([input_dict["obs"]["sensors"]["position"]])
  266. return output, []
  267. class TupleSpyModel(TFModelV2):
  268. capture_index = 0
  269. def __init__(self, obs_space, action_space, num_outputs, model_config, name):
  270. super().__init__(obs_space, action_space, None, model_config, name)
  271. # Will only feed in 0th index of observation Tuple space.
  272. input_ = tf.keras.layers.Input(shape=self.obs_space[0].shape)
  273. self.num_outputs = num_outputs or 64
  274. out = tf.keras.layers.Dense(self.num_outputs)(input_)
  275. self._main_layer = tf.keras.models.Model([input_], [out])
  276. def forward(self, input_dict, state, seq_lens):
  277. def spy(pos, cam, task):
  278. # TF runs this function in an isolated context, so we have to use
  279. # redis to communicate back to our suite
  280. ray.experimental.internal_kv._internal_kv_put(
  281. "t_spy_in_{}".format(TupleSpyModel.capture_index),
  282. pickle.dumps((pos, cam, task)),
  283. overwrite=True,
  284. )
  285. TupleSpyModel.capture_index += 1
  286. return np.array(0, dtype=np.int64)
  287. spy_fn = tf1.py_func(
  288. spy,
  289. [
  290. input_dict["obs"][0],
  291. input_dict["obs"][1][0],
  292. input_dict["obs"][2],
  293. ],
  294. tf.int64,
  295. stateful=True,
  296. )
  297. with tf1.control_dependencies([spy_fn]):
  298. output = tf1.layers.dense(input_dict["obs"][0], self.num_outputs)
  299. return output, []
  300. class TestNestedObservationSpaces(unittest.TestCase):
  301. @classmethod
  302. def setUpClass(cls):
  303. ray.init()
  304. @classmethod
  305. def tearDownClass(cls):
  306. ray.shutdown()
  307. def test_invalid_model(self):
  308. ModelCatalog.register_custom_model("invalid", InvalidModel)
  309. config = (
  310. PGConfig()
  311. .environment("CartPole-v1")
  312. .framework("torch")
  313. .training(model={"custom_model": "invalid"})
  314. )
  315. self.assertRaisesRegex(
  316. ValueError,
  317. "Subclasses of TorchModelV2 must also inherit from nn.Module",
  318. lambda: config.build(),
  319. )
  320. def test_invalid_model2(self):
  321. ModelCatalog.register_custom_model("invalid2", InvalidModel2)
  322. config = (
  323. PGConfig()
  324. .environment("CartPole-v1")
  325. .framework("tf")
  326. .training(model={"custom_model": "invalid2"})
  327. )
  328. self.assertRaisesRegex(
  329. ValueError,
  330. "State output is not a list",
  331. lambda: config.build(),
  332. )
  333. def do_test_nested_dict(self, make_env, test_lstm=False, disable_connectors=False):
  334. ModelCatalog.register_custom_model("composite", DictSpyModel)
  335. register_env("nested", make_env)
  336. config = (
  337. PGConfig()
  338. .environment("nested", disable_env_checking=True)
  339. .rollouts(num_rollout_workers=0, rollout_fragment_length=5)
  340. .framework("tf")
  341. .training(
  342. model={"custom_model": "composite", "use_lstm": test_lstm},
  343. train_batch_size=5,
  344. )
  345. )
  346. if disable_connectors:
  347. # manually disable the connectors
  348. # TODO(avnishn): remove this after deprecating external_env
  349. config = config.rollouts(enable_connectors=False)
  350. pg = config.build()
  351. # Skip first passes as they came from the TorchPolicy loss
  352. # initialization.
  353. DictSpyModel.capture_index = 0
  354. pg.train()
  355. # Check that the model sees the correct reconstructed observations
  356. for i in range(4):
  357. seen = pickle.loads(
  358. ray.experimental.internal_kv._internal_kv_get("d_spy_in_{}".format(i))
  359. )
  360. pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
  361. cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
  362. task_i = DICT_SAMPLES[i]["inner_state"]["job_status"]["task"]
  363. self.assertEqual(seen[0][0].tolist(), pos_i)
  364. self.assertEqual(seen[1][0].tolist(), cam_i)
  365. check(seen[2][0], task_i)
  366. def do_test_nested_tuple(self, make_env, disable_connectors=False):
  367. ModelCatalog.register_custom_model("composite2", TupleSpyModel)
  368. register_env("nested2", make_env)
  369. config = (
  370. PGConfig()
  371. .environment("nested2", disable_env_checking=True)
  372. .rollouts(num_rollout_workers=0, rollout_fragment_length=5)
  373. .framework("tf")
  374. .training(model={"custom_model": "composite2"}, train_batch_size=5)
  375. )
  376. if disable_connectors:
  377. # manually disable the connectors
  378. # TODO(avnishn): remove this after deprecating external_env
  379. config = config.rollouts(enable_connectors=False)
  380. pg = config.build()
  381. # Skip first passes as they came from the TorchPolicy loss
  382. # initialization.
  383. TupleSpyModel.capture_index = 0
  384. pg.train()
  385. # Check that the model sees the correct reconstructed observations
  386. for i in range(4):
  387. seen = pickle.loads(
  388. ray.experimental.internal_kv._internal_kv_get("t_spy_in_{}".format(i))
  389. )
  390. pos_i = TUPLE_SAMPLES[i][0].tolist()
  391. cam_i = TUPLE_SAMPLES[i][1][0].tolist()
  392. task_i = TUPLE_SAMPLES[i][2]
  393. self.assertEqual(seen[0][0].tolist(), pos_i)
  394. self.assertEqual(seen[1][0].tolist(), cam_i)
  395. check(seen[2][0], task_i)
  396. def test_nested_dict_gym(self):
  397. self.do_test_nested_dict(lambda _: NestedDictEnv())
  398. def test_nested_dict_gym_lstm(self):
  399. self.do_test_nested_dict(lambda _: NestedDictEnv(), test_lstm=True)
  400. def test_nested_dict_vector(self):
  401. self.do_test_nested_dict(
  402. lambda _: VectorEnv.vectorize_gym_envs(lambda i: NestedDictEnv())
  403. )
  404. def test_nested_dict_serving(self):
  405. # TODO: (Artur) Enable this test again for connectors if discrepancies
  406. # between EnvRunnerV2 and ExternalEnv are resolved
  407. if not PGConfig().enable_connectors:
  408. self.do_test_nested_dict(lambda _: SimpleServing(NestedDictEnv()))
  409. def test_nested_dict_async(self):
  410. self.do_test_nested_dict(lambda _: convert_to_base_env(NestedDictEnv()))
  411. def test_nested_tuple_gym(self):
  412. self.do_test_nested_tuple(lambda _: NestedTupleEnv())
  413. def test_nested_tuple_vector(self):
  414. self.do_test_nested_tuple(
  415. lambda _: VectorEnv.vectorize_gym_envs(lambda i: NestedTupleEnv())
  416. )
  417. def test_nested_tuple_serving(self):
  418. # TODO: (Artur) Enable this test again for connectors if discrepancies
  419. # between EnvRunnerV2 and ExternalEnv are resolved
  420. if not PGConfig().enable_connectors:
  421. self.do_test_nested_tuple(lambda _: SimpleServing(NestedTupleEnv()))
  422. def test_nested_tuple_async(self):
  423. self.do_test_nested_tuple(lambda _: convert_to_base_env(NestedTupleEnv()))
  424. def test_multi_agent_complex_spaces(self):
  425. ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
  426. ModelCatalog.register_custom_model("tuple_spy", TupleSpyModel)
  427. register_env("nested_ma", lambda _: NestedMultiAgentEnv())
  428. act_space = spaces.Discrete(2)
  429. config = (
  430. PGConfig()
  431. .environment("nested_ma", disable_env_checking=True)
  432. .framework("tf")
  433. .rollouts(num_rollout_workers=0, rollout_fragment_length=5)
  434. .training(train_batch_size=5)
  435. .multi_agent(
  436. policies={
  437. "tuple_policy": (
  438. None,
  439. TUPLE_SPACE,
  440. act_space,
  441. PGConfig.overrides(model={"custom_model": "tuple_spy"}),
  442. ),
  443. "dict_policy": (
  444. None,
  445. DICT_SPACE,
  446. act_space,
  447. PGConfig.overrides(model={"custom_model": "dict_spy"}),
  448. ),
  449. },
  450. policy_mapping_fn=(
  451. lambda agent_id, episode, worker, **kwargs: {
  452. "tuple_agent": "tuple_policy",
  453. "dict_agent": "dict_policy",
  454. }[agent_id]
  455. ),
  456. )
  457. )
  458. pg = config.build()
  459. # Skip first passes as they came from the TorchPolicy loss
  460. # initialization.
  461. TupleSpyModel.capture_index = DictSpyModel.capture_index = 0
  462. pg.train()
  463. for i in range(4):
  464. seen = pickle.loads(
  465. ray.experimental.internal_kv._internal_kv_get("d_spy_in_{}".format(i))
  466. )
  467. pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
  468. cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
  469. task_i = DICT_SAMPLES[i]["inner_state"]["job_status"]["task"]
  470. self.assertEqual(seen[0][0].tolist(), pos_i)
  471. self.assertEqual(seen[1][0].tolist(), cam_i)
  472. check(seen[2][0], task_i)
  473. for i in range(4):
  474. seen = pickle.loads(
  475. ray.experimental.internal_kv._internal_kv_get("t_spy_in_{}".format(i))
  476. )
  477. pos_i = TUPLE_SAMPLES[i][0].tolist()
  478. cam_i = TUPLE_SAMPLES[i][1][0].tolist()
  479. task_i = TUPLE_SAMPLES[i][2]
  480. self.assertEqual(seen[0][0].tolist(), pos_i)
  481. self.assertEqual(seen[1][0].tolist(), cam_i)
  482. check(seen[2][0], task_i)
  483. def test_rollout_dict_space(self):
  484. register_env("nested", lambda _: NestedDictEnv())
  485. config = PGConfig().environment("nested").framework("tf")
  486. algo = config.build()
  487. algo.train()
  488. result = algo.save()
  489. algo.stop()
  490. # Test train works on restore
  491. algo2 = config.build()
  492. algo2.restore(result)
  493. algo2.train()
  494. # Test rollout works on restore
  495. rollout(algo2, "nested", 100)
  496. def test_py_torch_model(self):
  497. ModelCatalog.register_custom_model("composite", TorchSpyModel)
  498. register_env("nested", lambda _: NestedDictEnv())
  499. config = (
  500. A2CConfig()
  501. .environment("nested")
  502. .framework("torch")
  503. .rollouts(num_rollout_workers=0, rollout_fragment_length=5)
  504. .training(train_batch_size=5, model={"custom_model": "composite"})
  505. )
  506. a2c = config.build()
  507. # Skip first passes as they came from the TorchPolicy loss
  508. # initialization.
  509. TorchSpyModel.capture_index = 0
  510. a2c.train()
  511. # Check that the model sees the correct reconstructed observations
  512. for i in range(4):
  513. seen = pickle.loads(
  514. ray.experimental.internal_kv._internal_kv_get(
  515. "torch_spy_in_{}".format(i)
  516. )
  517. )
  518. pos_i = DICT_SAMPLES[i]["sensors"]["position"].tolist()
  519. cam_i = DICT_SAMPLES[i]["sensors"]["front_cam"][0].tolist()
  520. task_i = one_hot(DICT_SAMPLES[i]["inner_state"]["job_status"]["task"], 5)
  521. # Only look at the last entry (-1) in `seen` as we reset (re-use)
  522. # the ray-kv indices before training.
  523. self.assertEqual(seen[0][-1].tolist(), pos_i)
  524. self.assertEqual(seen[1][-1].tolist(), cam_i)
  525. check(seen[2][-1], task_i)
  526. # TODO(ekl) should probably also add a test for TF/eager
  527. def test_torch_repeated(self):
  528. ModelCatalog.register_custom_model("r1", TorchRepeatedSpyModel)
  529. register_env("repeat", lambda _: RepeatedSpaceEnv())
  530. config = (
  531. A2CConfig()
  532. .environment("repeat")
  533. .framework("torch")
  534. .rollouts(num_rollout_workers=0, rollout_fragment_length=5)
  535. .training(train_batch_size=5, model={"custom_model": "r1"})
  536. )
  537. a2c = config.build()
  538. # Skip first passes as they came from the TorchPolicy loss
  539. # initialization.
  540. TorchRepeatedSpyModel.capture_index = 0
  541. a2c.train()
  542. # Check that the model sees the correct reconstructed observations
  543. for i in range(4):
  544. seen = pickle.loads(
  545. ray.experimental.internal_kv._internal_kv_get(
  546. "torch_rspy_in_{}".format(i)
  547. )
  548. )
  549. # Only look at the last entry (-1) in `seen` as we reset (re-use)
  550. # the ray-kv indices before training.
  551. self.assertEqual(to_list(seen[:][-1]), to_list(REPEATED_SAMPLES[i]))
  552. if __name__ == "__main__":
  553. import pytest
  554. import sys
  555. sys.exit(pytest.main(["-v", __file__]))