test_dt.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. import unittest
  2. from typing import Dict
  3. import gymnasium as gym
  4. import numpy as np
  5. from rllib_dt.dt.dt import DTConfig
  6. import ray
  7. from ray.rllib import SampleBatch
  8. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  9. from ray.rllib.utils.test_utils import check_train_results
  10. tf1, tf, tfv = try_import_tf()
  11. torch, _ = try_import_torch()
  12. def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
  13. for key in d1.keys():
  14. assert key in d2.keys()
  15. for key in d2.keys():
  16. assert key in d1.keys()
  17. for key in d1.keys():
  18. assert isinstance(d1[key], np.ndarray)
  19. assert isinstance(d2[key], np.ndarray)
  20. assert d1[key].shape == d2[key].shape
  21. assert np.allclose(d1[key], d2[key])
  22. class TestDT(unittest.TestCase):
  23. @classmethod
  24. def setUpClass(cls):
  25. ray.init()
  26. @classmethod
  27. def tearDownClass(cls):
  28. ray.shutdown()
  29. def test_dt_compilation(self):
  30. """Test whether a DT algorithm can be built with all supported frameworks."""
  31. config = (
  32. DTConfig()
  33. .environment(
  34. env="Pendulum-v1",
  35. clip_actions=True,
  36. normalize_actions=True,
  37. )
  38. .framework("torch")
  39. .offline_data(
  40. input_="dataset",
  41. input_config={
  42. "format": "json",
  43. "paths": [
  44. "s3://anonymous@air-example-data/rllib/pendulum/large.json"
  45. ],
  46. },
  47. actions_in_input_normalized=True,
  48. )
  49. .training(
  50. train_batch_size=200,
  51. replay_buffer_config={
  52. "capacity": 8,
  53. },
  54. model={
  55. "max_seq_len": 4,
  56. },
  57. num_layers=1,
  58. num_heads=1,
  59. embed_dim=64,
  60. horizon=200,
  61. )
  62. .evaluation(
  63. target_return=-120,
  64. evaluation_interval=2,
  65. evaluation_num_workers=0,
  66. evaluation_duration=10,
  67. evaluation_duration_unit="episodes",
  68. evaluation_parallel_to_training=False,
  69. evaluation_config=DTConfig.overrides(input_="sampler", explore=False),
  70. )
  71. .rollouts(
  72. num_rollout_workers=0,
  73. )
  74. .reporting(
  75. min_train_timesteps_per_iteration=10,
  76. )
  77. .experimental(
  78. _disable_preprocessor_api=True,
  79. )
  80. )
  81. num_iterations = 4
  82. for _ in ["torch"]:
  83. algo = config.build()
  84. # check if 4 iterations raises any errors
  85. for i in range(num_iterations):
  86. results = algo.train()
  87. check_train_results(results)
  88. print(results)
  89. if (i + 1) % 2 == 0:
  90. # evaluation happens every 2 iterations
  91. eval_results = results["evaluation"]
  92. print(
  93. f"iter={algo.iteration} "
  94. f"R={eval_results['episode_reward_mean']}"
  95. )
  96. # do example inference rollout
  97. env = gym.make("Pendulum-v1")
  98. obs, _ = env.reset()
  99. input_dict = algo.get_initial_input_dict(obs)
  100. for _ in range(200):
  101. action, _, extra = algo.compute_single_action(input_dict=input_dict)
  102. obs, reward, terminated, truncated, _ = env.step(action)
  103. if terminated or truncated:
  104. break
  105. else:
  106. input_dict = algo.get_next_input_dict(
  107. input_dict,
  108. action,
  109. reward,
  110. obs,
  111. extra,
  112. )
  113. env.close()
  114. algo.stop()
  115. def test_inference_methods(self):
  116. """Test inference methods."""
  117. config = (
  118. DTConfig()
  119. .environment(
  120. env="Pendulum-v1",
  121. clip_actions=True,
  122. normalize_actions=True,
  123. )
  124. .framework("torch")
  125. .training(
  126. train_batch_size=200,
  127. replay_buffer_config={
  128. "capacity": 8,
  129. },
  130. model={
  131. "max_seq_len": 3,
  132. },
  133. num_layers=1,
  134. num_heads=1,
  135. embed_dim=64,
  136. horizon=200,
  137. )
  138. .evaluation(
  139. target_return=-120,
  140. )
  141. .rollouts(
  142. num_rollout_workers=0,
  143. )
  144. .experimental(_disable_preprocessor_api=True)
  145. )
  146. algo = config.build()
  147. # Do a controlled fake rollout for 2 steps and check input_dict
  148. # first input_dict
  149. obs = np.array([0.0, 1.0, 2.0])
  150. input_dict = algo.get_initial_input_dict(obs)
  151. target = SampleBatch(
  152. {
  153. SampleBatch.OBS: np.array(
  154. [
  155. [0.0, 0.0, 0.0],
  156. [0.0, 0.0, 0.0],
  157. [0.0, 1.0, 2.0],
  158. ],
  159. dtype=np.float32,
  160. ),
  161. SampleBatch.ACTIONS: np.array([[0.0], [0.0]], dtype=np.float32),
  162. SampleBatch.RETURNS_TO_GO: np.array([0.0, 0.0], dtype=np.float32),
  163. SampleBatch.REWARDS: np.zeros((), dtype=np.float32),
  164. SampleBatch.T: np.array([-1, -1], dtype=np.int32),
  165. }
  166. )
  167. _assert_input_dict_equals(input_dict, target)
  168. # forward pass with first input_dict
  169. action, _, extra = algo.compute_single_action(input_dict=input_dict)
  170. assert action.shape == (1,)
  171. assert SampleBatch.RETURNS_TO_GO in extra
  172. assert np.isclose(extra[SampleBatch.RETURNS_TO_GO], -120.0)
  173. # second input_dict
  174. action = np.array([0.5])
  175. obs = np.array([3.0, 4.0, 5.0])
  176. reward = -10.0
  177. input_dict = algo.get_next_input_dict(
  178. input_dict,
  179. action,
  180. reward,
  181. obs,
  182. extra,
  183. )
  184. target = SampleBatch(
  185. {
  186. SampleBatch.OBS: np.array(
  187. [
  188. [0.0, 0.0, 0.0],
  189. [0.0, 1.0, 2.0],
  190. [3.0, 4.0, 5.0],
  191. ],
  192. dtype=np.float32,
  193. ),
  194. SampleBatch.ACTIONS: np.array([[0.0], [0.5]], dtype=np.float32),
  195. SampleBatch.RETURNS_TO_GO: np.array([0.0, -120.0], dtype=np.float32),
  196. SampleBatch.REWARDS: np.asarray(-10.0),
  197. SampleBatch.T: np.array([-1, 0], dtype=np.int32),
  198. }
  199. )
  200. _assert_input_dict_equals(input_dict, target)
  201. # forward pass with second input_dict
  202. action, _, extra = algo.compute_single_action(input_dict=input_dict)
  203. assert action.shape == (1,)
  204. assert SampleBatch.RETURNS_TO_GO in extra
  205. assert np.isclose(extra[SampleBatch.RETURNS_TO_GO], -110.0)
  206. # third input_dict
  207. action = np.array([-0.2])
  208. obs = np.array([6.0, 7.0, 8.0])
  209. reward = -20.0
  210. input_dict = algo.get_next_input_dict(
  211. input_dict,
  212. action,
  213. reward,
  214. obs,
  215. extra,
  216. )
  217. target = SampleBatch(
  218. {
  219. SampleBatch.OBS: np.array(
  220. [
  221. [0.0, 1.0, 2.0],
  222. [3.0, 4.0, 5.0],
  223. [6.0, 7.0, 8.0],
  224. ],
  225. dtype=np.float32,
  226. ),
  227. SampleBatch.ACTIONS: np.array([[0.5], [-0.2]], dtype=np.float32),
  228. SampleBatch.RETURNS_TO_GO: np.array([-120, -110.0], dtype=np.float32),
  229. SampleBatch.REWARDS: np.asarray(-20.0),
  230. SampleBatch.T: np.array([0, 1], dtype=np.int32),
  231. }
  232. )
  233. _assert_input_dict_equals(input_dict, target)
  234. if __name__ == "__main__":
  235. import sys
  236. import pytest
  237. sys.exit(pytest.main(["-v", __file__]))