test_dt_policy.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. import unittest
  2. from typing import Dict
  3. import gymnasium as gym
  4. import numpy as np
  5. from rllib_dt.dt.dt_torch_policy import DTTorchPolicy
  6. import ray
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  9. tf1, tf, tfv = try_import_tf()
  10. torch, nn = try_import_torch()
  11. def _default_config():
  12. """Base config to use."""
  13. return {
  14. "model": {
  15. "max_seq_len": 4,
  16. },
  17. "embed_dim": 32,
  18. "num_layers": 2,
  19. "horizon": 10,
  20. "num_heads": 2,
  21. "embed_pdrop": 0.1,
  22. "resid_pdrop": 0.1,
  23. "attn_pdrop": 0.1,
  24. "framework": "torch",
  25. "lr": 1e-3,
  26. "lr_schedule": None,
  27. "optimizer": {
  28. "weight_decay": 1e-4,
  29. "betas": [0.9, 0.99],
  30. },
  31. "target_return": 200.0,
  32. "loss_coef_actions": 1.0,
  33. "loss_coef_obs": 0,
  34. "loss_coef_returns_to_go": 0,
  35. "num_gpus": 0,
  36. "_fake_gpus": None,
  37. "_enable_new_api_stack": False,
  38. }
  39. def _assert_input_dict_equals(d1: Dict[str, np.ndarray], d2: Dict[str, np.ndarray]):
  40. for key in d1.keys():
  41. assert key in d2.keys()
  42. for key in d2.keys():
  43. assert key in d1.keys()
  44. for key in d1.keys():
  45. assert isinstance(d1[key], np.ndarray), "input_dict should only be numpy array."
  46. assert isinstance(d2[key], np.ndarray), "input_dict should only be numpy array."
  47. assert d1[key].shape == d2[key].shape, "input_dict are of different shape."
  48. assert np.allclose(d1[key], d2[key]), "input_dict values are not equal."
  49. class TestDTPolicy(unittest.TestCase):
  50. @classmethod
  51. def setUpClass(cls):
  52. ray.init()
  53. @classmethod
  54. def tearDownClass(cls):
  55. ray.shutdown()
  56. def test_torch_postprocess_trajectory(self):
  57. """Test postprocess_trajectory"""
  58. config = _default_config()
  59. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(4,))
  60. action_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
  61. # Create policy
  62. policy = DTTorchPolicy(observation_space, action_space, config)
  63. # Generate input_dict with some data
  64. sample_batch = SampleBatch(
  65. {
  66. SampleBatch.REWARDS: np.array([1.0, 2.0, 1.0, 1.0]),
  67. SampleBatch.EPS_ID: np.array([0, 0, 0, 0]),
  68. }
  69. )
  70. # Do postprocess trajectory to calculate rtg.
  71. sample_batch = policy.postprocess_trajectory(sample_batch)
  72. # Assert that terminateds and truncateds are correctly set.
  73. assert (
  74. SampleBatch.TERMINATEDS in sample_batch
  75. ), "`terminateds` isn't part of the batch."
  76. assert (
  77. SampleBatch.TRUNCATEDS not in sample_batch
  78. ), "`truncateds` shouldn't be part of the batch (in this particular test case)."
  79. assert np.allclose(
  80. sample_batch[SampleBatch.TERMINATEDS],
  81. np.array([False, False, False, True]),
  82. ), "`terminateds` isn't set correctly."
  83. def test_torch_input_dict(self):
  84. """Test inference input_dict methods
  85. This is a minimal version the test in test_dt.py.
  86. The shapes of the input_dict might be confusing but it makes sense in
  87. context of what the function is supposed to do.
  88. Check action_distribution_fn for an explanation.
  89. """
  90. config = _default_config()
  91. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
  92. action_spaces = [
  93. gym.spaces.Box(-1.0, 1.0, shape=(1,)),
  94. gym.spaces.Discrete(4),
  95. ]
  96. for action_space in action_spaces:
  97. # Create policy
  98. policy = DTTorchPolicy(observation_space, action_space, config)
  99. # initial obs and input_dict
  100. obs = np.array([0.0, 1.0, 2.0])
  101. input_dict = policy.get_initial_input_dict(obs)
  102. # Check input_dict matches what it should be
  103. target_input_dict = SampleBatch(
  104. {
  105. SampleBatch.OBS: np.array(
  106. [
  107. [0.0, 0.0, 0.0],
  108. [0.0, 0.0, 0.0],
  109. [0.0, 0.0, 0.0],
  110. [0.0, 1.0, 2.0],
  111. ],
  112. dtype=np.float32,
  113. ),
  114. SampleBatch.ACTIONS: (
  115. np.array([[0.0], [0.0], [0.0]], dtype=np.float32)
  116. if isinstance(action_space, gym.spaces.Box)
  117. else np.array([0, 0, 0], dtype=np.int32)
  118. ),
  119. SampleBatch.RETURNS_TO_GO: np.array(
  120. [0.0, 0.0, 0.0], dtype=np.float32
  121. ),
  122. SampleBatch.REWARDS: np.zeros((), dtype=np.float32),
  123. SampleBatch.T: np.array([-1, -1, -1], dtype=np.int32),
  124. }
  125. )
  126. _assert_input_dict_equals(input_dict, target_input_dict)
  127. # Get next input_dict
  128. input_dict = policy.get_next_input_dict(
  129. input_dict,
  130. action=(
  131. np.asarray([1.0], dtype=np.float32)
  132. if isinstance(action_space, gym.spaces.Box)
  133. else np.asarray(1, dtype=np.int32)
  134. ),
  135. reward=1.0,
  136. next_obs=np.array([3.0, 4.0, 5.0]),
  137. extra={
  138. SampleBatch.RETURNS_TO_GO: config["target_return"],
  139. },
  140. )
  141. # Check input_dict matches what it should be
  142. target_input_dict = SampleBatch(
  143. {
  144. SampleBatch.OBS: np.array(
  145. [
  146. [0.0, 0.0, 0.0],
  147. [0.0, 0.0, 0.0],
  148. [0.0, 1.0, 2.0],
  149. [3.0, 4.0, 5.0],
  150. ],
  151. dtype=np.float32,
  152. ),
  153. SampleBatch.ACTIONS: (
  154. np.array([[0.0], [0.0], [1.0]], dtype=np.float32)
  155. if isinstance(action_space, gym.spaces.Box)
  156. else np.array([0, 0, 1], dtype=np.int32)
  157. ),
  158. SampleBatch.RETURNS_TO_GO: np.array(
  159. [0.0, 0.0, config["target_return"]], dtype=np.float32
  160. ),
  161. SampleBatch.REWARDS: np.asarray(1.0, dtype=np.float32),
  162. SampleBatch.T: np.array([-1, -1, 0], dtype=np.int32),
  163. }
  164. )
  165. _assert_input_dict_equals(input_dict, target_input_dict)
  166. def test_torch_action(self):
  167. """Test policy's action_distribution_fn and extra_action_out methods by
  168. calling compute_actions_from_input_dict which works those two methods
  169. in conjunction.
  170. """
  171. config = _default_config()
  172. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
  173. action_spaces = [
  174. gym.spaces.Box(-1.0, 1.0, shape=(1,)),
  175. gym.spaces.Discrete(4),
  176. ]
  177. for action_space in action_spaces:
  178. # Create policy
  179. policy = DTTorchPolicy(observation_space, action_space, config)
  180. # input_dict for initial observation
  181. input_dict = SampleBatch(
  182. {
  183. SampleBatch.OBS: np.array(
  184. [
  185. [
  186. [0.0, 0.0, 0.0],
  187. [0.0, 0.0, 0.0],
  188. [0.0, 0.0, 0.0],
  189. [0.0, 1.0, 2.0],
  190. ]
  191. ],
  192. dtype=np.float32,
  193. ),
  194. SampleBatch.ACTIONS: (
  195. np.array([[[0.0], [0.0], [0.0]]], dtype=np.float32)
  196. if isinstance(action_space, gym.spaces.Box)
  197. else np.array([[0, 0, 0]], dtype=np.int32)
  198. ),
  199. SampleBatch.RETURNS_TO_GO: np.array(
  200. [[0.0, 0.0, 0.0]], dtype=np.float32
  201. ),
  202. SampleBatch.REWARDS: np.array([0.0], dtype=np.float32),
  203. SampleBatch.T: np.array([[-1, -1, -1]], dtype=np.int32),
  204. }
  205. )
  206. # Run compute_actions_from_input_dict
  207. actions, _, extras = policy.compute_actions_from_input_dict(
  208. input_dict,
  209. explore=False,
  210. timestep=None,
  211. )
  212. # Check actions
  213. assert actions.shape == (
  214. 1,
  215. *action_space.shape,
  216. ), "actions has incorrect shape."
  217. # Check extras
  218. assert (
  219. SampleBatch.RETURNS_TO_GO in extras
  220. ), "extras should contain returns_to_go."
  221. assert extras[SampleBatch.RETURNS_TO_GO].shape == (
  222. 1,
  223. ), "extras['returns_to_go'] has incorrect shape."
  224. assert np.isclose(
  225. extras[SampleBatch.RETURNS_TO_GO],
  226. np.asarray([config["target_return"]], dtype=np.float32),
  227. ), "extras['returns_to_go'] should contain target_return."
  228. # input_dict for non-initial observation
  229. input_dict = SampleBatch(
  230. {
  231. SampleBatch.OBS: np.array(
  232. [
  233. [
  234. [0.0, 0.0, 0.0],
  235. [0.0, 0.0, 0.0],
  236. [0.0, 1.0, 2.0],
  237. [3.0, 4.0, 5.0],
  238. ]
  239. ],
  240. dtype=np.float32,
  241. ),
  242. SampleBatch.ACTIONS: (
  243. np.array([[[0.0], [0.0], [1.0]]], dtype=np.float32)
  244. if isinstance(action_space, gym.spaces.Box)
  245. else np.array([[0, 0, 1]], dtype=np.int32)
  246. ),
  247. SampleBatch.RETURNS_TO_GO: np.array(
  248. [[0.0, 0.0, config["target_return"]]], dtype=np.float32
  249. ),
  250. SampleBatch.REWARDS: np.array([10.0], dtype=np.float32),
  251. SampleBatch.T: np.array([[-1, -1, 0]], dtype=np.int32),
  252. }
  253. )
  254. # Run compute_actions_from_input_dict
  255. actions, _, extras = policy.compute_actions_from_input_dict(
  256. input_dict,
  257. explore=False,
  258. timestep=None,
  259. )
  260. # Check actions
  261. assert actions.shape == (
  262. 1,
  263. *action_space.shape,
  264. ), "actions has incorrect shape."
  265. # Check extras
  266. assert (
  267. SampleBatch.RETURNS_TO_GO in extras
  268. ), "extras should contain returns_to_go."
  269. assert extras[SampleBatch.RETURNS_TO_GO].shape == (
  270. 1,
  271. ), "extras['returns_to_go'] has incorrect shape."
  272. assert np.isclose(
  273. extras[SampleBatch.RETURNS_TO_GO],
  274. np.asarray([config["target_return"] - 10.0], dtype=np.float32),
  275. ), "extras['returns_to_go'] should contain target_return."
  276. def test_loss(self):
  277. """Test loss function."""
  278. config = _default_config()
  279. config["embed_pdrop"] = 0
  280. config["resid_pdrop"] = 0
  281. config["attn_pdrop"] = 0
  282. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
  283. action_spaces = [
  284. gym.spaces.Box(-1.0, 1.0, shape=(1,)),
  285. gym.spaces.Discrete(4),
  286. ]
  287. for action_space in action_spaces:
  288. # Create policy
  289. policy = DTTorchPolicy(observation_space, action_space, config)
  290. # Run loss functions on batches with different items in the mask to make
  291. # sure the masks are working and making the loss the same.
  292. batch1 = SampleBatch(
  293. {
  294. SampleBatch.OBS: np.array(
  295. [
  296. [
  297. [0.0, 0.0, 0.0],
  298. [0.0, 0.0, 0.0],
  299. [0.0, 1.0, 2.0],
  300. [3.0, 4.0, 5.0],
  301. ]
  302. ],
  303. dtype=np.float32,
  304. ),
  305. SampleBatch.ACTIONS: (
  306. np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32)
  307. if isinstance(action_space, gym.spaces.Box)
  308. else np.array([[0, 0, 1, 3]], dtype=np.int64)
  309. ),
  310. SampleBatch.RETURNS_TO_GO: np.array(
  311. [[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32
  312. ),
  313. SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32),
  314. SampleBatch.ATTENTION_MASKS: np.array(
  315. [[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
  316. ),
  317. }
  318. )
  319. batch2 = SampleBatch(
  320. {
  321. SampleBatch.OBS: np.array(
  322. [
  323. [
  324. [1.0, 1.0, -1.0],
  325. [1.0, 10.0, 12.0],
  326. [0.0, 1.0, 2.0],
  327. [3.0, 4.0, 5.0],
  328. ]
  329. ],
  330. dtype=np.float32,
  331. ),
  332. SampleBatch.ACTIONS: (
  333. np.array([[[1.0], [-0.5], [1.0], [0.5]]], dtype=np.float32)
  334. if isinstance(action_space, gym.spaces.Box)
  335. else np.array([[2, 1, 1, 3]], dtype=np.int64)
  336. ),
  337. SampleBatch.RETURNS_TO_GO: np.array(
  338. [[[200.0], [-10.0], [100.0], [90.0], [80.0]]], dtype=np.float32
  339. ),
  340. SampleBatch.T: np.array([[9, 3, 0, 1]], dtype=np.int32),
  341. SampleBatch.ATTENTION_MASKS: np.array(
  342. [[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
  343. ),
  344. }
  345. )
  346. loss1 = policy.loss(policy.model, policy.dist_class, batch1)
  347. loss2 = policy.loss(policy.model, policy.dist_class, batch2)
  348. loss1 = loss1.detach().cpu().item()
  349. loss2 = loss2.detach().cpu().item()
  350. assert np.isclose(loss1, loss2), "Masks are not working for losses."
  351. # Run loss on a widely different batch and make sure the loss is different.
  352. batch3 = SampleBatch(
  353. {
  354. SampleBatch.OBS: np.array(
  355. [
  356. [
  357. [1.0, 1.0, -20.0],
  358. [0.1, 10.0, 12.0],
  359. [1.4, 12.0, -9.0],
  360. [6.0, 40.0, -2.0],
  361. ]
  362. ],
  363. dtype=np.float32,
  364. ),
  365. SampleBatch.ACTIONS: (
  366. np.array([[[2.0], [-1.5], [0.2], [0.1]]], dtype=np.float32)
  367. if isinstance(action_space, gym.spaces.Box)
  368. else np.array([[1, 3, 0, 2]], dtype=np.int64)
  369. ),
  370. SampleBatch.RETURNS_TO_GO: np.array(
  371. [[[90.0], [80.0], [70.0], [60.0], [50.0]]], dtype=np.float32
  372. ),
  373. SampleBatch.T: np.array([[3, 4, 5, 6]], dtype=np.int32),
  374. SampleBatch.ATTENTION_MASKS: np.array(
  375. [[1.0, 1.0, 1.0, 1.0]], dtype=np.float32
  376. ),
  377. }
  378. )
  379. loss3 = policy.loss(policy.model, policy.dist_class, batch3)
  380. loss3 = loss3.detach().cpu().item()
  381. assert not np.isclose(
  382. loss1, loss3
  383. ), "Widely different inputs are giving the same loss value."
  384. def test_loss_coef(self):
  385. """Test the loss_coef_{key} config options."""
  386. config = _default_config()
  387. config["embed_pdrop"] = 0
  388. config["resid_pdrop"] = 0
  389. config["attn_pdrop"] = 0
  390. # set initial action coef to 0
  391. config["loss_coef_actions"] = 0
  392. observation_space = gym.spaces.Box(-1.0, 1.0, shape=(3,))
  393. action_spaces = [
  394. gym.spaces.Box(-1.0, 1.0, shape=(1,)),
  395. gym.spaces.Discrete(4),
  396. ]
  397. for action_space in action_spaces:
  398. batch = SampleBatch(
  399. {
  400. SampleBatch.OBS: np.array(
  401. [
  402. [
  403. [0.0, 0.0, 0.0],
  404. [0.0, 0.0, 0.0],
  405. [0.0, 1.0, 2.0],
  406. [3.0, 4.0, 5.0],
  407. ]
  408. ],
  409. dtype=np.float32,
  410. ),
  411. SampleBatch.ACTIONS: (
  412. np.array([[[0.0], [0.0], [1.0], [0.5]]], dtype=np.float32)
  413. if isinstance(action_space, gym.spaces.Box)
  414. else np.array([[0, 0, 1, 3]], dtype=np.int64)
  415. ),
  416. SampleBatch.RETURNS_TO_GO: np.array(
  417. [[[0.0], [0.0], [100.0], [90.0], [80.0]]], dtype=np.float32
  418. ),
  419. SampleBatch.T: np.array([[0, 0, 0, 1]], dtype=np.int32),
  420. SampleBatch.ATTENTION_MASKS: np.array(
  421. [[0.0, 0.0, 1.0, 1.0]], dtype=np.float32
  422. ),
  423. }
  424. )
  425. keys = [SampleBatch.ACTIONS, SampleBatch.OBS, SampleBatch.RETURNS_TO_GO]
  426. for key in keys:
  427. # create policy and run loss with different coefs
  428. # create policy 1 with coef = 1
  429. config1 = config.copy()
  430. config1[f"loss_coef_{key}"] = 1.0
  431. policy1 = DTTorchPolicy(observation_space, action_space, config1)
  432. loss1 = policy1.loss(policy1.model, policy1.dist_class, batch)
  433. loss1 = loss1.detach().cpu().item()
  434. # create policy 2 with coef = 10
  435. config2 = config.copy()
  436. config2[f"loss_coef_{key}"] = 10.0
  437. policy2 = DTTorchPolicy(observation_space, action_space, config2)
  438. # Copy the weights over so they output the same loss without scaling.
  439. policy2.set_weights(policy1.get_weights())
  440. loss2 = policy2.loss(policy2.model, policy2.dist_class, batch)
  441. loss2 = loss2.detach().cpu().item()
  442. # Compare loss, should be factor of 10 difference.
  443. self.assertAlmostEqual(
  444. loss2 / loss1,
  445. 10.0,
  446. places=3,
  447. msg="the two losses should be different to a factor of 10.",
  448. )
  449. if __name__ == "__main__":
  450. import sys
  451. import pytest
  452. sys.exit(pytest.main(["-v", __file__]))