torch_action_dist.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. import functools
  2. import gym
  3. from math import log
  4. import numpy as np
  5. import tree # pip install dm_tree
  6. from typing import Optional
  7. from ray.rllib.models.action_dist import ActionDistribution
  8. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  9. from ray.rllib.utils.annotations import override
  10. from ray.rllib.utils.framework import try_import_torch
  11. from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, \
  12. MAX_LOG_NN_OUTPUT
  13. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  14. from ray.rllib.utils.typing import TensorType, List, Union, \
  15. Tuple, ModelConfigDict
  16. torch, nn = try_import_torch()
  17. class TorchDistributionWrapper(ActionDistribution):
  18. """Wrapper class for torch.distributions."""
  19. @override(ActionDistribution)
  20. def __init__(self, inputs: List[TensorType], model: TorchModelV2):
  21. # If inputs are not a torch Tensor, make them one and make sure they
  22. # are on the correct device.
  23. if not isinstance(inputs, torch.Tensor):
  24. inputs = torch.from_numpy(inputs)
  25. if isinstance(model, TorchModelV2):
  26. inputs = inputs.to(next(model.parameters()).device)
  27. super().__init__(inputs, model)
  28. # Store the last sample here.
  29. self.last_sample = None
  30. @override(ActionDistribution)
  31. def logp(self, actions: TensorType) -> TensorType:
  32. return self.dist.log_prob(actions)
  33. @override(ActionDistribution)
  34. def entropy(self) -> TensorType:
  35. return self.dist.entropy()
  36. @override(ActionDistribution)
  37. def kl(self, other: ActionDistribution) -> TensorType:
  38. return torch.distributions.kl.kl_divergence(self.dist, other.dist)
  39. @override(ActionDistribution)
  40. def sample(self) -> TensorType:
  41. self.last_sample = self.dist.sample()
  42. return self.last_sample
  43. @override(ActionDistribution)
  44. def sampled_action_logp(self) -> TensorType:
  45. assert self.last_sample is not None
  46. return self.logp(self.last_sample)
  47. class TorchCategorical(TorchDistributionWrapper):
  48. """Wrapper class for PyTorch Categorical distribution."""
  49. @override(ActionDistribution)
  50. def __init__(self,
  51. inputs: List[TensorType],
  52. model: TorchModelV2 = None,
  53. temperature: float = 1.0):
  54. if temperature != 1.0:
  55. assert temperature > 0.0, \
  56. "Categorical `temperature` must be > 0.0!"
  57. inputs /= temperature
  58. super().__init__(inputs, model)
  59. self.dist = torch.distributions.categorical.Categorical(
  60. logits=self.inputs)
  61. @override(ActionDistribution)
  62. def deterministic_sample(self) -> TensorType:
  63. self.last_sample = self.dist.probs.argmax(dim=1)
  64. return self.last_sample
  65. @staticmethod
  66. @override(ActionDistribution)
  67. def required_model_output_shape(
  68. action_space: gym.Space,
  69. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  70. return action_space.n
  71. class TorchMultiCategorical(TorchDistributionWrapper):
  72. """MultiCategorical distribution for MultiDiscrete action spaces."""
  73. @override(TorchDistributionWrapper)
  74. def __init__(self,
  75. inputs: List[TensorType],
  76. model: TorchModelV2,
  77. input_lens: Union[List[int], np.ndarray, Tuple[int, ...]],
  78. action_space=None):
  79. super().__init__(inputs, model)
  80. # If input_lens is np.ndarray or list, force-make it a tuple.
  81. inputs_split = self.inputs.split(tuple(input_lens), dim=1)
  82. self.cats = [
  83. torch.distributions.categorical.Categorical(logits=input_)
  84. for input_ in inputs_split
  85. ]
  86. # Used in case we are dealing with an Int Box.
  87. self.action_space = action_space
  88. @override(TorchDistributionWrapper)
  89. def sample(self) -> TensorType:
  90. arr = [cat.sample() for cat in self.cats]
  91. sample_ = torch.stack(arr, dim=1)
  92. if isinstance(self.action_space, gym.spaces.Box):
  93. sample_ = torch.reshape(sample_,
  94. [-1] + list(self.action_space.shape))
  95. self.last_sample = sample_
  96. return sample_
  97. @override(ActionDistribution)
  98. def deterministic_sample(self) -> TensorType:
  99. arr = [torch.argmax(cat.probs, -1) for cat in self.cats]
  100. sample_ = torch.stack(arr, dim=1)
  101. if isinstance(self.action_space, gym.spaces.Box):
  102. sample_ = torch.reshape(sample_,
  103. [-1] + list(self.action_space.shape))
  104. self.last_sample = sample_
  105. return sample_
  106. @override(TorchDistributionWrapper)
  107. def logp(self, actions: TensorType) -> TensorType:
  108. # # If tensor is provided, unstack it into list.
  109. if isinstance(actions, torch.Tensor):
  110. if isinstance(self.action_space, gym.spaces.Box):
  111. actions = torch.reshape(
  112. actions, [-1, int(np.product(self.action_space.shape))])
  113. actions = torch.unbind(actions, dim=1)
  114. logps = torch.stack(
  115. [cat.log_prob(act) for cat, act in zip(self.cats, actions)])
  116. return torch.sum(logps, dim=0)
  117. @override(ActionDistribution)
  118. def multi_entropy(self) -> TensorType:
  119. return torch.stack([cat.entropy() for cat in self.cats], dim=1)
  120. @override(TorchDistributionWrapper)
  121. def entropy(self) -> TensorType:
  122. return torch.sum(self.multi_entropy(), dim=1)
  123. @override(ActionDistribution)
  124. def multi_kl(self, other: ActionDistribution) -> TensorType:
  125. return torch.stack(
  126. [
  127. torch.distributions.kl.kl_divergence(cat, oth_cat)
  128. for cat, oth_cat in zip(self.cats, other.cats)
  129. ],
  130. dim=1,
  131. )
  132. @override(TorchDistributionWrapper)
  133. def kl(self, other: ActionDistribution) -> TensorType:
  134. return torch.sum(self.multi_kl(other), dim=1)
  135. @staticmethod
  136. @override(ActionDistribution)
  137. def required_model_output_shape(
  138. action_space: gym.Space,
  139. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  140. # Int Box.
  141. if isinstance(action_space, gym.spaces.Box):
  142. assert action_space.dtype.name.startswith("int")
  143. low_ = np.min(action_space.low)
  144. high_ = np.max(action_space.high)
  145. assert np.all(action_space.low == low_)
  146. assert np.all(action_space.high == high_)
  147. np.product(action_space.shape) * (high_ - low_ + 1)
  148. # MultiDiscrete space.
  149. else:
  150. return np.sum(action_space.nvec)
  151. class TorchDiagGaussian(TorchDistributionWrapper):
  152. """Wrapper class for PyTorch Normal distribution."""
  153. @override(ActionDistribution)
  154. def __init__(self,
  155. inputs: List[TensorType],
  156. model: TorchModelV2,
  157. *,
  158. action_space: Optional[gym.spaces.Space] = None):
  159. super().__init__(inputs, model)
  160. mean, log_std = torch.chunk(self.inputs, 2, dim=1)
  161. self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
  162. # Remember to squeeze action samples in case action space is Box(shape)
  163. self.zero_action_dim = action_space and action_space.shape == ()
  164. @override(TorchDistributionWrapper)
  165. def sample(self) -> TensorType:
  166. sample = super().sample()
  167. if self.zero_action_dim:
  168. return torch.squeeze(sample, dim=-1)
  169. return sample
  170. @override(ActionDistribution)
  171. def deterministic_sample(self) -> TensorType:
  172. self.last_sample = self.dist.mean
  173. return self.last_sample
  174. @override(TorchDistributionWrapper)
  175. def logp(self, actions: TensorType) -> TensorType:
  176. return super().logp(actions).sum(-1)
  177. @override(TorchDistributionWrapper)
  178. def entropy(self) -> TensorType:
  179. return super().entropy().sum(-1)
  180. @override(TorchDistributionWrapper)
  181. def kl(self, other: ActionDistribution) -> TensorType:
  182. return super().kl(other).sum(-1)
  183. @staticmethod
  184. @override(ActionDistribution)
  185. def required_model_output_shape(
  186. action_space: gym.Space,
  187. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  188. return np.prod(action_space.shape) * 2
  189. class TorchSquashedGaussian(TorchDistributionWrapper):
  190. """A tanh-squashed Gaussian distribution defined by: mean, std, low, high.
  191. The distribution will never return low or high exactly, but
  192. `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively.
  193. """
  194. def __init__(self,
  195. inputs: List[TensorType],
  196. model: TorchModelV2,
  197. low: float = -1.0,
  198. high: float = 1.0):
  199. """Parameterizes the distribution via `inputs`.
  200. Args:
  201. low (float): The lowest possible sampling value
  202. (excluding this value).
  203. high (float): The highest possible sampling value
  204. (excluding this value).
  205. """
  206. super().__init__(inputs, model)
  207. # Split inputs into mean and log(std).
  208. mean, log_std = torch.chunk(self.inputs, 2, dim=-1)
  209. # Clip `scale` values (coming from NN) to reasonable values.
  210. log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT)
  211. std = torch.exp(log_std)
  212. self.dist = torch.distributions.normal.Normal(mean, std)
  213. assert np.all(np.less(low, high))
  214. self.low = low
  215. self.high = high
  216. self.mean = mean
  217. self.std = std
  218. @override(ActionDistribution)
  219. def deterministic_sample(self) -> TensorType:
  220. self.last_sample = self._squash(self.dist.mean)
  221. return self.last_sample
  222. @override(TorchDistributionWrapper)
  223. def sample(self) -> TensorType:
  224. # Use the reparameterization version of `dist.sample` to allow for
  225. # the results to be backprop'able e.g. in a loss term.
  226. normal_sample = self.dist.rsample()
  227. self.last_sample = self._squash(normal_sample)
  228. return self.last_sample
  229. @override(ActionDistribution)
  230. def logp(self, x: TensorType) -> TensorType:
  231. # Unsquash values (from [low,high] to ]-inf,inf[)
  232. unsquashed_values = self._unsquash(x)
  233. # Get log prob of unsquashed values from our Normal.
  234. log_prob_gaussian = self.dist.log_prob(unsquashed_values)
  235. # For safety reasons, clamp somehow, only then sum up.
  236. log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100)
  237. log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1)
  238. # Get log-prob for squashed Gaussian.
  239. unsquashed_values_tanhd = torch.tanh(unsquashed_values)
  240. log_prob = log_prob_gaussian - torch.sum(
  241. torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1)
  242. return log_prob
  243. def sample_logp(self):
  244. z = self.dist.rsample()
  245. actions = self._squash(z)
  246. return actions, torch.sum(
  247. self.dist.log_prob(z) -
  248. torch.log(1 - actions * actions + SMALL_NUMBER),
  249. dim=-1)
  250. @override(TorchDistributionWrapper)
  251. def entropy(self) -> TensorType:
  252. raise ValueError("Entropy not defined for SquashedGaussian!")
  253. @override(TorchDistributionWrapper)
  254. def kl(self, other: ActionDistribution) -> TensorType:
  255. raise ValueError("KL not defined for SquashedGaussian!")
  256. def _squash(self, raw_values: TensorType) -> TensorType:
  257. # Returned values are within [low, high] (including `low` and `high`).
  258. squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * \
  259. (self.high - self.low) + self.low
  260. return torch.clamp(squashed, self.low, self.high)
  261. def _unsquash(self, values: TensorType) -> TensorType:
  262. normed_values = (values - self.low) / (self.high - self.low) * 2.0 - \
  263. 1.0
  264. # Stabilize input to atanh.
  265. save_normed_values = torch.clamp(normed_values, -1.0 + SMALL_NUMBER,
  266. 1.0 - SMALL_NUMBER)
  267. unsquashed = torch.atanh(save_normed_values)
  268. return unsquashed
  269. @staticmethod
  270. @override(ActionDistribution)
  271. def required_model_output_shape(
  272. action_space: gym.Space,
  273. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  274. return np.prod(action_space.shape) * 2
  275. class TorchBeta(TorchDistributionWrapper):
  276. """
  277. A Beta distribution is defined on the interval [0, 1] and parameterized by
  278. shape parameters alpha and beta (also called concentration parameters).
  279. PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z
  280. with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta)
  281. and Gamma(n) = (n - 1)!
  282. """
  283. def __init__(self,
  284. inputs: List[TensorType],
  285. model: TorchModelV2,
  286. low: float = 0.0,
  287. high: float = 1.0):
  288. super().__init__(inputs, model)
  289. # Stabilize input parameters (possibly coming from a linear layer).
  290. self.inputs = torch.clamp(self.inputs, log(SMALL_NUMBER),
  291. -log(SMALL_NUMBER))
  292. self.inputs = torch.log(torch.exp(self.inputs) + 1.0) + 1.0
  293. self.low = low
  294. self.high = high
  295. alpha, beta = torch.chunk(self.inputs, 2, dim=-1)
  296. # Note: concentration0==beta, concentration1=alpha (!)
  297. self.dist = torch.distributions.Beta(
  298. concentration1=alpha, concentration0=beta)
  299. @override(ActionDistribution)
  300. def deterministic_sample(self) -> TensorType:
  301. self.last_sample = self._squash(self.dist.mean)
  302. return self.last_sample
  303. @override(TorchDistributionWrapper)
  304. def sample(self) -> TensorType:
  305. # Use the reparameterization version of `dist.sample` to allow for
  306. # the results to be backprop'able e.g. in a loss term.
  307. normal_sample = self.dist.rsample()
  308. self.last_sample = self._squash(normal_sample)
  309. return self.last_sample
  310. @override(ActionDistribution)
  311. def logp(self, x: TensorType) -> TensorType:
  312. unsquashed_values = self._unsquash(x)
  313. return torch.sum(self.dist.log_prob(unsquashed_values), dim=-1)
  314. def _squash(self, raw_values: TensorType) -> TensorType:
  315. return raw_values * (self.high - self.low) + self.low
  316. def _unsquash(self, values: TensorType) -> TensorType:
  317. return (values - self.low) / (self.high - self.low)
  318. @staticmethod
  319. @override(ActionDistribution)
  320. def required_model_output_shape(
  321. action_space: gym.Space,
  322. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  323. return np.prod(action_space.shape) * 2
  324. class TorchDeterministic(TorchDistributionWrapper):
  325. """Action distribution that returns the input values directly.
  326. This is similar to DiagGaussian with standard deviation zero (thus only
  327. requiring the "mean" values as NN output).
  328. """
  329. @override(ActionDistribution)
  330. def deterministic_sample(self) -> TensorType:
  331. return self.inputs
  332. @override(TorchDistributionWrapper)
  333. def sampled_action_logp(self) -> TensorType:
  334. return torch.zeros((self.inputs.size()[0], ), dtype=torch.float32)
  335. @override(TorchDistributionWrapper)
  336. def sample(self) -> TensorType:
  337. return self.deterministic_sample()
  338. @staticmethod
  339. @override(ActionDistribution)
  340. def required_model_output_shape(
  341. action_space: gym.Space,
  342. model_config: ModelConfigDict) -> Union[int, np.ndarray]:
  343. return np.prod(action_space.shape)
  344. class TorchMultiActionDistribution(TorchDistributionWrapper):
  345. """Action distribution that operates on multiple, possibly nested actions.
  346. """
  347. def __init__(self, inputs, model, *, child_distributions, input_lens,
  348. action_space):
  349. """Initializes a TorchMultiActionDistribution object.
  350. Args:
  351. inputs (torch.Tensor): A single tensor of shape [BATCH, size].
  352. model (TorchModelV2): The TorchModelV2 object used to produce
  353. inputs for this distribution.
  354. child_distributions (any[torch.Tensor]): Any struct
  355. that contains the child distribution classes to use to
  356. instantiate the child distributions from `inputs`. This could
  357. be an already flattened list or a struct according to
  358. `action_space`.
  359. input_lens (any[int]): A flat list or a nested struct of input
  360. split lengths used to split `inputs`.
  361. action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex
  362. and possibly nested action space.
  363. """
  364. if not isinstance(inputs, torch.Tensor):
  365. inputs = torch.from_numpy(inputs)
  366. if isinstance(model, TorchModelV2):
  367. inputs = inputs.to(next(model.parameters()).device)
  368. super().__init__(inputs, model)
  369. self.action_space_struct = get_base_struct_from_space(action_space)
  370. self.input_lens = tree.flatten(input_lens)
  371. flat_child_distributions = tree.flatten(child_distributions)
  372. split_inputs = torch.split(inputs, self.input_lens, dim=1)
  373. self.flat_child_distributions = tree.map_structure(
  374. lambda dist, input_: dist(input_, model), flat_child_distributions,
  375. list(split_inputs))
  376. @override(ActionDistribution)
  377. def logp(self, x):
  378. if isinstance(x, np.ndarray):
  379. x = torch.Tensor(x)
  380. # Single tensor input (all merged).
  381. if isinstance(x, torch.Tensor):
  382. split_indices = []
  383. for dist in self.flat_child_distributions:
  384. if isinstance(dist, TorchCategorical):
  385. split_indices.append(1)
  386. elif isinstance(dist, TorchMultiCategorical) and \
  387. dist.action_space is not None:
  388. split_indices.append(int(np.prod(dist.action_space.shape)))
  389. else:
  390. sample = dist.sample()
  391. # Cover Box(shape=()) case.
  392. if len(sample.shape) == 1:
  393. split_indices.append(1)
  394. else:
  395. split_indices.append(sample.size()[1])
  396. split_x = list(torch.split(x, split_indices, dim=1))
  397. # Structured or flattened (by single action component) input.
  398. else:
  399. split_x = tree.flatten(x)
  400. def map_(val, dist):
  401. # Remove extra categorical dimension.
  402. if isinstance(dist, TorchCategorical):
  403. val = (torch.squeeze(val, dim=-1)
  404. if len(val.shape) > 1 else val).int()
  405. return dist.logp(val)
  406. # Remove extra categorical dimension and take the logp of each
  407. # component.
  408. flat_logps = tree.map_structure(map_, split_x,
  409. self.flat_child_distributions)
  410. return functools.reduce(lambda a, b: a + b, flat_logps)
  411. @override(ActionDistribution)
  412. def kl(self, other):
  413. kl_list = [
  414. d.kl(o) for d, o in zip(self.flat_child_distributions,
  415. other.flat_child_distributions)
  416. ]
  417. return functools.reduce(lambda a, b: a + b, kl_list)
  418. @override(ActionDistribution)
  419. def entropy(self):
  420. entropy_list = [d.entropy() for d in self.flat_child_distributions]
  421. return functools.reduce(lambda a, b: a + b, entropy_list)
  422. @override(ActionDistribution)
  423. def sample(self):
  424. child_distributions = tree.unflatten_as(self.action_space_struct,
  425. self.flat_child_distributions)
  426. return tree.map_structure(lambda s: s.sample(), child_distributions)
  427. @override(ActionDistribution)
  428. def deterministic_sample(self):
  429. child_distributions = tree.unflatten_as(self.action_space_struct,
  430. self.flat_child_distributions)
  431. return tree.map_structure(lambda s: s.deterministic_sample(),
  432. child_distributions)
  433. @override(TorchDistributionWrapper)
  434. def sampled_action_logp(self):
  435. p = self.flat_child_distributions[0].sampled_action_logp()
  436. for c in self.flat_child_distributions[1:]:
  437. p += c.sampled_action_logp()
  438. return p
  439. @override(ActionDistribution)
  440. def required_model_output_shape(self, action_space, model_config):
  441. return np.sum(self.input_lens)
  442. class TorchDirichlet(TorchDistributionWrapper):
  443. """Dirichlet distribution for continuous actions that are between
  444. [0,1] and sum to 1.
  445. e.g. actions that represent resource allocation."""
  446. def __init__(self, inputs, model):
  447. """Input is a tensor of logits. The exponential of logits is used to
  448. parametrize the Dirichlet distribution as all parameters need to be
  449. positive. An arbitrary small epsilon is added to the concentration
  450. parameters to be zero due to numerical error.
  451. See issue #4440 for more details.
  452. """
  453. self.epsilon = torch.tensor(1e-7).to(inputs.device)
  454. concentration = torch.exp(inputs) + self.epsilon
  455. self.dist = torch.distributions.dirichlet.Dirichlet(
  456. concentration=concentration,
  457. validate_args=True,
  458. )
  459. super().__init__(concentration, model)
  460. @override(ActionDistribution)
  461. def deterministic_sample(self) -> TensorType:
  462. self.last_sample = nn.functional.softmax(self.dist.concentration)
  463. return self.last_sample
  464. @override(ActionDistribution)
  465. def logp(self, x):
  466. # Support of Dirichlet are positive real numbers. x is already
  467. # an array of positive numbers, but we clip to avoid zeros due to
  468. # numerical errors.
  469. x = torch.max(x, self.epsilon)
  470. x = x / torch.sum(x, dim=-1, keepdim=True)
  471. return self.dist.log_prob(x)
  472. @override(ActionDistribution)
  473. def entropy(self):
  474. return self.dist.entropy()
  475. @override(ActionDistribution)
  476. def kl(self, other):
  477. return self.dist.kl_divergence(other.dist)
  478. @staticmethod
  479. @override(ActionDistribution)
  480. def required_model_output_shape(action_space, model_config):
  481. return np.prod(action_space.shape)