test_dt_model.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import unittest
  2. import gymnasium as gym
  3. import numpy as np
  4. from rllib_dt.dt.dt_torch_model import DTTorchModel
  5. import ray
  6. from ray.rllib.policy.sample_batch import SampleBatch
  7. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  8. from ray.rllib.utils.numpy import convert_to_numpy
  9. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  10. tf1, tf, tfv = try_import_tf()
  11. torch, _ = try_import_torch()
  12. def _assert_outputs_equal(outputs):
  13. for i in range(1, len(outputs)):
  14. for key in outputs[0].keys():
  15. assert np.allclose(
  16. outputs[0][key], outputs[i][key]
  17. ), "outputs are different but they shouldn't be."
  18. def _assert_outputs_not_equal(outputs):
  19. for i in range(1, len(outputs)):
  20. for key in outputs[0].keys():
  21. assert not np.allclose(
  22. outputs[0][key], outputs[i][key]
  23. ), "some outputs are the same but they shouldn't be."
  24. def _generate_input_dict(B, T, obs_space, action_space):
  25. """Generate input_dict that has completely fake values."""
  26. # generate deterministic inputs
  27. # obs
  28. obs = np.arange(B * T * obs_space.shape[0], dtype=np.float32).reshape(
  29. (B, T, obs_space.shape[0])
  30. )
  31. # actions
  32. if isinstance(action_space, gym.spaces.Box):
  33. act = np.arange(B * T * action_space.shape[0], dtype=np.float32).reshape(
  34. (B, T, action_space.shape[0])
  35. )
  36. else:
  37. act = np.mod(np.arange(B * T, dtype=np.int32).reshape((B, T)), action_space.n)
  38. # returns to go
  39. rtg = np.arange(B * (T + 1), dtype=np.float32).reshape((B, T + 1, 1))
  40. # timesteps
  41. timesteps = np.stack([np.arange(T, dtype=np.int32) for _ in range(B)], axis=0)
  42. # attention mask
  43. mask = np.ones((B, T), dtype=np.float32)
  44. input_dict = SampleBatch(
  45. {
  46. SampleBatch.OBS: obs,
  47. SampleBatch.ACTIONS: act,
  48. SampleBatch.RETURNS_TO_GO: rtg,
  49. SampleBatch.T: timesteps,
  50. SampleBatch.ATTENTION_MASKS: mask,
  51. }
  52. )
  53. input_dict = convert_to_torch_tensor(input_dict)
  54. return input_dict
  55. class TestDTModel(unittest.TestCase):
  56. @classmethod
  57. def setUpClass(cls):
  58. ray.init()
  59. @classmethod
  60. def tearDownClass(cls):
  61. ray.shutdown()
  62. def test_torch_model_init(self):
  63. """Test models are initialized properly"""
  64. model_config = {
  65. "embed_dim": 32,
  66. "num_layers": 2,
  67. "max_seq_len": 4,
  68. "max_ep_len": 10,
  69. "num_heads": 2,
  70. "embed_pdrop": 0.1,
  71. "resid_pdrop": 0.1,
  72. "attn_pdrop": 0.1,
  73. "use_obs_output": False,
  74. "use_return_output": False,
  75. }
  76. num_outputs = 2
  77. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(num_outputs,))
  78. action_dim = 5
  79. action_spaces = [
  80. gym.spaces.Box(-1.0, 1.0, shape=(action_dim,)),
  81. gym.spaces.Discrete(action_dim),
  82. ]
  83. B, T = 3, 4
  84. for action_space in action_spaces:
  85. # Generate input dict.
  86. input_dict = _generate_input_dict(B, T, observation_space, action_space)
  87. # Do random initialization a few times and make sure outputs are different
  88. outputs = []
  89. for _ in range(10):
  90. model = DTTorchModel(
  91. observation_space,
  92. action_space,
  93. num_outputs,
  94. model_config,
  95. "model",
  96. )
  97. # so dropout is not in effect
  98. model.eval()
  99. model_out, _ = model(input_dict)
  100. output = model.get_prediction(model_out, input_dict)
  101. outputs.append(convert_to_numpy(output))
  102. _assert_outputs_not_equal(outputs)
  103. # Initialize once and make sure dropout is working
  104. model = DTTorchModel(
  105. observation_space,
  106. action_space,
  107. num_outputs,
  108. model_config,
  109. "model",
  110. )
  111. # Dropout should make outputs different in training mode
  112. model.train()
  113. outputs = []
  114. for _ in range(10):
  115. model_out, _ = model(input_dict)
  116. output = model.get_prediction(model_out, input_dict)
  117. outputs.append(convert_to_numpy(output))
  118. _assert_outputs_not_equal(outputs)
  119. # Dropout should make outputs the same in eval mode
  120. model.eval()
  121. outputs = []
  122. for _ in range(10):
  123. model_out, _ = model(input_dict)
  124. output = model.get_prediction(model_out, input_dict)
  125. outputs.append(convert_to_numpy(output))
  126. _assert_outputs_equal(outputs)
  127. def test_torch_model_prediction_target(self):
  128. """Test the get_prediction and get_targets function."""
  129. model_config = {
  130. "embed_dim": 16,
  131. "num_layers": 3,
  132. "max_seq_len": 3,
  133. "max_ep_len": 9,
  134. "num_heads": 1,
  135. "embed_pdrop": 0.2,
  136. "resid_pdrop": 0.2,
  137. "attn_pdrop": 0.2,
  138. "use_obs_output": True,
  139. "use_return_output": True,
  140. }
  141. num_outputs = 5
  142. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(num_outputs,))
  143. action_dim = 2
  144. action_spaces = [
  145. gym.spaces.Box(-1.0, 1.0, shape=(action_dim,)),
  146. gym.spaces.Discrete(action_dim),
  147. ]
  148. B, T = 2, 3
  149. for action_space in action_spaces:
  150. # Generate input dict.
  151. input_dict = _generate_input_dict(B, T, observation_space, action_space)
  152. # Make model and forward pass.
  153. model = DTTorchModel(
  154. observation_space,
  155. action_space,
  156. num_outputs,
  157. model_config,
  158. "model",
  159. )
  160. model_out, _ = model(input_dict)
  161. preds = model.get_prediction(model_out, input_dict)
  162. target = model.get_targets(model_out, input_dict)
  163. preds = convert_to_numpy(preds)
  164. target = convert_to_numpy(target)
  165. # Test the content and shape of output and target
  166. if isinstance(action_space, gym.spaces.Box):
  167. # test preds shape
  168. self.assertEqual(preds[SampleBatch.ACTIONS].shape, (B, T, action_dim))
  169. # test target shape and content
  170. self.assertEqual(target[SampleBatch.ACTIONS].shape, (B, T, action_dim))
  171. assert np.allclose(
  172. target[SampleBatch.ACTIONS],
  173. input_dict[SampleBatch.ACTIONS],
  174. )
  175. else:
  176. # test preds shape
  177. self.assertEqual(preds[SampleBatch.ACTIONS].shape, (B, T, action_dim))
  178. # test target shape and content
  179. self.assertEqual(target[SampleBatch.ACTIONS].shape, (B, T))
  180. assert np.allclose(
  181. target[SampleBatch.ACTIONS],
  182. input_dict[SampleBatch.ACTIONS],
  183. )
  184. # test preds shape
  185. self.assertEqual(preds[SampleBatch.OBS].shape, (B, T, num_outputs))
  186. # test target shape and content
  187. self.assertEqual(target[SampleBatch.OBS].shape, (B, T, num_outputs))
  188. assert np.allclose(
  189. target[SampleBatch.OBS],
  190. input_dict[SampleBatch.OBS],
  191. )
  192. # test preds shape
  193. self.assertEqual(preds[SampleBatch.RETURNS_TO_GO].shape, (B, T, 1))
  194. # test target shape and content
  195. self.assertEqual(target[SampleBatch.RETURNS_TO_GO].shape, (B, T, 1))
  196. assert np.allclose(
  197. target[SampleBatch.RETURNS_TO_GO],
  198. input_dict[SampleBatch.RETURNS_TO_GO][:, 1:, :],
  199. )
  200. def test_causal_masking(self):
  201. """Test that the transformer model' causal masking works."""
  202. model_config = {
  203. "embed_dim": 16,
  204. "num_layers": 2,
  205. "max_seq_len": 4,
  206. "max_ep_len": 10,
  207. "num_heads": 2,
  208. "embed_pdrop": 0,
  209. "resid_pdrop": 0,
  210. "attn_pdrop": 0,
  211. "use_obs_output": True,
  212. "use_return_output": True,
  213. }
  214. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,))
  215. action_space = gym.spaces.Box(-1.0, 1.0, shape=(2,))
  216. B = 2
  217. T = model_config["max_seq_len"]
  218. # Generate input dict.
  219. input_dict = _generate_input_dict(B, T, observation_space, action_space)
  220. # make model and forward with attention
  221. model = DTTorchModel(
  222. observation_space,
  223. action_space,
  224. 4,
  225. model_config,
  226. "model",
  227. )
  228. model_out, _ = model(input_dict)
  229. preds = model.get_prediction(model_out, input_dict, return_attentions=True)
  230. preds = convert_to_numpy(preds)
  231. # test properties of attentions
  232. attentions = preds["attentions"]
  233. self.assertEqual(
  234. len(attentions),
  235. model_config["num_layers"],
  236. "there should as many attention tensors as layers.",
  237. )
  238. # used to select the causal padded element of each attention tensor
  239. select_mask = np.triu(np.ones((3 * T, 3 * T), dtype=np.bool_), k=1)
  240. select_mask = np.tile(select_mask, (B, model_config["num_heads"], 1, 1))
  241. for attention in attentions:
  242. # check shape
  243. self.assertEqual(
  244. attention.shape, (B, model_config["num_heads"], T * 3, T * 3)
  245. )
  246. # check the upper triangular masking
  247. assert np.allclose(
  248. attention[select_mask], 0.0
  249. ), "masked elements should be zero."
  250. # check that the non-masked elements have non 0 scores
  251. # Note: it is very unlikely that randomly initialized weights will make
  252. # one of the scores be 0, as these scores are probabilities.
  253. assert not np.any(
  254. np.isclose(attention[np.logical_not(select_mask)], 0.0)
  255. ), "non masked elements should be nonzero."
  256. if __name__ == "__main__":
  257. import sys
  258. import pytest
  259. sys.exit(pytest.main(["-v", __file__]))