tf_action_dist.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. import functools
  2. import gym
  3. from math import log
  4. import numpy as np
  5. import tree # pip install dm_tree
  6. from typing import Optional
  7. from ray.rllib.models.action_dist import ActionDistribution
  8. from ray.rllib.models.modelv2 import ModelV2
  9. from ray.rllib.utils import MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT, \
  10. SMALL_NUMBER
  11. from ray.rllib.utils.annotations import override, DeveloperAPI
  12. from ray.rllib.utils.framework import try_import_tf, try_import_tfp
  13. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  14. from ray.rllib.utils.typing import TensorType, List, Union, \
  15. Tuple, ModelConfigDict
  16. tf1, tf, tfv = try_import_tf()
  17. tfp = try_import_tfp()
  18. @DeveloperAPI
  19. class TFActionDistribution(ActionDistribution):
  20. """TF-specific extensions for building action distributions."""
  21. @override(ActionDistribution)
  22. def __init__(self, inputs: List[TensorType], model: ModelV2):
  23. super().__init__(inputs, model)
  24. self.sample_op = self._build_sample_op()
  25. self.sampled_action_logp_op = self.logp(self.sample_op)
  26. @DeveloperAPI
  27. def _build_sample_op(self) -> TensorType:
  28. """Implement this instead of sample(), to enable op reuse.
  29. This is needed since the sample op is non-deterministic and is shared
  30. between sample() and sampled_action_logp().
  31. """
  32. raise NotImplementedError
  33. @override(ActionDistribution)
  34. def sample(self) -> TensorType:
  35. """Draw a sample from the action distribution."""
  36. return self.sample_op
  37. @override(ActionDistribution)
  38. def sampled_action_logp(self) -> TensorType:
  39. """Returns the log probability of the sampled action."""
  40. return self.sampled_action_logp_op
  41. class Categorical(TFActionDistribution):
  42. """Categorical distribution for discrete action spaces."""
  43. @DeveloperAPI
  44. def __init__(self,
  45. inputs: List[TensorType],
  46. model: ModelV2 = None,
  47. temperature: float = 1.0):
  48. assert temperature > 0.0, "Categorical `temperature` must be > 0.0!"
  49. # Allow softmax formula w/ temperature != 1.0:
  50. # Divide inputs by temperature.
  51. super().__init__(inputs / temperature, model)
  52. @override(ActionDistribution)
  53. def deterministic_sample(self) -> TensorType:
  54. return tf.math.argmax(self.inputs, axis=1)
  55. @override(ActionDistribution)
  56. def logp(self, x: TensorType) -> TensorType:
  57. return -tf.nn.sparse_softmax_cross_entropy_with_logits(
  58. logits=self.inputs, labels=tf.cast(x, tf.int32))
  59. @override(ActionDistribution)
  60. def entropy(self) -> TensorType:
  61. a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
  62. ea0 = tf.exp(a0)
  63. z0 = tf.reduce_sum(ea0, axis=1, keepdims=True)
  64. p0 = ea0 / z0
  65. return tf.reduce_sum(p0 * (tf.math.log(z0) - a0), axis=1)
  66. @override(ActionDistribution)
  67. def kl(self, other: ActionDistribution) -> TensorType:
  68. a0 = self.inputs - tf.reduce_max(self.inputs, axis=1, keepdims=True)
  69. a1 = other.inputs - tf.reduce_max(other.inputs, axis=1, keepdims=True)
  70. ea0 = tf.exp(a0)
  71. ea1 = tf.exp(a1)
  72. z0 = tf.reduce_sum(ea0, axis=1, keepdims=True)
  73. z1 = tf.reduce_sum(ea1, axis=1, keepdims=True)
  74. p0 = ea0 / z0
  75. return tf.reduce_sum(
  76. p0 * (a0 - tf.math.log(z0) - a1 + tf.math.log(z1)), axis=1)
  77. @override(TFActionDistribution)
  78. def _build_sample_op(self) -> TensorType:
  79. return tf.squeeze(tf.random.categorical(self.inputs, 1), axis=1)
  80. @staticmethod
  81. @override(ActionDistribution)
  82. def required_model_output_shape(action_space, model_config):
  83. return action_space.n
  84. class MultiCategorical(TFActionDistribution):
  85. """MultiCategorical distribution for MultiDiscrete action spaces."""
  86. def __init__(self,
  87. inputs: List[TensorType],
  88. model: ModelV2,
  89. input_lens: Union[List[int], np.ndarray, Tuple[int, ...]],
  90. action_space=None):
  91. # skip TFActionDistribution init
  92. ActionDistribution.__init__(self, inputs, model)
  93. self.cats = [
  94. Categorical(input_, model)
  95. for input_ in tf.split(inputs, input_lens, axis=1)
  96. ]
  97. self.action_space = action_space
  98. if self.action_space is None:
  99. self.action_space = gym.spaces.MultiDiscrete(
  100. [c.inputs.shape[1] for c in self.cats])
  101. self.sample_op = self._build_sample_op()
  102. self.sampled_action_logp_op = self.logp(self.sample_op)
  103. @override(ActionDistribution)
  104. def deterministic_sample(self) -> TensorType:
  105. sample_ = tf.stack(
  106. [cat.deterministic_sample() for cat in self.cats], axis=1)
  107. if isinstance(self.action_space, gym.spaces.Box):
  108. return tf.cast(
  109. tf.reshape(sample_, [-1] + list(self.action_space.shape)),
  110. self.action_space.dtype)
  111. return sample_
  112. @override(ActionDistribution)
  113. def logp(self, actions: TensorType) -> TensorType:
  114. # If tensor is provided, unstack it into list.
  115. if isinstance(actions, tf.Tensor):
  116. if isinstance(self.action_space, gym.spaces.Box):
  117. actions = tf.reshape(
  118. actions, [-1, int(np.product(self.action_space.shape))])
  119. elif isinstance(self.action_space, gym.spaces.MultiDiscrete):
  120. actions.set_shape((None, len(self.cats)))
  121. actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
  122. logps = tf.stack(
  123. [cat.logp(act) for cat, act in zip(self.cats, actions)])
  124. return tf.reduce_sum(logps, axis=0)
  125. @override(ActionDistribution)
  126. def multi_entropy(self) -> TensorType:
  127. return tf.stack([cat.entropy() for cat in self.cats], axis=1)
  128. @override(ActionDistribution)
  129. def entropy(self) -> TensorType:
  130. return tf.reduce_sum(self.multi_entropy(), axis=1)
  131. @override(ActionDistribution)
  132. def multi_kl(self, other: ActionDistribution) -> TensorType:
  133. return tf.stack(
  134. [cat.kl(oth_cat) for cat, oth_cat in zip(self.cats, other.cats)],
  135. axis=1)
  136. @override(ActionDistribution)
  137. def kl(self, other: ActionDistribution) -> TensorType:
  138. return tf.reduce_sum(self.multi_kl(other), axis=1)
  139. @override(TFActionDistribution)
  140. def _build_sample_op(self) -> TensorType:
  141. sample_op = tf.stack([cat.sample() for cat in self.cats], axis=1)
  142. if isinstance(self.action_space, gym.spaces.Box):
  143. return tf.cast(
  144. tf.reshape(sample_op, [-1] + list(self.action_space.shape)),
  145. dtype=self.action_space.dtype)
  146. return sample_op
  147. @staticmethod
  148. @override(ActionDistribution)
  149. def required_model_output_shape(
  150. action_space: gym.Space,
  151. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  152. # Int Box.
  153. if isinstance(action_space, gym.spaces.Box):
  154. assert action_space.dtype.name.startswith("int")
  155. low_ = np.min(action_space.low)
  156. high_ = np.max(action_space.high)
  157. assert np.all(action_space.low == low_)
  158. assert np.all(action_space.high == high_)
  159. np.product(action_space.shape) * (high_ - low_ + 1)
  160. # MultiDiscrete space.
  161. else:
  162. return np.sum(action_space.nvec)
  163. class GumbelSoftmax(TFActionDistribution):
  164. """GumbelSoftmax distr. (for differentiable sampling in discr. actions
  165. The Gumbel Softmax distribution [1] (also known as the Concrete [2]
  166. distribution) is a close cousin of the relaxed one-hot categorical
  167. distribution, whose tfp implementation we will use here plus
  168. adjusted `sample_...` and `log_prob` methods. See discussion at [0].
  169. [0] https://stackoverflow.com/questions/56226133/
  170. soft-actor-critic-with-discrete-action-space
  171. [1] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017):
  172. https://arxiv.org/abs/1611.01144
  173. [2] The Concrete Distribution: A Continuous Relaxation of Discrete Random
  174. Variables (Maddison et al, 2017) https://arxiv.org/abs/1611.00712
  175. """
  176. @DeveloperAPI
  177. def __init__(self,
  178. inputs: List[TensorType],
  179. model: ModelV2 = None,
  180. temperature: float = 1.0):
  181. """Initializes a GumbelSoftmax distribution.
  182. Args:
  183. temperature (float): Temperature parameter. For low temperatures,
  184. the expected value approaches a categorical random variable.
  185. For high temperatures, the expected value approaches a uniform
  186. distribution.
  187. """
  188. assert temperature >= 0.0
  189. self.dist = tfp.distributions.RelaxedOneHotCategorical(
  190. temperature=temperature, logits=inputs)
  191. self.probs = tf.nn.softmax(self.dist._distribution.logits)
  192. super().__init__(inputs, model)
  193. @override(ActionDistribution)
  194. def deterministic_sample(self) -> TensorType:
  195. # Return the dist object's prob values.
  196. return self.probs
  197. @override(ActionDistribution)
  198. def logp(self, x: TensorType) -> TensorType:
  199. # Override since the implementation of tfp.RelaxedOneHotCategorical
  200. # yields positive values.
  201. if x.shape != self.dist.logits.shape:
  202. values = tf.one_hot(
  203. x, self.dist.logits.shape.as_list()[-1], dtype=tf.float32)
  204. assert values.shape == self.dist.logits.shape, (
  205. values.shape, self.dist.logits.shape)
  206. # [0]'s implementation (see line below) seems to be an approximation
  207. # to the actual Gumbel Softmax density.
  208. return -tf.reduce_sum(
  209. -x * tf.nn.log_softmax(self.dist.logits, axis=-1), axis=-1)
  210. @override(TFActionDistribution)
  211. def _build_sample_op(self) -> TensorType:
  212. return self.dist.sample()
  213. @staticmethod
  214. @override(ActionDistribution)
  215. def required_model_output_shape(
  216. action_space: gym.Space,
  217. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  218. return action_space.n
  219. class DiagGaussian(TFActionDistribution):
  220. """Action distribution where each vector element is a gaussian.
  221. The first half of the input vector defines the gaussian means, and the
  222. second half the gaussian standard deviations.
  223. """
  224. def __init__(self,
  225. inputs: List[TensorType],
  226. model: ModelV2,
  227. *,
  228. action_space: Optional[gym.spaces.Space] = None):
  229. mean, log_std = tf.split(inputs, 2, axis=1)
  230. self.mean = mean
  231. self.log_std = log_std
  232. self.std = tf.exp(log_std)
  233. # Remember to squeeze action samples in case action space is Box(shape)
  234. self.zero_action_dim = action_space and action_space.shape == ()
  235. super().__init__(inputs, model)
  236. @override(ActionDistribution)
  237. def deterministic_sample(self) -> TensorType:
  238. return self.mean
  239. @override(ActionDistribution)
  240. def logp(self, x: TensorType) -> TensorType:
  241. # Cover case where action space is Box(shape=()).
  242. if int(tf.shape(x).shape[0]) == 1:
  243. x = tf.expand_dims(x, axis=1)
  244. return -0.5 * tf.reduce_sum(
  245. tf.math.square((tf.cast(x, tf.float32) - self.mean) / self.std),
  246. axis=1
  247. ) - 0.5 * np.log(2.0 * np.pi) * tf.cast(tf.shape(x)[1], tf.float32) - \
  248. tf.reduce_sum(self.log_std, axis=1)
  249. @override(ActionDistribution)
  250. def kl(self, other: ActionDistribution) -> TensorType:
  251. assert isinstance(other, DiagGaussian)
  252. return tf.reduce_sum(
  253. other.log_std - self.log_std +
  254. (tf.math.square(self.std) + tf.math.square(self.mean - other.mean))
  255. / (2.0 * tf.math.square(other.std)) - 0.5,
  256. axis=1)
  257. @override(ActionDistribution)
  258. def entropy(self) -> TensorType:
  259. return tf.reduce_sum(
  260. self.log_std + .5 * np.log(2.0 * np.pi * np.e), axis=1)
  261. @override(TFActionDistribution)
  262. def _build_sample_op(self) -> TensorType:
  263. sample = self.mean + self.std * tf.random.normal(tf.shape(self.mean))
  264. if self.zero_action_dim:
  265. return tf.squeeze(sample, axis=-1)
  266. return sample
  267. @staticmethod
  268. @override(ActionDistribution)
  269. def required_model_output_shape(
  270. action_space: gym.Space,
  271. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  272. return np.prod(action_space.shape) * 2
  273. class SquashedGaussian(TFActionDistribution):
  274. """A tanh-squashed Gaussian distribution defined by: mean, std, low, high.
  275. The distribution will never return low or high exactly, but
  276. `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
  277. """
  278. def __init__(self,
  279. inputs: List[TensorType],
  280. model: ModelV2,
  281. low: float = -1.0,
  282. high: float = 1.0):
  283. """Parameterizes the distribution via `inputs`.
  284. Args:
  285. low (float): The lowest possible sampling value
  286. (excluding this value).
  287. high (float): The highest possible sampling value
  288. (excluding this value).
  289. """
  290. assert tfp is not None
  291. mean, log_std = tf.split(inputs, 2, axis=-1)
  292. # Clip `scale` values (coming from NN) to reasonable values.
  293. log_std = tf.clip_by_value(log_std, MIN_LOG_NN_OUTPUT,
  294. MAX_LOG_NN_OUTPUT)
  295. std = tf.exp(log_std)
  296. self.distr = tfp.distributions.Normal(loc=mean, scale=std)
  297. assert np.all(np.less(low, high))
  298. self.low = low
  299. self.high = high
  300. super().__init__(inputs, model)
  301. @override(ActionDistribution)
  302. def deterministic_sample(self) -> TensorType:
  303. mean = self.distr.mean()
  304. return self._squash(mean)
  305. @override(TFActionDistribution)
  306. def _build_sample_op(self) -> TensorType:
  307. return self._squash(self.distr.sample())
  308. @override(ActionDistribution)
  309. def logp(self, x: TensorType) -> TensorType:
  310. # Unsquash values (from [low,high] to ]-inf,inf[)
  311. unsquashed_values = tf.cast(self._unsquash(x), self.inputs.dtype)
  312. # Get log prob of unsquashed values from our Normal.
  313. log_prob_gaussian = self.distr.log_prob(unsquashed_values)
  314. # For safety reasons, clamp somehow, only then sum up.
  315. log_prob_gaussian = tf.clip_by_value(log_prob_gaussian, -100, 100)
  316. log_prob_gaussian = tf.reduce_sum(log_prob_gaussian, axis=-1)
  317. # Get log-prob for squashed Gaussian.
  318. unsquashed_values_tanhd = tf.math.tanh(unsquashed_values)
  319. log_prob = log_prob_gaussian - tf.reduce_sum(
  320. tf.math.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER),
  321. axis=-1)
  322. return log_prob
  323. def sample_logp(self):
  324. z = self.distr.sample()
  325. actions = self._squash(z)
  326. return actions, tf.reduce_sum(
  327. self.distr.log_prob(z) -
  328. tf.math.log(1 - actions * actions + SMALL_NUMBER),
  329. axis=-1)
  330. @override(ActionDistribution)
  331. def entropy(self) -> TensorType:
  332. raise ValueError("Entropy not defined for SquashedGaussian!")
  333. @override(ActionDistribution)
  334. def kl(self, other: ActionDistribution) -> TensorType:
  335. raise ValueError("KL not defined for SquashedGaussian!")
  336. def _squash(self, raw_values: TensorType) -> TensorType:
  337. # Returned values are within [low, high] (including `low` and `high`).
  338. squashed = ((tf.math.tanh(raw_values) + 1.0) / 2.0) * \
  339. (self.high - self.low) + self.low
  340. return tf.clip_by_value(squashed, self.low, self.high)
  341. def _unsquash(self, values: TensorType) -> TensorType:
  342. normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
  343. 1.0
  344. # Stabilize input to atanh.
  345. save_normed_values = tf.clip_by_value(
  346. normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER)
  347. unsquashed = tf.math.atanh(save_normed_values)
  348. return unsquashed
  349. @staticmethod
  350. @override(ActionDistribution)
  351. def required_model_output_shape(
  352. action_space: gym.Space,
  353. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  354. return np.prod(action_space.shape) * 2
  355. class Beta(TFActionDistribution):
  356. """
  357. A Beta distribution is defined on the interval [0, 1] and parameterized by
  358. shape parameters alpha and beta (also called concentration parameters).
  359. PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
  360. with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
  361. and Gamma(n) = (n - 1)!
  362. """
  363. def __init__(self,
  364. inputs: List[TensorType],
  365. model: ModelV2,
  366. low: float = 0.0,
  367. high: float = 1.0):
  368. # Stabilize input parameters (possibly coming from a linear layer).
  369. inputs = tf.clip_by_value(inputs, log(SMALL_NUMBER),
  370. -log(SMALL_NUMBER))
  371. inputs = tf.math.log(tf.math.exp(inputs) + 1.0) + 1.0
  372. self.low = low
  373. self.high = high
  374. alpha, beta = tf.split(inputs, 2, axis=-1)
  375. # Note: concentration0==beta, concentration1=alpha (!)
  376. self.dist = tfp.distributions.Beta(
  377. concentration1=alpha, concentration0=beta)
  378. super().__init__(inputs, model)
  379. @override(ActionDistribution)
  380. def deterministic_sample(self) -> TensorType:
  381. mean = self.dist.mean()
  382. return self._squash(mean)
  383. @override(TFActionDistribution)
  384. def _build_sample_op(self) -> TensorType:
  385. return self._squash(self.dist.sample())
  386. @override(ActionDistribution)
  387. def logp(self, x: TensorType) -> TensorType:
  388. unsquashed_values = self._unsquash(x)
  389. return tf.math.reduce_sum(
  390. self.dist.log_prob(unsquashed_values), axis=-1)
  391. def _squash(self, raw_values: TensorType) -> TensorType:
  392. return raw_values * (self.high - self.low) + self.low
  393. def _unsquash(self, values: TensorType) -> TensorType:
  394. return (values - self.low) / (self.high - self.low)
  395. @staticmethod
  396. @override(ActionDistribution)
  397. def required_model_output_shape(
  398. action_space: gym.Space,
  399. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  400. return np.prod(action_space.shape) * 2
  401. class Deterministic(TFActionDistribution):
  402. """Action distribution that returns the input values directly.
  403. This is similar to DiagGaussian with standard deviation zero (thus only
  404. requiring the "mean" values as NN output).
  405. """
  406. @override(ActionDistribution)
  407. def deterministic_sample(self) -> TensorType:
  408. return self.inputs
  409. @override(TFActionDistribution)
  410. def logp(self, x: TensorType) -> TensorType:
  411. return tf.zeros_like(self.inputs)
  412. @override(TFActionDistribution)
  413. def _build_sample_op(self) -> TensorType:
  414. return self.inputs
  415. @staticmethod
  416. @override(ActionDistribution)
  417. def required_model_output_shape(
  418. action_space: gym.Space,
  419. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  420. return np.prod(action_space.shape)
  421. class MultiActionDistribution(TFActionDistribution):
  422. """Action distribution that operates on a set of actions.
  423. Args:
  424. inputs (Tensor list): A list of tensors from which to compute samples.
  425. """
  426. def __init__(self, inputs, model, *, child_distributions, input_lens,
  427. action_space):
  428. ActionDistribution.__init__(self, inputs, model)
  429. self.action_space_struct = get_base_struct_from_space(action_space)
  430. self.input_lens = np.array(input_lens, dtype=np.int32)
  431. split_inputs = tf.split(inputs, self.input_lens, axis=1)
  432. self.flat_child_distributions = tree.map_structure(
  433. lambda dist, input_: dist(input_, model), child_distributions,
  434. split_inputs)
  435. @override(ActionDistribution)
  436. def logp(self, x):
  437. # Single tensor input (all merged).
  438. if isinstance(x, (tf.Tensor, np.ndarray)):
  439. split_indices = []
  440. for dist in self.flat_child_distributions:
  441. if isinstance(dist, Categorical):
  442. split_indices.append(1)
  443. elif isinstance(dist, MultiCategorical) and \
  444. dist.action_space is not None:
  445. split_indices.append(np.prod(dist.action_space.shape))
  446. else:
  447. sample = dist.sample()
  448. # Cover Box(shape=()) case.
  449. if len(sample.shape) == 1:
  450. split_indices.append(1)
  451. else:
  452. split_indices.append(tf.shape(sample)[1])
  453. split_x = tf.split(x, split_indices, axis=1)
  454. # Structured or flattened (by single action component) input.
  455. else:
  456. split_x = tree.flatten(x)
  457. def map_(val, dist):
  458. # Remove extra categorical dimension.
  459. if isinstance(dist, Categorical):
  460. val = tf.cast(
  461. tf.squeeze(val, axis=-1)
  462. if len(val.shape) > 1 else val, tf.int32)
  463. return dist.logp(val)
  464. # Remove extra categorical dimension and take the logp of each
  465. # component.
  466. flat_logps = tree.map_structure(map_, split_x,
  467. self.flat_child_distributions)
  468. return functools.reduce(lambda a, b: a + b, flat_logps)
  469. @override(ActionDistribution)
  470. def kl(self, other):
  471. kl_list = [
  472. d.kl(o) for d, o in zip(self.flat_child_distributions,
  473. other.flat_child_distributions)
  474. ]
  475. return functools.reduce(lambda a, b: a + b, kl_list)
  476. @override(ActionDistribution)
  477. def entropy(self):
  478. entropy_list = [d.entropy() for d in self.flat_child_distributions]
  479. return functools.reduce(lambda a, b: a + b, entropy_list)
  480. @override(ActionDistribution)
  481. def sample(self):
  482. child_distributions = tree.unflatten_as(self.action_space_struct,
  483. self.flat_child_distributions)
  484. return tree.map_structure(lambda s: s.sample(), child_distributions)
  485. @override(ActionDistribution)
  486. def deterministic_sample(self):
  487. child_distributions = tree.unflatten_as(self.action_space_struct,
  488. self.flat_child_distributions)
  489. return tree.map_structure(lambda s: s.deterministic_sample(),
  490. child_distributions)
  491. @override(TFActionDistribution)
  492. def sampled_action_logp(self):
  493. p = self.flat_child_distributions[0].sampled_action_logp()
  494. for c in self.flat_child_distributions[1:]:
  495. p += c.sampled_action_logp()
  496. return p
  497. @override(ActionDistribution)
  498. def required_model_output_shape(self, action_space, model_config):
  499. return np.sum(self.input_lens)
  500. class Dirichlet(TFActionDistribution):
  501. """Dirichlet distribution for continuous actions that are between
  502. [0,1] and sum to 1.
  503. e.g. actions that represent resource allocation."""
  504. def __init__(self, inputs: List[TensorType], model: ModelV2):
  505. """Input is a tensor of logits. The exponential of logits is used to
  506. parametrize the Dirichlet distribution as all parameters need to be
  507. positive. An arbitrary small epsilon is added to the concentration
  508. parameters to be zero due to numerical error.
  509. See issue #4440 for more details.
  510. """
  511. self.epsilon = 1e-7
  512. concentration = tf.exp(inputs) + self.epsilon
  513. self.dist = tf1.distributions.Dirichlet(
  514. concentration=concentration,
  515. validate_args=True,
  516. allow_nan_stats=False,
  517. )
  518. super().__init__(concentration, model)
  519. @override(ActionDistribution)
  520. def deterministic_sample(self) -> TensorType:
  521. return tf.nn.softmax(self.dist.concentration)
  522. @override(ActionDistribution)
  523. def logp(self, x: TensorType) -> TensorType:
  524. # Support of Dirichlet are positive real numbers. x is already
  525. # an array of positive numbers, but we clip to avoid zeros due to
  526. # numerical errors.
  527. x = tf.maximum(x, self.epsilon)
  528. x = x / tf.reduce_sum(x, axis=-1, keepdims=True)
  529. return self.dist.log_prob(x)
  530. @override(ActionDistribution)
  531. def entropy(self) -> TensorType:
  532. return self.dist.entropy()
  533. @override(ActionDistribution)
  534. def kl(self, other: ActionDistribution) -> TensorType:
  535. return self.dist.kl_divergence(other.dist)
  536. @override(TFActionDistribution)
  537. def _build_sample_op(self) -> TensorType:
  538. return self.dist.sample()
  539. @staticmethod
  540. @override(ActionDistribution)
  541. def required_model_output_shape(
  542. action_space: gym.Space,
  543. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  544. return np.prod(action_space.shape)