tf_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. import gym
  2. from gym.spaces import Discrete, MultiDiscrete
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from typing import Any, Callable, List, Optional, Type, TYPE_CHECKING, Union
  6. from ray.rllib.utils.deprecation import Deprecated
  7. from ray.rllib.utils.framework import try_import_tf
  8. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  9. from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \
  10. PartialTrainerConfigDict, TensorStructType, TensorType
  11. if TYPE_CHECKING:
  12. from ray.rllib.policy.tf_policy import TFPolicy
  13. tf1, tf, tfv = try_import_tf()
  14. @Deprecated(new="ray.rllib.utils.numpy.convert_to_numpy()", error=True)
  15. def convert_to_non_tf_type(x: TensorStructType) -> TensorStructType:
  16. """Converts values in `stats` to non-Tensor numpy or python types.
  17. Args:
  18. x: Any (possibly nested) struct, the values in which will be
  19. converted and returned as a new struct with all tf (eager) tensors
  20. being converted to numpy types.
  21. Returns:
  22. A new struct with the same structure as `x`, but with all
  23. values converted to non-tf Tensor types.
  24. """
  25. # The mapping function used to numpyize torch Tensors.
  26. def mapping(item):
  27. if isinstance(item, (tf.Tensor, tf.Variable)):
  28. return item.numpy()
  29. else:
  30. return item
  31. return tree.map_structure(mapping, x)
  32. def explained_variance(y: TensorType, pred: TensorType) -> TensorType:
  33. """Computes the explained variance for a pair of labels and predictions.
  34. The formula used is:
  35. max(-1.0, 1.0 - (std(y - pred)^2 / std(y)^2))
  36. Args:
  37. y: The labels.
  38. pred: The predictions.
  39. Returns:
  40. The explained variance given a pair of labels and predictions.
  41. """
  42. _, y_var = tf.nn.moments(y, axes=[0])
  43. _, diff_var = tf.nn.moments(y - pred, axes=[0])
  44. return tf.maximum(-1.0, 1 - (diff_var / y_var))
  45. def get_gpu_devices() -> List[str]:
  46. """Returns a list of GPU device names, e.g. ["/gpu:0", "/gpu:1"].
  47. Supports both tf1.x and tf2.x.
  48. Returns:
  49. List of GPU device names (str).
  50. """
  51. if tfv == 1:
  52. from tensorflow.python.client import device_lib
  53. devices = device_lib.list_local_devices()
  54. else:
  55. try:
  56. devices = tf.config.list_physical_devices()
  57. except Exception:
  58. devices = tf.config.experimental.list_physical_devices()
  59. # Expect "GPU", but also stuff like: "XLA_GPU".
  60. return [d.name for d in devices if "GPU" in d.device_type]
  61. def get_placeholder(*,
  62. space: Optional[gym.Space] = None,
  63. value: Optional[Any] = None,
  64. name: Optional[str] = None,
  65. time_axis: bool = False,
  66. flatten: bool = True) -> "tf1.placeholder":
  67. """Returns a tf1.placeholder object given optional hints, such as a space.
  68. Note that the returned placeholder will always have a leading batch
  69. dimension (None).
  70. Args:
  71. space: An optional gym.Space to hint the shape and dtype of the
  72. placeholder.
  73. value: An optional value to hint the shape and dtype of the
  74. placeholder.
  75. name: An optional name for the placeholder.
  76. time_axis: Whether the placeholder should also receive a time
  77. dimension (None).
  78. flatten: Whether to flatten the given space into a plain Box space
  79. and then create the placeholder from the resulting space.
  80. Returns:
  81. The tf1 placeholder.
  82. """
  83. from ray.rllib.models.catalog import ModelCatalog
  84. if space is not None:
  85. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
  86. if flatten:
  87. return ModelCatalog.get_action_placeholder(space, None)
  88. else:
  89. return tree.map_structure_with_path(
  90. lambda path, component: get_placeholder(
  91. space=component,
  92. name=name + "." + ".".join([str(p) for p in path]),
  93. ),
  94. get_base_struct_from_space(space),
  95. )
  96. return tf1.placeholder(
  97. shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
  98. dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
  99. name=name,
  100. )
  101. else:
  102. assert value is not None
  103. shape = value.shape[1:]
  104. return tf1.placeholder(
  105. shape=(None, ) + ((None, )
  106. if time_axis else ()) + (shape if isinstance(
  107. shape, tuple) else tuple(shape.as_list())),
  108. dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
  109. name=name,
  110. )
  111. def get_tf_eager_cls_if_necessary(
  112. orig_cls: Type["TFPolicy"],
  113. config: PartialTrainerConfigDict) -> Type["TFPolicy"]:
  114. """Returns the corresponding tf-eager class for a given TFPolicy class.
  115. Args:
  116. orig_cls: The original TFPolicy class to get the corresponding tf-eager
  117. class for.
  118. config: The Trainer config dict.
  119. Returns:
  120. The tf eager policy class corresponding to the given TFPolicy class.
  121. """
  122. cls = orig_cls
  123. framework = config.get("framework", "tf")
  124. if framework in ["tf2", "tf", "tfe"]:
  125. if not tf1:
  126. raise ImportError("Could not import tensorflow!")
  127. if framework in ["tf2", "tfe"]:
  128. assert tf1.executing_eagerly()
  129. from ray.rllib.policy.tf_policy import TFPolicy
  130. # Create eager-class.
  131. if hasattr(orig_cls, "as_eager"):
  132. cls = orig_cls.as_eager()
  133. if config.get("eager_tracing"):
  134. cls = cls.with_tracing()
  135. # Could be some other type of policy.
  136. elif not issubclass(orig_cls, TFPolicy):
  137. pass
  138. else:
  139. raise ValueError("This policy does not support eager "
  140. "execution: {}".format(orig_cls))
  141. return cls
  142. def huber_loss(x: TensorType, delta: float = 1.0) -> TensorType:
  143. """Computes the huber loss for a given term and delta parameter.
  144. Reference: https://en.wikipedia.org/wiki/Huber_loss
  145. Note that the factor of 0.5 is implicitly included in the calculation.
  146. Formula:
  147. L = 0.5 * x^2 for small abs x (delta threshold)
  148. L = delta * (abs(x) - 0.5*delta) for larger abs x (delta threshold)
  149. Args:
  150. x: The input term, e.g. a TD error.
  151. delta: The delta parmameter in the above formula.
  152. Returns:
  153. The Huber loss resulting from `x` and `delta`.
  154. """
  155. return tf.where(
  156. tf.abs(x) < delta, # for small x -> apply the Huber correction
  157. tf.math.square(x) * 0.5,
  158. delta * (tf.abs(x) - 0.5 * delta),
  159. )
  160. def make_tf_callable(session_or_none: Optional["tf1.Session"],
  161. dynamic_shape: bool = False) -> Callable:
  162. """Returns a function that can be executed in either graph or eager mode.
  163. The function must take only positional args.
  164. If eager is enabled, this will act as just a function. Otherwise, it
  165. will build a function that executes a session run with placeholders
  166. internally.
  167. Args:
  168. session_or_none: tf.Session if in graph mode, else None.
  169. dynamic_shape: True if the placeholders should have a dynamic
  170. batch dimension. Otherwise they will be fixed shape.
  171. Returns:
  172. A function that can be called in either eager or static-graph mode.
  173. """
  174. if tf.executing_eagerly():
  175. assert session_or_none is None
  176. else:
  177. assert session_or_none is not None
  178. def make_wrapper(fn):
  179. # Static-graph mode: Create placeholders and make a session call each
  180. # time the wrapped function is called. Returns the output of this
  181. # session call.
  182. if session_or_none is not None:
  183. args_placeholders = []
  184. kwargs_placeholders = {}
  185. symbolic_out = [None]
  186. def call(*args, **kwargs):
  187. args_flat = []
  188. for a in args:
  189. if type(a) is list:
  190. args_flat.extend(a)
  191. else:
  192. args_flat.append(a)
  193. args = args_flat
  194. # We have not built any placeholders yet: Do this once here,
  195. # then reuse the same placeholders each time we call this
  196. # function again.
  197. if symbolic_out[0] is None:
  198. with session_or_none.graph.as_default():
  199. def _create_placeholders(path, value):
  200. if dynamic_shape:
  201. if len(value.shape) > 0:
  202. shape = (None, ) + value.shape[1:]
  203. else:
  204. shape = ()
  205. else:
  206. shape = value.shape
  207. return tf1.placeholder(
  208. dtype=value.dtype,
  209. shape=shape,
  210. name=".".join([str(p) for p in path]),
  211. )
  212. placeholders = tree.map_structure_with_path(
  213. _create_placeholders, args)
  214. for ph in tree.flatten(placeholders):
  215. args_placeholders.append(ph)
  216. placeholders = tree.map_structure_with_path(
  217. _create_placeholders, kwargs)
  218. for k, ph in placeholders.items():
  219. kwargs_placeholders[k] = ph
  220. symbolic_out[0] = fn(*args_placeholders,
  221. **kwargs_placeholders)
  222. feed_dict = dict(zip(args_placeholders, tree.flatten(args)))
  223. tree.map_structure(lambda ph, v: feed_dict.__setitem__(ph, v),
  224. kwargs_placeholders, kwargs)
  225. ret = session_or_none.run(symbolic_out[0], feed_dict)
  226. return ret
  227. return call
  228. # Eager mode (call function as is).
  229. else:
  230. return fn
  231. return make_wrapper
  232. def minimize_and_clip(
  233. optimizer: LocalOptimizer,
  234. objective: TensorType,
  235. var_list: List["tf.Variable"],
  236. clip_val: float = 10.0,
  237. ) -> ModelGradients:
  238. """Computes, then clips gradients using objective, optimizer and var list.
  239. Ensures the norm of the gradients for each variable is clipped to
  240. `clip_val`.
  241. Args:
  242. optimizer: Either a shim optimizer (tf eager) containing a
  243. tf.GradientTape under `self.tape` or a tf1 local optimizer
  244. object.
  245. objective: The loss tensor to calculate gradients on.
  246. var_list: The list of tf.Variables to compute gradients over.
  247. clip_val: The global norm clip value. Will clip around -clip_val and
  248. +clip_val.
  249. Returns:
  250. The resulting model gradients (list or tuples of grads + vars)
  251. corresponding to the input `var_list`.
  252. """
  253. # Accidentally passing values < 0.0 will break all gradients.
  254. assert clip_val is None or clip_val > 0.0, clip_val
  255. if tf.executing_eagerly():
  256. tape = optimizer.tape
  257. grads_and_vars = list(
  258. zip(list(tape.gradient(objective, var_list)), var_list))
  259. else:
  260. grads_and_vars = optimizer.compute_gradients(
  261. objective, var_list=var_list)
  262. return [(tf.clip_by_norm(g, clip_val) if clip_val is not None else g, v)
  263. for (g, v) in grads_and_vars if g is not None]
  264. def one_hot(x: TensorType, space: gym.Space) -> TensorType:
  265. """Returns a one-hot tensor, given and int tensor and a space.
  266. Handles the MultiDiscrete case as well.
  267. Args:
  268. x: The input tensor.
  269. space: The space to use for generating the one-hot tensor.
  270. Returns:
  271. The resulting one-hot tensor.
  272. Raises:
  273. ValueError: If the given space is not a discrete one.
  274. Examples:
  275. >>> x = tf.Variable([0, 3], dtype=tf.int32) # batch-dim=2
  276. >>> # Discrete space with 4 (one-hot) slots per batch item.
  277. >>> s = gym.spaces.Discrete(4)
  278. >>> one_hot(x, s)
  279. <tf.Tensor 'one_hot:0' shape=(2, 4) dtype=float32>
  280. >>> x = tf.Variable([[0, 1, 2, 3]], dtype=tf.int32) # batch-dim=1
  281. >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
  282. >>> # per batch item.
  283. >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
  284. >>> one_hot(x, s)
  285. <tf.Tensor 'concat:0' shape=(1, 20) dtype=float32>
  286. """
  287. if isinstance(space, Discrete):
  288. return tf.one_hot(x, space.n, dtype=tf.float32)
  289. elif isinstance(space, MultiDiscrete):
  290. return tf.concat(
  291. [
  292. tf.one_hot(x[:, i], n, dtype=tf.float32)
  293. for i, n in enumerate(space.nvec)
  294. ],
  295. axis=-1)
  296. else:
  297. raise ValueError("Unsupported space for `one_hot`: {}".format(space))
  298. def reduce_mean_ignore_inf(x: TensorType,
  299. axis: Optional[int] = None) -> TensorType:
  300. """Same as tf.reduce_mean() but ignores -inf values.
  301. Args:
  302. x: The input tensor to reduce mean over.
  303. axis: The axis over which to reduce. None for all axes.
  304. Returns:
  305. The mean reduced inputs, ignoring inf values.
  306. """
  307. mask = tf.not_equal(x, tf.float32.min)
  308. x_zeroed = tf.where(mask, x, tf.zeros_like(x))
  309. return (tf.math.reduce_sum(x_zeroed, axis) / tf.math.reduce_sum(
  310. tf.cast(mask, tf.float32), axis))
  311. def scope_vars(scope: Union[str, "tf1.VariableScope"],
  312. trainable_only: bool = False) -> List["tf.Variable"]:
  313. """Get variables inside a given scope.
  314. Args:
  315. scope: Scope in which the variables reside.
  316. trainable_only: Whether or not to return only the variables that were
  317. marked as trainable.
  318. Returns:
  319. The list of variables in the given `scope`.
  320. """
  321. return tf1.get_collection(
  322. tf1.GraphKeys.TRAINABLE_VARIABLES
  323. if trainable_only else tf1.GraphKeys.VARIABLES,
  324. scope=scope if isinstance(scope, str) else scope.name)
  325. def zero_logps_from_actions(actions: TensorStructType) -> TensorType:
  326. """Helper function useful for returning dummy logp's (0) for some actions.
  327. Args:
  328. actions: The input actions. This can be any struct
  329. of complex action components or a simple tensor of different
  330. dimensions, e.g. [B], [B, 2], or {"a": [B, 4, 5], "b": [B]}.
  331. Returns:
  332. A 1D tensor of 0.0 (dummy logp's) matching the batch
  333. dim of `actions` (shape=[B]).
  334. """
  335. # Need to flatten `actions` in case we have a complex action space.
  336. # Take the 0th component to extract the batch dim.
  337. action_component = tree.flatten(actions)[0]
  338. logp_ = tf.zeros_like(action_component, dtype=tf.float32)
  339. # Logp's should be single values (but with the same batch dim as
  340. # `deterministic_actions` or `stochastic_actions`). In case
  341. # actions are just [B], zeros_like works just fine here, but if
  342. # actions are [B, ...], we have to reduce logp back to just [B].
  343. while len(logp_.shape) > 1:
  344. logp_ = logp_[:, 0]
  345. return logp_