torch_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. import gym
  2. from gym.spaces import Discrete, MultiDiscrete
  3. import numpy as np
  4. import os
  5. import tree # pip install dm_tree
  6. from typing import Dict, List, Optional, TYPE_CHECKING
  7. import warnings
  8. from ray.rllib.models.repeated_values import RepeatedValues
  9. from ray.rllib.utils.deprecation import Deprecated
  10. from ray.rllib.utils.framework import try_import_torch
  11. from ray.rllib.utils.numpy import SMALL_NUMBER
  12. from ray.rllib.utils.typing import LocalOptimizer, SpaceStruct, TensorType, \
  13. TensorStructType
  14. if TYPE_CHECKING:
  15. from ray.rllib.policy.torch_policy import TorchPolicy
  16. torch, nn = try_import_torch()
  17. # Limit values suitable for use as close to a -inf logit. These are useful
  18. # since -inf / inf cause NaNs during backprop.
  19. FLOAT_MIN = -3.4e38
  20. FLOAT_MAX = 3.4e38
  21. def apply_grad_clipping(policy: "TorchPolicy", optimizer: LocalOptimizer,
  22. loss: TensorType) -> Dict[str, TensorType]:
  23. """Applies gradient clipping to already computed grads inside `optimizer`.
  24. Args:
  25. policy: The TorchPolicy, which calculated `loss`.
  26. optimizer: A local torch optimizer object.
  27. loss: The torch loss tensor.
  28. Returns:
  29. An info dict containing the "grad_norm" key and the resulting clipped
  30. gradients.
  31. """
  32. info = {}
  33. if policy.config["grad_clip"]:
  34. for param_group in optimizer.param_groups:
  35. # Make sure we only pass params with grad != None into torch
  36. # clip_grad_norm_. Would fail otherwise.
  37. params = list(
  38. filter(lambda p: p.grad is not None, param_group["params"]))
  39. if params:
  40. grad_gnorm = nn.utils.clip_grad_norm_(
  41. params, policy.config["grad_clip"])
  42. if isinstance(grad_gnorm, torch.Tensor):
  43. grad_gnorm = grad_gnorm.cpu().numpy()
  44. info["grad_gnorm"] = grad_gnorm
  45. return info
  46. @Deprecated(
  47. old="ray.rllib.utils.torch_utils.atanh",
  48. new="torch.math.atanh",
  49. error=False)
  50. def atanh(x: TensorType) -> TensorType:
  51. """Atanh function for PyTorch."""
  52. return 0.5 * torch.log(
  53. (1 + x).clamp(min=SMALL_NUMBER) / (1 - x).clamp(min=SMALL_NUMBER))
  54. def concat_multi_gpu_td_errors(policy: "TorchPolicy") -> Dict[str, TensorType]:
  55. """Concatenates multi-GPU (per-tower) TD error tensors given TorchPolicy.
  56. TD-errors are extracted from the TorchPolicy via its tower_stats property.
  57. Args:
  58. policy: The TorchPolicy to extract the TD-error values from.
  59. Returns:
  60. A dict mapping strings "td_error" and "mean_td_error" to the
  61. corresponding concatenated and mean-reduced values.
  62. """
  63. td_error = torch.cat(
  64. [
  65. t.tower_stats.get("td_error", torch.tensor([0.0])).to(
  66. policy.device) for t in policy.model_gpu_towers
  67. ],
  68. dim=0)
  69. policy.td_error = td_error
  70. return {
  71. "td_error": td_error,
  72. "mean_td_error": torch.mean(td_error),
  73. }
  74. @Deprecated(new="ray/rllib/utils/numpy.py::convert_to_numpy", error=False)
  75. def convert_to_non_torch_type(stats: TensorStructType) -> TensorStructType:
  76. """Converts values in `stats` to non-Tensor numpy or python types.
  77. Args:
  78. stats (any): Any (possibly nested) struct, the values in which will be
  79. converted and returned as a new struct with all torch tensors
  80. being converted to numpy types.
  81. Returns:
  82. Any: A new struct with the same structure as `stats`, but with all
  83. values converted to non-torch Tensor types.
  84. """
  85. # The mapping function used to numpyize torch Tensors.
  86. def mapping(item):
  87. if isinstance(item, torch.Tensor):
  88. return item.cpu().item() if len(item.size()) == 0 else \
  89. item.detach().cpu().numpy()
  90. else:
  91. return item
  92. return tree.map_structure(mapping, stats)
  93. def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None):
  94. """Converts any struct to torch.Tensors.
  95. x (any): Any (possibly nested) struct, the values in which will be
  96. converted and returned as a new struct with all leaves converted
  97. to torch tensors.
  98. Returns:
  99. Any: A new struct with the same structure as `stats`, but with all
  100. values converted to torch Tensor types.
  101. """
  102. def mapping(item):
  103. # Already torch tensor -> make sure it's on right device.
  104. if torch.is_tensor(item):
  105. return item if device is None else item.to(device)
  106. # Special handling of "Repeated" values.
  107. elif isinstance(item, RepeatedValues):
  108. return RepeatedValues(
  109. tree.map_structure(mapping, item.values), item.lengths,
  110. item.max_len)
  111. # Numpy arrays.
  112. if isinstance(item, np.ndarray):
  113. # Object type (e.g. info dicts in train batch): leave as-is.
  114. if item.dtype == object:
  115. return item
  116. # Non-writable numpy-arrays will cause PyTorch warning.
  117. elif item.flags.writeable is False:
  118. with warnings.catch_warnings():
  119. warnings.simplefilter("ignore")
  120. tensor = torch.from_numpy(item)
  121. # Already numpy: Wrap as torch tensor.
  122. else:
  123. tensor = torch.from_numpy(item)
  124. # Everything else: Convert to numpy, then wrap as torch tensor.
  125. else:
  126. tensor = torch.from_numpy(np.asarray(item))
  127. # Floatify all float64 tensors.
  128. if tensor.dtype == torch.double:
  129. tensor = tensor.float()
  130. return tensor if device is None else tensor.to(device)
  131. return tree.map_structure(mapping, x)
  132. def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
  133. """Computes the explained variance for a pair of labels and predictions.
  134. The formula used is:
  135. max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))
  136. Args:
  137. y: The labels.
  138. pred: The predictions.
  139. Returns:
  140. The explained variance given a pair of labels and predictions.
  141. """
  142. y_var = torch.var(y, dim=[0])
  143. diff_var = torch.var(y - pred, dim=[0])
  144. min_ = torch.tensor([-1.0]).to(pred.device)
  145. return torch.max(min_, 1 - (diff_var / y_var))[0]
  146. def flatten_inputs_to_1d_tensor(inputs: TensorStructType,
  147. spaces_struct: Optional[SpaceStruct] = None,
  148. time_axis: bool = False) -> TensorType:
  149. """Flattens arbitrary input structs according to the given spaces struct.
  150. Returns a single 1D tensor resulting from the different input
  151. components' values.
  152. Thereby:
  153. - Boxes (any shape) get flattened to (B, [T]?, -1). Note that image boxes
  154. are not treated differently from other types of Boxes and get
  155. flattened as well.
  156. - Discrete (int) values are one-hot'd, e.g. a batch of [1, 0, 3] (B=3 with
  157. Discrete(4) space) results in [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]].
  158. - MultiDiscrete values are multi-one-hot'd, e.g. a batch of
  159. [[0, 2], [1, 4]] (B=2 with MultiDiscrete([2, 5]) space) results in
  160. [[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 1]].
  161. Args:
  162. inputs: The inputs to be flattened.
  163. spaces_struct: The structure of the spaces that behind the input
  164. time_axis: Whether all inputs have a time-axis (after the batch axis).
  165. If True, will keep not only the batch axis (0th), but the time axis
  166. (1st) as-is and flatten everything from the 2nd axis up.
  167. Returns:
  168. A single 1D tensor resulting from concatenating all
  169. flattened/one-hot'd input components. Depending on the time_axis flag,
  170. the shape is (B, n) or (B, T, n).
  171. Examples:
  172. >>> # B=2
  173. >>> out = flatten_inputs_to_1d_tensor(
  174. ... {"a": [1, 0], "b": [[[0.0], [0.1]], [1.0], [1.1]]},
  175. ... spaces_struct=dict(a=Discrete(2), b=Box(shape=(2, 1)))
  176. ... )
  177. >>> print(out)
  178. ... [[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]] # B=2 n=4
  179. >>> # B=2; T=2
  180. >>> out = flatten_inputs_to_1d_tensor(
  181. ... ([[1, 0], [0, 1]],
  182. ... [[[0.0, 0.1], [1.0, 1.1]], [[2.0, 2.1], [3.0, 3.1]]]),
  183. ... spaces_struct=tuple([Discrete(2), Box(shape=(2, ))]),
  184. ... time_axis=True
  185. ... )
  186. >>> print(out)
  187. ... [[[0.0, 1.0, 0.0, 0.1], [1.0, 0.0, 1.0, 1.1]],
  188. ... [[1.0, 0.0, 2.0, 2.1], [0.0, 1.0, 3.0, 3.1]]] # B=2 T=2 n=4
  189. """
  190. flat_inputs = tree.flatten(inputs)
  191. flat_spaces = tree.flatten(spaces_struct) if spaces_struct is not None \
  192. else [None] * len(flat_inputs)
  193. B = None
  194. T = None
  195. out = []
  196. for input_, space in zip(flat_inputs, flat_spaces):
  197. # Store batch and (if applicable) time dimension.
  198. if B is None:
  199. B = input_.shape[0]
  200. if time_axis:
  201. T = input_.shape[1]
  202. # One-hot encoding.
  203. if isinstance(space, Discrete):
  204. if time_axis:
  205. input_ = torch.reshape(input_, [B * T])
  206. out.append(one_hot(input_, space).float())
  207. # Multi one-hot encoding.
  208. elif isinstance(space, MultiDiscrete):
  209. if time_axis:
  210. input_ = torch.reshape(input_, [B * T, -1])
  211. out.append(one_hot(input_, space).float())
  212. # Box: Flatten.
  213. else:
  214. if time_axis:
  215. input_ = torch.reshape(input_, [B * T, -1])
  216. else:
  217. input_ = torch.reshape(input_, [B, -1])
  218. out.append(input_.float())
  219. merged = torch.cat(out, dim=-1)
  220. # Restore the time-dimension, if applicable.
  221. if time_axis:
  222. merged = torch.reshape(merged, [B, T, -1])
  223. return merged
  224. def global_norm(tensors: List[TensorType]) -> TensorType:
  225. """Returns the global L2 norm over a list of tensors.
  226. output = sqrt(SUM(t ** 2 for t in tensors)),
  227. where SUM reduces over all tensors and over all elements in tensors.
  228. Args:
  229. tensors: The list of tensors to calculate the global norm over.
  230. Returns:
  231. The global L2 norm over the given tensor list.
  232. """
  233. # List of single tensors' L2 norms: SQRT(SUM(xi^2)) over all xi in tensor.
  234. single_l2s = [
  235. torch.pow(torch.sum(torch.pow(t, 2.0)), 0.5) for t in tensors
  236. ]
  237. # Compute global norm from all single tensors' L2 norms.
  238. return torch.pow(sum(torch.pow(l2, 2.0) for l2 in single_l2s), 0.5)
  239. def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType:
  240. """Computes the huber loss for a given term and delta parameter.
  241. Reference: https://en.wikipedia.org/wiki/Huber_loss
  242. Note that the factor of 0.5 is implicitly included in the calculation.
  243. Formula:
  244. L = 0.5 * x^2 for small abs x (delta threshold)
  245. L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold)
  246. Args:
  247. x: The input term, e.g. a TD error.
  248. delta: The delta parmameter in the above formula.
  249. Returns:
  250. The Huber loss resulting from `x` and `delta`.
  251. """
  252. return torch.where(
  253. torch.abs(x) < delta,
  254. torch.pow(x, 2.0) * 0.5, delta * (torch.abs(x) - 0.5 * delta))
  255. def l2_loss(x: TensorType) -> TensorType:
  256. """Computes half the L2 norm over a tensor's values without the sqrt.
  257. output = 0.5 * sum(x ** 2)
  258. Args:
  259. x: The input tensor.
  260. Returns:
  261. 0.5 times the L2 norm over the given tensor's values (w/o sqrt).
  262. """
  263. return 0.5 * torch.sum(torch.pow(x, 2.0))
  264. def minimize_and_clip(optimizer: "torch.optim.Optimizer",
  265. clip_val: float = 10.0) -> None:
  266. """Clips grads found in `optimizer.param_groups` to given value in place.
  267. Ensures the norm of the gradients for each variable is clipped to
  268. `clip_val`.
  269. Args:
  270. optimizer: The torch.optim.Optimizer to get the variables from.
  271. clip_val: The global norm clip value. Will clip around -clip_val and
  272. +clip_val.
  273. """
  274. # Loop through optimizer's variables and norm per variable.
  275. for param_group in optimizer.param_groups:
  276. for p in param_group["params"]:
  277. if p.grad is not None:
  278. torch.nn.utils.clip_grad_norm_(p.grad, clip_val)
  279. def one_hot(x: TensorType, space: gym.Space) -> TensorType:
  280. """Returns a one-hot tensor, given and int tensor and a space.
  281. Handles the MultiDiscrete case as well.
  282. Args:
  283. x: The input tensor.
  284. space: The space to use for generating the one-hot tensor.
  285. Returns:
  286. The resulting one-hot tensor.
  287. Raises:
  288. ValueError: If the given space is not a discrete one.
  289. Examples:
  290. >>> x = torch.IntTensor([0, 3]) # batch-dim=2
  291. >>> # Discrete space with 4 (one-hot) slots per batch item.
  292. >>> s = gym.spaces.Discrete(4)
  293. >>> one_hot(x, s)
  294. tensor([[1, 0, 0, 0], [0, 0, 0, 1]])
  295. >>> x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1
  296. >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
  297. >>> # per batch item.
  298. >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
  299. >>> one_hot(x, s)
  300. tensor([[1, 0, 0, 0, 0,
  301. 0, 1, 0, 0,
  302. 0, 0, 1, 0,
  303. 0, 0, 0, 1, 0, 0, 0]])
  304. """
  305. if isinstance(space, Discrete):
  306. return nn.functional.one_hot(x.long(), space.n)
  307. elif isinstance(space, MultiDiscrete):
  308. return torch.cat(
  309. [
  310. nn.functional.one_hot(x[:, i].long(), n)
  311. for i, n in enumerate(space.nvec)
  312. ],
  313. dim=-1)
  314. else:
  315. raise ValueError("Unsupported space for `one_hot`: {}".format(space))
  316. def reduce_mean_ignore_inf(x: TensorType,
  317. axis: Optional[int] = None) -> TensorType:
  318. """Same as torch.mean() but ignores -inf values.
  319. Args:
  320. x: The input tensor to reduce mean over.
  321. axis: The axis over which to reduce. None for all axes.
  322. Returns:
  323. The mean reduced inputs, ignoring inf values.
  324. """
  325. mask = torch.ne(x, float("-inf"))
  326. x_zeroed = torch.where(mask, x, torch.zeros_like(x))
  327. return torch.sum(x_zeroed, axis) / torch.sum(mask.float(), axis)
  328. def sequence_mask(
  329. lengths: TensorType,
  330. maxlen: Optional[int] = None,
  331. dtype=None,
  332. time_major: bool = False,
  333. ) -> TensorType:
  334. """Offers same behavior as tf.sequence_mask for torch.
  335. Thanks to Dimitris Papatheodorou
  336. (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
  337. 39036).
  338. Args:
  339. lengths: The tensor of individual lengths to mask by.
  340. maxlen: The maximum length to use for the time axis. If None, use
  341. the max of `lengths`.
  342. dtype: The torch dtype to use for the resulting mask.
  343. time_major: Whether to return the mask as [B, T] (False; default) or
  344. as [T, B] (True).
  345. Returns:
  346. The sequence mask resulting from the given input and parameters.
  347. """
  348. # If maxlen not given, use the longest lengths in the `lengths` tensor.
  349. if maxlen is None:
  350. maxlen = int(lengths.max())
  351. mask = ~(torch.ones(
  352. (len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths)
  353. # Time major transformation.
  354. if not time_major:
  355. mask = mask.t()
  356. # By default, set the mask to be boolean.
  357. mask.type(dtype or torch.bool)
  358. return mask
  359. def set_torch_seed(seed: Optional[int] = None) -> None:
  360. """Sets the torch random seed to the given value.
  361. Args:
  362. seed: The seed to use or None for no seeding.
  363. """
  364. if seed is not None and torch:
  365. torch.manual_seed(seed)
  366. # See https://github.com/pytorch/pytorch/issues/47672.
  367. cuda_version = torch.version.cuda
  368. if cuda_version is not None and float(torch.version.cuda) >= 10.2:
  369. os.environ["CUBLAS_WORKSPACE_CONFIG"] = "4096:8"
  370. else:
  371. # Not all Operations support this.
  372. torch.use_deterministic_algorithms(True)
  373. # This is only for Convolution no problem.
  374. torch.backends.cudnn.deterministic = True
  375. def softmax_cross_entropy_with_logits(
  376. logits: TensorType,
  377. labels: TensorType,
  378. ) -> TensorType:
  379. """Same behavior as tf.nn.softmax_cross_entropy_with_logits.
  380. Args:
  381. x: The input predictions.
  382. labels: The labels corresponding to `x`.
  383. Returns:
  384. The resulting softmax cross-entropy given predictions and labels.
  385. """
  386. return torch.sum(-labels * nn.functional.log_softmax(logits, -1), -1)
  387. class Swish(nn.Module):
  388. def __init__(self):
  389. super().__init__()
  390. self._beta = nn.Parameter(torch.tensor(1.0))
  391. def forward(self, input_tensor):
  392. return input_tensor * torch.sigmoid(self._beta * input_tensor)