123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709 |
- from functools import partial
- from gymnasium.spaces import Box, Dict, Tuple
- import numpy as np
- from scipy.stats import beta, norm
- import tree # pip install dm_tree
- import unittest
- from ray.rllib.models.jax.jax_action_dist import JAXCategorical
- from ray.rllib.models.tf.tf_action_dist import (
- Beta,
- Categorical,
- DiagGaussian,
- GumbelSoftmax,
- MultiActionDistribution,
- MultiCategorical,
- SquashedGaussian,
- )
- from ray.rllib.models.torch.torch_action_dist import (
- TorchBeta,
- TorchCategorical,
- TorchDiagGaussian,
- TorchMultiActionDistribution,
- TorchMultiCategorical,
- TorchSquashedGaussian,
- )
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.numpy import (
- MIN_LOG_NN_OUTPUT,
- MAX_LOG_NN_OUTPUT,
- softmax,
- SMALL_NUMBER,
- LARGE_INTEGER,
- )
- from ray.rllib.utils.test_utils import check, framework_iterator
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- class TestActionDistributions(unittest.TestCase):
- """Tests ActionDistribution classes."""
- @classmethod
- def setUpClass(cls) -> None:
- # Set seeds for deterministic tests (make sure we don't fail
- # because of "bad" sampling).
- np.random.seed(42 + 1)
- torch.manual_seed(42 + 1)
- def _stability_test(
- self,
- distribution_cls,
- network_output_shape,
- fw,
- sess=None,
- bounds=None,
- extra_kwargs=None,
- ):
- extreme_values = [
- 0.0,
- float(LARGE_INTEGER),
- -float(LARGE_INTEGER),
- 1.1e-34,
- 1.1e34,
- -1.1e-34,
- -1.1e34,
- SMALL_NUMBER,
- -SMALL_NUMBER,
- ]
- inputs = np.zeros(shape=network_output_shape, dtype=np.float32)
- for batch_item in range(network_output_shape[0]):
- for num in range(len(inputs[batch_item]) // 2):
- inputs[batch_item][num] = np.random.choice(extreme_values)
- else:
- # For Gaussians, the second half of the vector is
- # log standard deviations, and should therefore be
- # the log of a positive number >= 1.
- inputs[batch_item][num] = np.log(
- max(1, np.random.choice((extreme_values)))
- )
- dist = distribution_cls(inputs, {}, **(extra_kwargs or {}))
- for _ in range(100):
- sample = dist.sample()
- if fw == "jax":
- sample_check = sample
- elif fw != "tf":
- sample_check = sample.numpy()
- else:
- sample_check = sess.run(sample)
- assert not np.any(np.isnan(sample_check))
- assert np.all(np.isfinite(sample_check))
- if bounds:
- assert np.min(sample_check) >= bounds[0]
- assert np.max(sample_check) <= bounds[1]
- # Make sure bounds make sense and are actually also being
- # sampled.
- if isinstance(bounds[0], int):
- assert isinstance(bounds[1], int)
- assert bounds[0] in sample_check
- assert bounds[1] in sample_check
- logp = dist.logp(sample)
- if fw == "jax":
- logp_check = logp
- elif fw != "tf":
- logp_check = logp.numpy()
- else:
- logp_check = sess.run(logp)
- assert not np.any(np.isnan(logp_check))
- assert np.all(np.isfinite(logp_check))
- def test_categorical(self):
- batch_size = 10000
- num_categories = 4
- # Create categorical distribution with n categories.
- inputs_space = Box(
- -1.0, 2.0, shape=(batch_size, num_categories), dtype=np.float32
- )
- inputs_space.seed(42)
- values_space = Box(0, num_categories - 1, shape=(batch_size,), dtype=np.int32)
- values_space.seed(42)
- inputs = inputs_space.sample()
- for fw, sess in framework_iterator(session=True):
- # Create the correct distribution object.
- cls = (
- JAXCategorical
- if fw == "jax"
- else Categorical
- if fw != "torch"
- else TorchCategorical
- )
- categorical = cls(inputs, {})
- # Do a stability test using extreme NN outputs to see whether
- # sampling and logp'ing result in NaN or +/-inf values.
- self._stability_test(
- cls,
- inputs_space.shape,
- fw=fw,
- sess=sess,
- bounds=(0, num_categories - 1),
- )
- # Batch of size=3 and deterministic (True).
- expected = np.transpose(np.argmax(inputs, axis=-1))
- # Sample, expect always max value
- # (max likelihood for deterministic draw).
- out = categorical.deterministic_sample()
- check(out, expected)
- # Batch of size=3 and non-deterministic -> expect roughly the mean.
- out = categorical.sample()
- check(
- np.mean(out)
- if fw == "jax"
- else tf.reduce_mean(out)
- if fw != "torch"
- else torch.mean(out.float()),
- 1.0,
- decimals=0,
- )
- # Test log-likelihood outputs.
- probs = softmax(inputs)
- values = values_space.sample()
- out = categorical.logp(values if fw != "torch" else torch.Tensor(values))
- expected = []
- for i in range(batch_size):
- expected.append(np.sum(np.log(np.array(probs[i][values[i]]))))
- check(out, expected, decimals=4)
- # Test entropy outputs.
- out = categorical.entropy()
- expected_entropy = -np.sum(probs * np.log(probs), -1)
- check(out, expected_entropy)
- def test_multi_categorical(self):
- batch_size = 100
- num_categories = 3
- num_sub_distributions = 5
- # Create 5 categorical distributions of 3 categories each.
- inputs_space = Box(
- -1.0, 2.0, shape=(batch_size, num_sub_distributions * num_categories)
- )
- inputs_space.seed(42)
- values_space = Box(
- 0,
- num_categories - 1,
- shape=(num_sub_distributions, batch_size),
- dtype=np.int32,
- )
- values_space.seed(42)
- inputs = inputs_space.sample()
- input_lengths = [num_categories] * num_sub_distributions
- inputs_split = np.split(inputs, num_sub_distributions, axis=1)
- for fw, sess in framework_iterator(session=True):
- # Create the correct distribution object.
- cls = MultiCategorical if fw != "torch" else TorchMultiCategorical
- multi_categorical = cls(inputs, None, input_lengths)
- # Do a stability test using extreme NN outputs to see whether
- # sampling and logp'ing result in NaN or +/-inf values.
- self._stability_test(
- cls,
- inputs_space.shape,
- fw=fw,
- sess=sess,
- bounds=(0, num_categories - 1),
- extra_kwargs={"input_lens": input_lengths},
- )
- # Batch of size=3 and deterministic (True).
- expected = np.transpose(np.argmax(inputs_split, axis=-1))
- # Sample, expect always max value
- # (max likelihood for deterministic draw).
- out = multi_categorical.deterministic_sample()
- check(out, expected)
- # Batch of size=3 and non-deterministic -> expect roughly the mean.
- out = multi_categorical.sample()
- check(
- tf.reduce_mean(out) if fw != "torch" else torch.mean(out.float()),
- 1.0,
- decimals=0,
- )
- # Test log-likelihood outputs.
- probs = softmax(inputs_split)
- values = values_space.sample()
- out = multi_categorical.logp(
- values
- if fw != "torch"
- else [torch.Tensor(values[i]) for i in range(num_sub_distributions)]
- ) # v in np.stack(values, 1)])
- expected = []
- for i in range(batch_size):
- expected.append(
- np.sum(
- np.log(
- np.array(
- [
- probs[j][i][values[j][i]]
- for j in range(num_sub_distributions)
- ]
- )
- )
- )
- )
- check(out, expected, decimals=4)
- # Test entropy outputs.
- out = multi_categorical.entropy()
- expected_entropy = -np.sum(np.sum(probs * np.log(probs), 0), -1)
- check(out, expected_entropy)
- def test_squashed_gaussian(self):
- """Tests the SquashedGaussian ActionDistribution for all frameworks."""
- input_space = Box(-2.0, 2.0, shape=(2000, 10))
- input_space.seed(42)
- low, high = -2.0, 1.0
- for fw, sess in framework_iterator(session=True):
- cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian
- # Do a stability test using extreme NN outputs to see whether
- # sampling and logp'ing result in NaN or +/-inf values.
- self._stability_test(
- cls, input_space.shape, fw=fw, sess=sess, bounds=(low, high)
- )
- # Batch of size=n and deterministic.
- inputs = input_space.sample()
- means, _ = np.split(inputs, 2, axis=-1)
- squashed_distribution = cls(inputs, {}, low=low, high=high)
- expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
- # Sample n times, expect always mean value (deterministic draw).
- out = squashed_distribution.deterministic_sample()
- check(out, expected)
- # Batch of size=n and non-deterministic -> expect roughly the mean.
- inputs = input_space.sample()
- means, log_stds = np.split(inputs, 2, axis=-1)
- squashed_distribution = cls(inputs, {}, low=low, high=high)
- expected = ((np.tanh(means) + 1.0) / 2.0) * (high - low) + low
- values = squashed_distribution.sample()
- if sess:
- values = sess.run(values)
- else:
- values = values.numpy()
- self.assertTrue(np.max(values) <= high)
- self.assertTrue(np.min(values) >= low)
- check(np.mean(values), expected.mean(), decimals=1)
- # Test log-likelihood outputs.
- sampled_action_logp = squashed_distribution.logp(
- values if fw != "torch" else torch.Tensor(values)
- )
- if sess:
- sampled_action_logp = sess.run(sampled_action_logp)
- else:
- sampled_action_logp = sampled_action_logp.numpy()
- # Convert to parameters for distr.
- stds = np.exp(np.clip(log_stds, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT))
- # Unsquash values, then get log-llh from regular gaussian.
- # atanh_in = np.clip((values - low) / (high - low) * 2.0 - 1.0,
- # -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER)
- normed_values = (values - low) / (high - low) * 2.0 - 1.0
- save_normed_values = np.clip(
- normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER
- )
- unsquashed_values = np.arctanh(save_normed_values)
- log_prob_unsquashed = np.sum(
- np.log(norm.pdf(unsquashed_values, means, stds)), -1
- )
- log_prob = log_prob_unsquashed - np.sum(
- np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1
- )
- check(np.sum(sampled_action_logp), np.sum(log_prob), rtol=0.05)
- # NN output.
- means = np.array(
- [[0.1, 0.2, 0.3, 0.4, 50.0], [-0.1, -0.2, -0.3, -0.4, -1.0]]
- )
- log_stds = np.array(
- [[0.8, -0.2, 0.3, -1.0, 2.0], [0.7, -0.3, 0.4, -0.9, 2.0]]
- )
- squashed_distribution = cls(
- inputs=np.concatenate([means, log_stds], axis=-1),
- model={},
- low=low,
- high=high,
- )
- # Convert to parameters for distr.
- stds = np.exp(log_stds)
- # Values to get log-likelihoods for.
- values = np.array(
- [[0.9, 0.2, 0.4, -0.1, -1.05], [-0.9, -0.2, 0.4, -0.1, -1.05]]
- )
- # Unsquash values, then get log-llh from regular gaussian.
- unsquashed_values = np.arctanh((values - low) / (high - low) * 2.0 - 1.0)
- log_prob_unsquashed = np.sum(
- np.log(norm.pdf(unsquashed_values, means, stds)), -1
- )
- log_prob = log_prob_unsquashed - np.sum(
- np.log(1 - np.tanh(unsquashed_values) ** 2), axis=-1
- )
- outs = squashed_distribution.logp(
- values if fw != "torch" else torch.Tensor(values)
- )
- if sess:
- outs = sess.run(outs)
- check(outs, log_prob, decimals=4)
- def test_diag_gaussian(self):
- """Tests the DiagGaussian ActionDistribution for all frameworks."""
- input_space = Box(-2.0, 1.0, shape=(2000, 10))
- input_space.seed(42)
- for fw, sess in framework_iterator(session=True):
- cls = DiagGaussian if fw != "torch" else TorchDiagGaussian
- # Do a stability test using extreme NN outputs to see whether
- # sampling and logp'ing result in NaN or +/-inf values.
- self._stability_test(cls, input_space.shape, fw=fw, sess=sess)
- # Batch of size=n and deterministic.
- inputs = input_space.sample()
- means, _ = np.split(inputs, 2, axis=-1)
- diag_distribution = cls(inputs, {})
- expected = means
- # Sample n times, expect always mean value (deterministic draw).
- out = diag_distribution.deterministic_sample()
- check(out, expected)
- # Batch of size=n and non-deterministic -> expect roughly the mean.
- inputs = input_space.sample()
- means, log_stds = np.split(inputs, 2, axis=-1)
- diag_distribution = cls(inputs, {})
- expected = means
- values = diag_distribution.sample()
- if sess:
- values = sess.run(values)
- else:
- values = values.numpy()
- check(np.mean(values), expected.mean(), decimals=1)
- # Test log-likelihood outputs.
- sampled_action_logp = diag_distribution.logp(
- values if fw != "torch" else torch.Tensor(values)
- )
- if sess:
- sampled_action_logp = sess.run(sampled_action_logp)
- else:
- sampled_action_logp = sampled_action_logp.numpy()
- # NN output.
- means = np.array(
- [[0.1, 0.2, 0.3, 0.4, 50.0], [-0.1, -0.2, -0.3, -0.4, -1.0]],
- dtype=np.float32,
- )
- log_stds = np.array(
- [[0.8, -0.2, 0.3, -1.0, 2.0], [0.7, -0.3, 0.4, -0.9, 2.0]],
- dtype=np.float32,
- )
- diag_distribution = cls(
- inputs=np.concatenate([means, log_stds], axis=-1), model={}
- )
- # Convert to parameters for distr.
- stds = np.exp(log_stds)
- # Values to get log-likelihoods for.
- values = np.array(
- [[0.9, 0.2, 0.4, -0.1, -1.05], [-0.9, -0.2, 0.4, -0.1, -1.05]]
- )
- # get log-llh from regular gaussian.
- log_prob = np.sum(np.log(norm.pdf(values, means, stds)), -1)
- outs = diag_distribution.logp(
- values if fw != "torch" else torch.Tensor(values)
- )
- if sess:
- outs = sess.run(outs)
- check(outs, log_prob, decimals=4)
- def test_beta(self):
- input_space = Box(-2.0, 1.0, shape=(2000, 10))
- input_space.seed(42)
- low, high = -1.0, 2.0
- plain_beta_value_space = Box(0.0, 1.0, shape=(2000, 5))
- plain_beta_value_space.seed(42)
- for fw, sess in framework_iterator(session=True):
- cls = TorchBeta if fw == "torch" else Beta
- inputs = input_space.sample()
- beta_distribution = cls(inputs, {}, low=low, high=high)
- inputs = beta_distribution.inputs
- if sess:
- inputs = sess.run(inputs)
- else:
- inputs = inputs.numpy()
- alpha, beta_ = np.split(inputs, 2, axis=-1)
- # Mean for a Beta distribution: 1 / [1 + (beta/alpha)]
- expected = (1.0 / (1.0 + beta_ / alpha)) * (high - low) + low
- # Sample n times, expect always mean value (deterministic draw).
- out = beta_distribution.deterministic_sample()
- check(out, expected, rtol=0.01)
- # Batch of size=n and non-deterministic -> expect roughly the mean.
- values = beta_distribution.sample()
- if sess:
- values = sess.run(values)
- else:
- values = values.numpy()
- self.assertTrue(np.max(values) <= high)
- self.assertTrue(np.min(values) >= low)
- check(np.mean(values), expected.mean(), decimals=1)
- # Test log-likelihood outputs (against scipy).
- inputs = input_space.sample()
- beta_distribution = cls(inputs, {}, low=low, high=high)
- inputs = beta_distribution.inputs
- if sess:
- inputs = sess.run(inputs)
- else:
- inputs = inputs.numpy()
- alpha, beta_ = np.split(inputs, 2, axis=-1)
- values = plain_beta_value_space.sample()
- values_scaled = values * (high - low) + low
- if fw == "torch":
- values_scaled = torch.Tensor(values_scaled)
- print(values_scaled)
- out = beta_distribution.logp(values_scaled)
- check(out, np.sum(np.log(beta.pdf(values, alpha, beta_)), -1), rtol=0.01)
- # TODO(sven): Test entropy outputs (against scipy).
- def test_gumbel_softmax(self):
- """Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
- for fw, sess in framework_iterator(frameworks=("tf2", "tf"), session=True):
- batch_size = 1000
- num_categories = 5
- input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
- input_space.seed(42)
- # Batch of size=n and deterministic.
- inputs = input_space.sample()
- gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
- expected = softmax(inputs)
- # Sample n times, expect always mean value (deterministic draw).
- out = gumbel_softmax.deterministic_sample()
- check(out, expected)
- # Batch of size=n and non-deterministic -> expect roughly that
- # the max-likelihood (argmax) ints are output (most of the time).
- inputs = input_space.sample()
- gumbel_softmax = GumbelSoftmax(inputs, {}, temperature=1.0)
- expected_mean = np.mean(np.argmax(inputs, -1)).astype(np.float32)
- outs = gumbel_softmax.sample()
- if sess:
- outs = sess.run(outs)
- check(np.mean(np.argmax(outs, -1)), expected_mean, rtol=0.08)
- def test_multi_action_distribution(self):
- """Tests the MultiActionDistribution (across all frameworks)."""
- batch_size = 1000
- input_space = Tuple(
- [
- Box(-10.0, 10.0, shape=(batch_size, 4)),
- Box(
- -2.0,
- 2.0,
- shape=(
- batch_size,
- 6,
- ),
- ),
- Dict({"a": Box(-1.0, 1.0, shape=(batch_size, 4))}),
- ]
- )
- input_space.seed(42)
- std_space = Box(
- -0.05,
- 0.05,
- shape=(
- batch_size,
- 3,
- ),
- )
- std_space.seed(42)
- low, high = -1.0, 1.0
- value_space = Tuple(
- [
- Box(0, 3, shape=(batch_size,), dtype=np.int32),
- Box(-2.0, 2.0, shape=(batch_size, 3), dtype=np.float32),
- Dict({"a": Box(0.0, 1.0, shape=(batch_size, 2), dtype=np.float32)}),
- ]
- )
- value_space.seed(42)
- for fw, sess in framework_iterator(session=True):
- if fw == "torch":
- cls = TorchMultiActionDistribution
- child_distr_cls = [
- TorchCategorical,
- TorchDiagGaussian,
- partial(TorchBeta, low=low, high=high),
- ]
- else:
- cls = MultiActionDistribution
- child_distr_cls = [
- Categorical,
- DiagGaussian,
- partial(Beta, low=low, high=high),
- ]
- inputs = list(input_space.sample())
- distr = cls(
- np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1),
- model={},
- action_space=value_space,
- child_distributions=child_distr_cls,
- input_lens=[4, 6, 4],
- )
- # Adjust inputs for the Beta distr just as Beta itself does.
- inputs[2]["a"] = np.clip(
- inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER)
- )
- inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
- # Sample deterministically.
- expected_det = [
- np.argmax(inputs[0], axis=-1),
- inputs[1][:, :3], # [:3]=Mean values.
- # Mean for a Beta distribution:
- # 1 / [1 + (beta/alpha)] * range + low
- (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, 0:2]))
- * (high - low)
- + low,
- ]
- out = distr.deterministic_sample()
- if sess:
- out = sess.run(out)
- check(out[0], expected_det[0])
- check(out[1], expected_det[1])
- check(out[2]["a"], expected_det[2])
- # Stochastic sampling -> expect roughly the mean.
- inputs = list(input_space.sample())
- # Fix categorical inputs (not needed for distribution itself, but
- # for our expectation calculations).
- inputs[0] = softmax(inputs[0], -1)
- # Fix std inputs (shouldn't be too large for this test).
- inputs[1][:, 3:] = std_space.sample()
- # Adjust inputs for the Beta distr just as Beta itself does.
- inputs[2]["a"] = np.clip(
- inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER)
- )
- inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
- distr = cls(
- np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1),
- model={},
- action_space=value_space,
- child_distributions=child_distr_cls,
- input_lens=[4, 6, 4],
- )
- expected_mean = [
- np.mean(np.sum(inputs[0] * np.array([0, 1, 2, 3]), -1)),
- inputs[1][:, :3], # [:3]=Mean values.
- # Mean for a Beta distribution:
- # 1 / [1 + (beta/alpha)] * range + low
- (1.0 / (1.0 + inputs[2]["a"][:, 2:] / inputs[2]["a"][:, :2]))
- * (high - low)
- + low,
- ]
- out = distr.sample()
- if sess:
- out = sess.run(out)
- out = list(out)
- if fw == "torch":
- out[0] = out[0].numpy()
- out[1] = out[1].numpy()
- out[2]["a"] = out[2]["a"].numpy()
- check(np.mean(out[0]), expected_mean[0], decimals=1)
- check(np.mean(out[1], 0), np.mean(expected_mean[1], 0), decimals=1)
- check(np.mean(out[2]["a"], 0), np.mean(expected_mean[2], 0), decimals=1)
- # Test log-likelihood outputs.
- # Make sure beta-values are within 0.0 and 1.0 for the numpy
- # calculation (which doesn't have scaling).
- inputs = list(input_space.sample())
- # Adjust inputs for the Beta distr just as Beta itself does.
- inputs[2]["a"] = np.clip(
- inputs[2]["a"], np.log(SMALL_NUMBER), -np.log(SMALL_NUMBER)
- )
- inputs[2]["a"] = np.log(np.exp(inputs[2]["a"]) + 1.0) + 1.0
- distr = cls(
- np.concatenate([inputs[0], inputs[1], inputs[2]["a"]], axis=1),
- model={},
- action_space=value_space,
- child_distributions=child_distr_cls,
- input_lens=[4, 6, 4],
- )
- inputs[0] = softmax(inputs[0], -1)
- values = list(value_space.sample())
- log_prob_beta = np.log(
- beta.pdf(values[2]["a"], inputs[2]["a"][:, :2], inputs[2]["a"][:, 2:])
- )
- # Now do the up-scaling for [2] (beta values) to be between
- # low/high.
- values[2]["a"] = values[2]["a"] * (high - low) + low
- inputs[1][:, 3:] = np.exp(inputs[1][:, 3:])
- expected_log_llh = np.sum(
- np.concatenate(
- [
- np.expand_dims(
- np.log([i[values[0][j]] for j, i in enumerate(inputs[0])]),
- -1,
- ),
- np.log(norm.pdf(values[1], inputs[1][:, :3], inputs[1][:, 3:])),
- log_prob_beta,
- ],
- -1,
- ),
- -1,
- )
- values[0] = np.expand_dims(values[0], -1)
- if fw == "torch":
- values = tree.map_structure(lambda s: torch.Tensor(s), values)
- # Test all flattened input.
- concat = np.concatenate(tree.flatten(values), -1).astype(np.float32)
- out = distr.logp(concat)
- if sess:
- out = sess.run(out)
- check(out, expected_log_llh, atol=15)
- # Test structured input.
- out = distr.logp(values)
- if sess:
- out = sess.run(out)
- check(out, expected_log_llh, atol=15)
- # Test flattened input.
- out = distr.logp(tree.flatten(values))
- if sess:
- out = sess.run(out)
- check(out, expected_log_llh, atol=15)
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|