test_action_distributions.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. from functools import partial
  2. from gymnasium.spaces import Box, Dict, Tuple
  3. import numpy as np
  4. from scipy.stats import beta, norm
  5. import tree # pip install dm_tree
  6. import unittest
  7. from ray.rllib.models.jax.jax_action_dist import JAXCategorical
  8. from ray.rllib.models.tf.tf_action_dist import (
  9. Beta,
  10. Categorical,
  11. DiagGaussian,
  12. GumbelSoftmax,
  13. MultiActionDistribution,
  14. MultiCategorical,
  15. SquashedGaussian,
  16. )
  17. from ray.rllib.models.torch.torch_action_dist import (
  18. TorchBeta,
  19. TorchCategorical,
  20. TorchDiagGaussian,
  21. TorchMultiActionDistribution,
  22. TorchMultiCategorical,
  23. TorchSquashedGaussian,
  24. )
  25. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  26. from ray.rllib.utils.numpy import (
  27. MIN_LOG_NN_OUTPUT,
  28. MAX_LOG_NN_OUTPUT,
  29. softmax,
  30. SMALL_NUMBER,
  31. LARGE_INTEGER,
  32. )
  33. from ray.rllib.utils.test_utils import check, framework_iterator
  34. tf1, tf, tfv = try_import_tf()
  35. torch, _ = try_import_torch()
  36. class TestActionDistributions(unittest.TestCase):
  37. """Tests ActionDistribution classes."""
  38. @classmethod
  39. def setUpClass(cls) -> None:
  40. # Set seeds for deterministic tests (make sure we don't fail
  41. # because of "bad" sampling).
  42. np.random.seed(42 + 1)
  43. torch.manual_seed(42 + 1)
  44. def _stability_test(
  45. self,
  46. distribution_cls,
  47. network_output_shape,
  48. fw,
  49. sess=None,
  50. bounds=None,
  51. extra_kwargs=None,
  52. ):
  53. extreme_values = [
  54. 0.0,
  55. float(LARGE_INTEGER),
  56. -float(LARGE_INTEGER),
  57. 1.1e-34,
  58. 1.1e34,
  59. -1.1e-34,
  60. -1.1e34,
  61. SMALL_NUMBER,
  62. -SMALL_NUMBER,
  63. ]
  64. inputs = np.zeros(shape=network_output_shape, dtype=np.float32)
  65. for batch_item in range(network_output_shape[0]):
  66. for num in range(len(inputs[batch_item]) // 2):
  67. inputs[batch_item][num] = np.random.choice(extreme_values)
  68. else:
  69. # For Gaussians, the second half of the vector is
  70. # log standard deviations, and should therefore be
  71. # the log of a positive number >= 1.
  72. inputs[batch_item][num] = np.log(
  73. max(1, np.random.choice((extreme_values)))
  74. )
  75. dist = distribution_cls(inputs, {}, **(extra_kwargs or {}))
  76. for _ in range(100):
  77. sample = dist.sample()
  78. if fw == "jax":
  79. sample_check = sample
  80. elif fw != "tf":
  81. sample_check = sample.numpy()
  82. else:
  83. sample_check = sess.run(sample)
  84. assert not np.any(np.isnan(sample_check))
  85. assert np.all(np.isfinite(sample_check))
  86. if bounds:
  87. assert np.min(sample_check) >= bounds[0]
  88. assert np.max(sample_check) <= bounds[1]
  89. # Make sure bounds make sense and are actually also being
  90. # sampled.
  91. if isinstance(bounds[0], int):
  92. assert isinstance(bounds[1], int)
  93. assert bounds[0] in sample_check
  94. assert bounds[1] in sample_check
  95. logp = dist.logp(sample)
  96. if fw == "jax":
  97. logp_check = logp
  98. elif fw != "tf":
  99. logp_check = logp.numpy()
  100. else:
  101. logp_check = sess.run(logp)
  102. assert not np.any(np.isnan(logp_check))
  103. assert np.all(np.isfinite(logp_check))
  104. def test_categorical(self):
  105. batch_size = 10000
  106. num_categories = 4
  107. # Create categorical distribution with n categories.
  108. inputs_space = Box(
  109. -1.0, 2.0, shape=(batch_size, num_categories), dtype=np.float32
  110. )
  111. inputs_space.seed(42)
  112. values_space = Box(0, num_categories - 1, shape=(batch_size,), dtype=np.int32)
  113. values_space.seed(42)
  114. inputs = inputs_space.sample()
  115. for fw, sess in framework_iterator(session=True):
  116. # Create the correct distribution object.
  117. cls = (
  118. JAXCategorical
  119. if fw == "jax"
  120. else Categorical
  121. if fw != "torch"
  122. else TorchCategorical
  123. )
  124. categorical = cls(inputs, {})
  125. # Do a stability test using extreme NN outputs to see whether
  126. # sampling and logp'ing result in NaN or +/-inf values.
  127. self._stability_test(
  128. cls,
  129. inputs_space.shape,
  130. fw=fw,
  131. sess=sess,
  132. bounds=(0, num_categories - 1),
  133. )
  134. # Batch of size=3 and deterministic (True).
  135. expected = np.transpose(np.argmax(inputs, axis=-1))
  136. # Sample, expect always max value
  137. # (max likelihood for deterministic draw).
  138. out = categorical.deterministic_sample()
  139. check(out, expected)
  140. # Batch of size=3 and non-deterministic -> expect roughly the mean.
  141. out = categorical.sample()
  142. check(
  143. np.mean(out)
  144. if fw == "jax"
  145. else tf.reduce_mean(out)
  146. if fw != "torch"
  147. else torch.mean(out.float()),
  148. 1.0,
  149. decimals=0,
  150. )
  151. # Test log-likelihood outputs.
  152. probs = softmax(inputs)
  153. values = values_space.sample()
  154. out = categorical.logp(values if fw != "torch" else torch.Tensor(values))
  155. expected = []
  156. for i in range(batch_size):
  157. expected.append(np.sum(np.log(np.array(probs[i][values[i]]))))
  158. check(out, expected, decimals=4)
  159. # Test entropy outputs.
  160. out = categorical.entropy()
  161. expected_entropy = -np.sum(probs * np.log(probs), -1)
  162. check(out, expected_entropy)
  163. def test_multi_categorical(self):
  164. batch_size = 100
  165. num_categories = 3
  166. num_sub_distributions = 5
  167. # Create 5 categorical distributions of 3 categories each.
  168. inputs_space = Box(
  169. -1.0, 2.0, shape=(batch_size, num_sub_distributions * num_categories)
  170. )
  171. inputs_space.seed(42)
  172. values_space = Box(
  173. 0,
  174. num_categories - 1,
  175. shape=(num_sub_distributions, batch_size),
  176. dtype=np.int32,
  177. )
  178. values_space.seed(42)
  179. inputs = inputs_space.sample()
  180. input_lengths = [num_categories] * num_sub_distributions
  181. inputs_split = np.split(inputs, num_sub_distributions, axis=1)
  182. for fw, sess in framework_iterator(session=True):
  183. # Create the correct distribution object.
  184. cls = MultiCategorical if fw != "torch" else TorchMultiCategorical
  185. multi_categorical = cls(inputs, None, input_lengths)
  186. # Do a stability test using extreme NN outputs to see whether
  187. # sampling and logp'ing result in NaN or +/-inf values.
  188. self._stability_test(
  189. cls,
  190. inputs_space.shape,
  191. fw=fw,
  192. sess=sess,
  193. bounds=(0, num_categories - 1),
  194. extra_kwargs={"input_lens": input_lengths},
  195. )
  196. # Batch of size=3 and deterministic (True).
  197. expected = np.transpose(np.argmax(inputs_split, axis=-1))
  198. # Sample, expect always max value
  199. # (max likelihood for deterministic draw).
  200. out = multi_categorical.deterministic_sample()
  201. check(out, expected)
  202. # Batch of size=3 and non-deterministic -> expect roughly the mean.
  203. out = multi_categorical.sample()
  204. check(
  205. tf.reduce_mean(out) if fw != "torch" else torch.mean(out.float()),
  206. 1.0,
  207. decimals=0,
  208. )
  209. # Test log-likelihood outputs.
  210. probs = softmax(inputs_split)
  211. values = values_space.sample()
  212. out = multi_categorical.logp(
  213. values
  214. if fw != "torch"
  215. else [torch.Tensor(values[i]) for i in range(num_sub_distributions)]
  216. ) # v in np.stack(values, 1)])
  217. expected = []
  218. for i in range(batch_size):
  219. expected.append(
  220. np.sum(
  221. np.log(
  222. np.array(
  223. [
  224. probs[j][i][values[j][i]]
  225. for j in range(num_sub_distributions)
  226. ]
  227. )
  228. )
  229. )
  230. )
  231. check(out, expected, decimals=4)
  232. # Test entropy outputs.
  233. out = multi_categorical.entropy()
  234. expected_entropy = -np.sum(np.sum(probs * np.log(probs), 0), -1)
  235. check(out, expected_entropy)
  236. def test_squashed_gaussian(self):
  237. """Tests the SquashedGaussian ActionDistribution for all frameworks."""
  238. input_space = Box(-2.0, 2.0, shape=(2000, 10))
  239. input_space.seed(42)
  240. low, high = -2.0, 1.0
  241. for fw, sess in framework_iterator(session=True):
  242. cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian
  243. # Do a stability test using extreme NN outputs to see whether
  244. # sampling and logp'ing result in NaN or +/-inf values.
  245. self._stability_test(
  246. cls, input_space.shape, fw=fw, sess=sess, bounds=(low, high)
  247. )
  248. # Batch of size=n and deterministic.
  249. inputs = input_space.sample()
  250. means, _ = np.split(inputs, 2, axis=-1)
  251. squashed_distribution = cls(inputs, {}, low=low, high=high)
  252. expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
  253. # Sample n times, expect always mean value (deterministic draw).
  254. out = squashed_distribution.deterministic_sample()
  255. check(out, expected)
  256. # Batch of size=n and non-deterministic -> expect roughly the mean.
  257. inputs = input_space.sample()
  258. means, log_stds = np.split(inputs, 2, axis=-1)
  259. squashed_distribution = cls(inputs, {}, low=low, high=high)
  260. expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
  261. values = squashed_distribution.sample()
  262. if sess:
  263. values = sess.run(values)
  264. else:
  265. values = values.numpy()
  266. self.assertTrue(np.max(values) <= high)
  267. self.assertTrue(np.min(values) >= low)
  268. check(np.mean(values), expected.mean(), decimals=1)
  269. # Test log-likelihood outputs.
  270. sampled_action_logp = squashed_distribution.logp(
  271. values if fw != "torch" else torch.Tensor(values)
  272. )
  273. if sess:
  274. sampled_action_logp = sess.run(sampled_action_logp)
  275. else:
  276. sampled_action_logp = sampled_action_logp.numpy()
  277. # Convert to parameters for distr.
  278. stds = np.exp(np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
  279. # Unsquash values, then get log-llh from regular gaussian.
  280. # atanh_in = np.clip((values - low) / (high - low) * 2.0 - 1.0,
  281. # -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER)
  282. normed_values = (values - low) / (high - low) * 2.0 - 1.0
  283. save_normed_values = np.clip(
  284. normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER
  285. )
  286. unsquashed_values = np.arctanh(save_normed_values)
  287. log_prob_unsquashed = np.sum(
  288. np.log(norm.pdf(unsquashed_values, means, stds)), -1
  289. )
  290. log_prob = log_prob_unsquashed - np.sum(
  291. np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1
  292. )
  293. check(np.sum(sampled_action_logp), np.sum(log_prob), rtol=0.05)
  294. # NN output.
  295. means = np.array(
  296. [[0.1, 0.2, 0.3, 0.4, 50.0], [-0.1, -0.2, -0.3, -0.4, -1.0]]
  297. )
  298. log_stds = np.array(
  299. [[0.8, -0.2, 0.3, -1.0, 2.0], [0.7, -0.3, 0.4, -0.9, 2.0]]
  300. )
  301. squashed_distribution = cls(
  302. inputs=np.concatenate([means, log_stds], axis=-1),
  303. model={},
  304. low=low,
  305. high=high,
  306. )
  307. # Convert to parameters for distr.
  308. stds = np.exp(log_stds)
  309. # Values to get log-likelihoods for.
  310. values = np.array(
  311. [[0.9, 0.2, 0.4, -0.1, -1.05], [-0.9, -0.2, 0.4, -0.1, -1.05]]
  312. )
  313. # Unsquash values, then get log-llh from regular gaussian.
  314. unsquashed_values = np.arctanh((values - low) / (high - low) * 2.0 - 1.0)
  315. log_prob_unsquashed = np.sum(
  316. np.log(norm.pdf(unsquashed_values, means, stds)), -1
  317. )
  318. log_prob = log_prob_unsquashed - np.sum(
  319. np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1
  320. )
  321. outs = squashed_distribution.logp(
  322. values if fw != "torch" else torch.Tensor(values)
  323. )
  324. if sess:
  325. outs = sess.run(outs)
  326. check(outs, log_prob, decimals=4)
  327. def test_diag_gaussian(self):
  328. """Tests the DiagGaussian ActionDistribution for all frameworks."""
  329. input_space = Box(-2.0, 1.0, shape=(2000, 10))
  330. input_space.seed(42)
  331. for fw, sess in framework_iterator(session=True):
  332. cls = DiagGaussian if fw != "torch" else TorchDiagGaussian
  333. # Do a stability test using extreme NN outputs to see whether
  334. # sampling and logp'ing result in NaN or +/-inf values.
  335. self._stability_test(cls, input_space.shape, fw=fw, sess=sess)
  336. # Batch of size=n and deterministic.
  337. inputs = input_space.sample()
  338. means, _ = np.split(inputs, 2, axis=-1)
  339. diag_distribution = cls(inputs, {})
  340. expected = means
  341. # Sample n times, expect always mean value (deterministic draw).
  342. out = diag_distribution.deterministic_sample()
  343. check(out, expected)
  344. # Batch of size=n and non-deterministic -> expect roughly the mean.
  345. inputs = input_space.sample()
  346. means, log_stds = np.split(inputs, 2, axis=-1)
  347. diag_distribution = cls(inputs, {})
  348. expected = means
  349. values = diag_distribution.sample()
  350. if sess:
  351. values = sess.run(values)
  352. else:
  353. values = values.numpy()
  354. check(np.mean(values), expected.mean(), decimals=1)
  355. # Test log-likelihood outputs.
  356. sampled_action_logp = diag_distribution.logp(
  357. values if fw != "torch" else torch.Tensor(values)
  358. )
  359. if sess:
  360. sampled_action_logp = sess.run(sampled_action_logp)
  361. else:
  362. sampled_action_logp = sampled_action_logp.numpy()
  363. # NN output.
  364. means = np.array(
  365. [[0.1, 0.2, 0.3, 0.4, 50.0], [-0.1, -0.2, -0.3, -0.4, -1.0]],
  366. dtype=np.float32,
  367. )
  368. log_stds = np.array(
  369. [[0.8, -0.2, 0.3, -1.0, 2.0], [0.7, -0.3, 0.4, -0.9, 2.0]],
  370. dtype=np.float32,
  371. )
  372. diag_distribution = cls(
  373. inputs=np.concatenate([means, log_stds], axis=-1), model={}
  374. )
  375. # Convert to parameters for distr.
  376. stds = np.exp(log_stds)
  377. # Values to get log-likelihoods for.
  378. values = np.array(
  379. [[0.9, 0.2, 0.4, -0.1, -1.05], [-0.9, -0.2, 0.4, -0.1, -1.05]]
  380. )
  381. # get log-llh from regular gaussian.
  382. log_prob = np.sum(np.log(norm.pdf(values, means, stds)), -1)
  383. outs = diag_distribution.logp(
  384. values if fw != "torch" else torch.Tensor(values)
  385. )
  386. if sess:
  387. outs = sess.run(outs)
  388. check(outs, log_prob, decimals=4)
  389. def test_beta(self):
  390. input_space = Box(-2.0, 1.0, shape=(2000, 10))
  391. input_space.seed(42)
  392. low, high = -1.0, 2.0
  393. plain_beta_value_space = Box(0.0, 1.0, shape=(2000, 5))
  394. plain_beta_value_space.seed(42)
  395. for fw, sess in framework_iterator(session=True):
  396. cls = TorchBeta if fw == "torch" else Beta
  397. inputs = input_space.sample()
  398. beta_distribution = cls(inputs, {}, low=low, high=high)
  399. inputs = beta_distribution.inputs
  400. if sess:
  401. inputs = sess.run(inputs)
  402. else:
  403. inputs = inputs.numpy()
  404. alpha, beta_ = np.split(inputs, 2, axis=-1)
  405. # Mean for a Beta distribution: 1 / [1 + (beta/alpha)]
  406. expected = (1.0 / (1.0 + beta_ / alpha)) * (high - low) + low
  407. # Sample n times, expect always mean value (deterministic draw).
  408. out = beta_distribution.deterministic_sample()
  409. check(out, expected, rtol=0.01)
  410. # Batch of size=n and non-deterministic -> expect roughly the mean.
  411. values = beta_distribution.sample()
  412. if sess:
  413. values = sess.run(values)
  414. else:
  415. values = values.numpy()
  416. self.assertTrue(np.max(values) <= high)
  417. self.assertTrue(np.min(values) >= low)
  418. check(np.mean(values), expected.mean(), decimals=1)
  419. # Test log-likelihood outputs (against scipy).
  420. inputs = input_space.sample()
  421. beta_distribution = cls(inputs, {}, low=low, high=high)
  422. inputs = beta_distribution.inputs
  423. if sess:
  424. inputs = sess.run(inputs)
  425. else:
  426. inputs = inputs.numpy()
  427. alpha, beta_ = np.split(inputs, 2, axis=-1)
  428. values = plain_beta_value_space.sample()
  429. values_scaled = values * (high - low) + low
  430. if fw == "torch":
  431. values_scaled = torch.Tensor(values_scaled)
  432. print(values_scaled)
  433. out = beta_distribution.logp(values_scaled)
  434. check(out, np.sum(np.log(beta.pdf(values, alpha, beta_)), -1), rtol=0.01)
  435. # TODO(sven): Test entropy outputs (against scipy).
  436. def test_gumbel_softmax(self):
  437. """Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
  438. for fw, sess in framework_iterator(frameworks=("tf2", "tf"), session=True):
  439. batch_size = 1000
  440. num_categories = 5
  441. input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
  442. input_space.seed(42)
  443. # Batch of size=n and deterministic.
  444. inputs = input_space.sample()
  445. gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
  446. expected = softmax(inputs)
  447. # Sample n times, expect always mean value (deterministic draw).
  448. out = gumbel_softmax.deterministic_sample()
  449. check(out, expected)
  450. # Batch of size=n and non-deterministic -> expect roughly that
  451. # the max-likelihood (argmax) ints are output (most of the time).
  452. inputs = input_space.sample()
  453. gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
  454. expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
  455. outs = gumbel_softmax.sample()
  456. if sess:
  457. outs = sess.run(outs)
  458. check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
  459. def test_multi_action_distribution(self):
  460. """Tests the MultiActionDistribution (across all frameworks)."""
  461. batch_size = 1000
  462. input_space = Tuple(
  463. [
  464. Box(-10.0, 10.0, shape=(batch_size, 4)),
  465. Box(
  466. -2.0,
  467. 2.0,
  468. shape=(
  469. batch_size,
  470. 6,
  471. ),
  472. ),
  473. Dict({"a": Box(-1.0, 1.0, shape=(batch_size, 4))}),
  474. ]
  475. )
  476. input_space.seed(42)
  477. std_space = Box(
  478. -0.05,
  479. 0.05,
  480. shape=(
  481. batch_size,
  482. 3,
  483. ),
  484. )
  485. std_space.seed(42)
  486. low, high = -1.0, 1.0
  487. value_space = Tuple(
  488. [
  489. Box(0, 3, shape=(batch_size,), dtype=np.int32),
  490. Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32),
  491. Dict({"a": Box(0.0, 1.0, shape=(batch_size, 2), dtype=np.float32)}),
  492. ]
  493. )
  494. value_space.seed(42)
  495. for fw, sess in framework_iterator(session=True):
  496. if fw == "torch":
  497. cls = TorchMultiActionDistribution
  498. child_distr_cls = [
  499. TorchCategorical,
  500. TorchDiagGaussian,
  501. partial(TorchBeta, low=low, high=high),
  502. ]
  503. else:
  504. cls = MultiActionDistribution
  505. child_distr_cls = [
  506. Categorical,
  507. DiagGaussian,
  508. partial(Beta, low=low, high=high),
  509. ]
  510. inputs = list(input_space.sample())
  511. distr = cls(
  512. np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1),
  513. model={},
  514. action_space=value_space,
  515. child_distributions=child_distr_cls,
  516. input_lens=[4, 6, 4],
  517. )
  518. # Adjust inputs for the Beta distr just as Beta itself does.
  519. inputs[2]["a"] = np.clip(
  520. inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER)
  521. )
  522. inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
  523. # Sample deterministically.
  524. expected_det = [
  525. np.argmax(inputs[0], axis=-1),
  526. inputs[1][:, :3], # [:3]=Mean values.
  527. # Mean for a Beta distribution:
  528. # 1 / [1 + (beta/alpha)] * range + low
  529. (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, 0:2]))
  530. * (high - low)
  531. + low,
  532. ]
  533. out = distr.deterministic_sample()
  534. if sess:
  535. out = sess.run(out)
  536. check(out[0], expected_det[0])
  537. check(out[1], expected_det[1])
  538. check(out[2]["a"], expected_det[2])
  539. # Stochastic sampling -> expect roughly the mean.
  540. inputs = list(input_space.sample())
  541. # Fix categorical inputs (not needed for distribution itself, but
  542. # for our expectation calculations).
  543. inputs[0] = softmax(inputs[0], -1)
  544. # Fix std inputs (shouldn't be too large for this test).
  545. inputs[1][:, 3:] = std_space.sample()
  546. # Adjust inputs for the Beta distr just as Beta itself does.
  547. inputs[2]["a"] = np.clip(
  548. inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER)
  549. )
  550. inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
  551. distr = cls(
  552. np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1),
  553. model={},
  554. action_space=value_space,
  555. child_distributions=child_distr_cls,
  556. input_lens=[4, 6, 4],
  557. )
  558. expected_mean = [
  559. np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)),
  560. inputs[1][:, :3], # [:3]=Mean values.
  561. # Mean for a Beta distribution:
  562. # 1 / [1 + (beta/alpha)] * range + low
  563. (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, :2]))
  564. * (high - low)
  565. + low,
  566. ]
  567. out = distr.sample()
  568. if sess:
  569. out = sess.run(out)
  570. out = list(out)
  571. if fw == "torch":
  572. out[0] = out[0].numpy()
  573. out[1] = out[1].numpy()
  574. out[2]["a"] = out[2]["a"].numpy()
  575. check(np.mean(out[0]), expected_mean[0], decimals=1)
  576. check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1)
  577. check(np.mean(out[2]["a"], 0), np.mean(expected_mean[2], 0), decimals=1)
  578. # Test log-likelihood outputs.
  579. # Make sure beta-values are within 0.0 and 1.0 for the numpy
  580. # calculation (which doesn't have scaling).
  581. inputs = list(input_space.sample())
  582. # Adjust inputs for the Beta distr just as Beta itself does.
  583. inputs[2]["a"] = np.clip(
  584. inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER)
  585. )
  586. inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
  587. distr = cls(
  588. np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1),
  589. model={},
  590. action_space=value_space,
  591. child_distributions=child_distr_cls,
  592. input_lens=[4, 6, 4],
  593. )
  594. inputs[0] = softmax(inputs[0], -1)
  595. values = list(value_space.sample())
  596. log_prob_beta = np.log(
  597. beta.pdf(values[2]["a"], inputs[2]["a"][:, :2], inputs[2]["a"][:, 2:])
  598. )
  599. # Now do the up-scaling for [2] (beta values) to be between
  600. # low/high.
  601. values[2]["a"] = values[2]["a"] * (high - low) + low
  602. inputs[1][:, 3:] = np.exp(inputs[1][:, 3:])
  603. expected_log_llh = np.sum(
  604. np.concatenate(
  605. [
  606. np.expand_dims(
  607. np.log([i[values[0][j]] for j, i in enumerate(inputs[0])]),
  608. -1,
  609. ),
  610. np.log(norm.pdf(values[1], inputs[1][:, :3], inputs[1][:, 3:])),
  611. log_prob_beta,
  612. ],
  613. -1,
  614. ),
  615. -1,
  616. )
  617. values[0] = np.expand_dims(values[0], -1)
  618. if fw == "torch":
  619. values = tree.map_structure(lambda s: torch.Tensor(s), values)
  620. # Test all flattened input.
  621. concat = np.concatenate(tree.flatten(values), -1).astype(np.float32)
  622. out = distr.logp(concat)
  623. if sess:
  624. out = sess.run(out)
  625. check(out, expected_log_llh, atol=15)
  626. # Test structured input.
  627. out = distr.logp(values)
  628. if sess:
  629. out = sess.run(out)
  630. check(out, expected_log_llh, atol=15)
  631. # Test flattened input.
  632. out = distr.logp(tree.flatten(values))
  633. if sess:
  634. out = sess.run(out)
  635. check(out, expected_log_llh, atol=15)
  636. if __name__ == "__main__":
  637. import pytest
  638. import sys
  639. sys.exit(pytest.main(["-v", __file__]))