test_lstm.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import numpy as np
  2. import pickle
  3. import unittest
  4. import ray
  5. from ray.rllib.agents.ppo import PPOTrainer
  6. from ray.rllib.examples.env.debug_counter_env import DebugCounterEnv
  7. from ray.rllib.examples.models.rnn_spy_model import RNNSpyModel
  8. from ray.rllib.models import ModelCatalog
  9. from ray.rllib.policy.rnn_sequencing import chop_into_sequences
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.utils.test_utils import check
  12. from ray.tune.registry import register_env
  13. class TestLSTMUtils(unittest.TestCase):
  14. def test_basic(self):
  15. eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
  16. agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
  17. f = [[101, 102, 103, 201, 202, 203, 204, 205],
  18. [[101], [102], [103], [201], [202], [203], [204], [205]]]
  19. s = [[209, 208, 207, 109, 108, 107, 106, 105]]
  20. f_pad, s_init, seq_lens = chop_into_sequences(
  21. episode_ids=eps_ids,
  22. unroll_ids=np.ones_like(eps_ids),
  23. agent_indices=agent_ids,
  24. feature_columns=f,
  25. state_columns=s,
  26. max_seq_len=4)
  27. self.assertEqual([f.tolist() for f in f_pad], [
  28. [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0],
  29. [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0],
  30. [0], [0]],
  31. ])
  32. self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
  33. self.assertEqual(seq_lens.tolist(), [3, 4, 1])
  34. def test_nested(self):
  35. eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
  36. agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
  37. f = [{
  38. "a": np.array([1, 2, 3, 4, 13, 14, 15, 16]),
  39. "b": {
  40. "ba": np.array([5, 6, 7, 8, 9, 10, 11, 12])
  41. }
  42. }]
  43. s = [[209, 208, 207, 109, 108, 107, 106, 105]]
  44. f_pad, s_init, seq_lens = chop_into_sequences(
  45. episode_ids=eps_ids,
  46. unroll_ids=np.ones_like(eps_ids),
  47. agent_indices=agent_ids,
  48. feature_columns=f,
  49. state_columns=s,
  50. max_seq_len=4,
  51. handle_nested_data=True,
  52. )
  53. check(f_pad, [[[1, 2, 3, 0, 4, 13, 14, 15, 16, 0, 0, 0],
  54. [5, 6, 7, 0, 8, 9, 10, 11, 12, 0, 0, 0]]])
  55. self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]])
  56. self.assertEqual(seq_lens.tolist(), [3, 4, 1])
  57. def test_multi_dim(self):
  58. eps_ids = [1, 1, 1]
  59. agent_ids = [1, 1, 1]
  60. obs = np.ones((84, 84, 4))
  61. f = [[obs, obs * 2, obs * 3]]
  62. s = [[209, 208, 207]]
  63. f_pad, s_init, seq_lens = chop_into_sequences(
  64. episode_ids=eps_ids,
  65. unroll_ids=np.ones_like(eps_ids),
  66. agent_indices=agent_ids,
  67. feature_columns=f,
  68. state_columns=s,
  69. max_seq_len=4)
  70. self.assertEqual([f.tolist() for f in f_pad], [
  71. np.array([obs, obs * 2, obs * 3]).tolist(),
  72. ])
  73. self.assertEqual([s.tolist() for s in s_init], [[209]])
  74. self.assertEqual(seq_lens.tolist(), [3])
  75. def test_batch_id(self):
  76. eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
  77. batch_ids = [1, 1, 2, 2, 3, 3, 4, 4]
  78. agent_ids = [1, 1, 1, 1, 1, 1, 1, 1]
  79. f = [[101, 102, 103, 201, 202, 203, 204, 205],
  80. [[101], [102], [103], [201], [202], [203], [204], [205]]]
  81. s = [[209, 208, 207, 109, 108, 107, 106, 105]]
  82. _, _, seq_lens = chop_into_sequences(
  83. episode_ids=eps_ids,
  84. unroll_ids=batch_ids,
  85. agent_indices=agent_ids,
  86. feature_columns=f,
  87. state_columns=s,
  88. max_seq_len=4)
  89. self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
  90. def test_multi_agent(self):
  91. eps_ids = [1, 1, 1, 5, 5, 5, 5, 5]
  92. agent_ids = [1, 1, 2, 1, 1, 2, 2, 3]
  93. f = [[101, 102, 103, 201, 202, 203, 204, 205],
  94. [[101], [102], [103], [201], [202], [203], [204], [205]]]
  95. s = [[209, 208, 207, 109, 108, 107, 106, 105]]
  96. f_pad, s_init, seq_lens = chop_into_sequences(
  97. episode_ids=eps_ids,
  98. unroll_ids=np.ones_like(eps_ids),
  99. agent_indices=agent_ids,
  100. feature_columns=f,
  101. state_columns=s,
  102. max_seq_len=4,
  103. dynamic_max=False)
  104. self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1])
  105. self.assertEqual(len(f_pad[0]), 20)
  106. self.assertEqual(len(s_init[0]), 5)
  107. def test_dynamic_max_len(self):
  108. eps_ids = [5, 2, 2]
  109. agent_ids = [2, 2, 2]
  110. f = [[1, 1, 1]]
  111. s = [[1, 1, 1]]
  112. f_pad, s_init, seq_lens = chop_into_sequences(
  113. episode_ids=eps_ids,
  114. unroll_ids=np.ones_like(eps_ids),
  115. agent_indices=agent_ids,
  116. feature_columns=f,
  117. state_columns=s,
  118. max_seq_len=4)
  119. self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]])
  120. self.assertEqual([s.tolist() for s in s_init], [[1, 1]])
  121. self.assertEqual(seq_lens.tolist(), [1, 2])
  122. class TestRNNSequencing(unittest.TestCase):
  123. def setUp(self) -> None:
  124. ray.init(num_cpus=4)
  125. def tearDown(self) -> None:
  126. ray.shutdown()
  127. def test_simple_optimizer_sequencing(self):
  128. ModelCatalog.register_custom_model("rnn", RNNSpyModel)
  129. register_env("counter", lambda _: DebugCounterEnv())
  130. ppo = PPOTrainer(
  131. env="counter",
  132. config={
  133. "num_workers": 0,
  134. "rollout_fragment_length": 10,
  135. "train_batch_size": 10,
  136. "sgd_minibatch_size": 10,
  137. "num_sgd_iter": 1,
  138. "simple_optimizer": True,
  139. "model": {
  140. "custom_model": "rnn",
  141. "max_seq_len": 4,
  142. "vf_share_layers": True,
  143. },
  144. "framework": "tf",
  145. })
  146. ppo.train()
  147. ppo.train()
  148. batch0 = pickle.loads(
  149. ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_0"))
  150. self.assertEqual(
  151. batch0["sequences"].tolist(),
  152. [[[0], [1], [2], [3]], [[4], [5], [6], [7]], [[8], [9], [0], [0]]])
  153. self.assertEqual(batch0[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2])
  154. self.assertEqual(batch0["state_in"][0][0].tolist(), [0, 0, 0])
  155. self.assertEqual(batch0["state_in"][1][0].tolist(), [0, 0, 0])
  156. self.assertGreater(abs(np.sum(batch0["state_in"][0][1])), 0)
  157. self.assertGreater(abs(np.sum(batch0["state_in"][1][1])), 0)
  158. self.assertTrue(
  159. np.allclose(batch0["state_in"][0].tolist()[1:],
  160. batch0["state_out"][0].tolist()[:-1]))
  161. self.assertTrue(
  162. np.allclose(batch0["state_in"][1].tolist()[1:],
  163. batch0["state_out"][1].tolist()[:-1]))
  164. batch1 = pickle.loads(
  165. ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
  166. self.assertEqual(batch1["sequences"].tolist(), [
  167. [[10], [11], [12], [13]],
  168. [[14], [0], [0], [0]],
  169. [[0], [1], [2], [3]],
  170. [[4], [0], [0], [0]],
  171. ])
  172. self.assertEqual(batch1[SampleBatch.SEQ_LENS].tolist(), [4, 1, 4, 1])
  173. self.assertEqual(batch1["state_in"][0][2].tolist(), [0, 0, 0])
  174. self.assertEqual(batch1["state_in"][1][2].tolist(), [0, 0, 0])
  175. self.assertGreater(abs(np.sum(batch1["state_in"][0][0])), 0)
  176. self.assertGreater(abs(np.sum(batch1["state_in"][1][0])), 0)
  177. self.assertGreater(abs(np.sum(batch1["state_in"][0][1])), 0)
  178. self.assertGreater(abs(np.sum(batch1["state_in"][1][1])), 0)
  179. self.assertGreater(abs(np.sum(batch1["state_in"][0][3])), 0)
  180. self.assertGreater(abs(np.sum(batch1["state_in"][1][3])), 0)
  181. def test_minibatch_sequencing(self):
  182. ModelCatalog.register_custom_model("rnn", RNNSpyModel)
  183. register_env("counter", lambda _: DebugCounterEnv())
  184. ppo = PPOTrainer(
  185. env="counter",
  186. config={
  187. "shuffle_sequences": False, # for deterministic testing
  188. "num_workers": 0,
  189. "rollout_fragment_length": 20,
  190. "train_batch_size": 20,
  191. "sgd_minibatch_size": 10,
  192. "num_sgd_iter": 1,
  193. "model": {
  194. "custom_model": "rnn",
  195. "max_seq_len": 4,
  196. "vf_share_layers": True,
  197. },
  198. "framework": "tf",
  199. })
  200. ppo.train()
  201. ppo.train()
  202. # first epoch: 20 observations get split into 2 minibatches of 8
  203. # four observations are discarded
  204. batch0 = pickle.loads(
  205. ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_0"))
  206. batch1 = pickle.loads(
  207. ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_1"))
  208. if batch0["sequences"][0][0][0] > batch1["sequences"][0][0][0]:
  209. batch0, batch1 = batch1, batch0 # sort minibatches
  210. self.assertEqual(batch0[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2])
  211. self.assertEqual(batch1[SampleBatch.SEQ_LENS].tolist(), [2, 3, 4, 1])
  212. check(batch0["sequences"], [
  213. [[0], [1], [2], [3]],
  214. [[4], [5], [6], [7]],
  215. [[8], [9], [0], [0]],
  216. ])
  217. check(batch1["sequences"], [
  218. [[10], [11], [0], [0]],
  219. [[12], [13], [14], [0]],
  220. [[0], [1], [2], [3]],
  221. [[4], [0], [0], [0]],
  222. ])
  223. # second epoch: 20 observations get split into 2 minibatches of 8
  224. # four observations are discarded
  225. batch2 = pickle.loads(
  226. ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_2"))
  227. batch3 = pickle.loads(
  228. ray.experimental.internal_kv._internal_kv_get("rnn_spy_in_3"))
  229. if batch2["sequences"][0][0][0] > batch3["sequences"][0][0][0]:
  230. batch2, batch3 = batch3, batch2
  231. self.assertEqual(batch2[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2])
  232. self.assertEqual(batch3[SampleBatch.SEQ_LENS].tolist(), [4, 4, 2])
  233. check(batch2["sequences"], [
  234. [[0], [1], [2], [3]],
  235. [[4], [5], [6], [7]],
  236. [[8], [9], [0], [0]],
  237. ])
  238. check(batch3["sequences"], [
  239. [[5], [6], [7], [8]],
  240. [[9], [10], [11], [12]],
  241. [[13], [14], [0], [0]],
  242. ])
  243. if __name__ == "__main__":
  244. import pytest
  245. import sys
  246. sys.exit(pytest.main(["-v", __file__]))