framework.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import logging
  2. import numpy as np
  3. import os
  4. import sys
  5. from typing import Any, Optional
  6. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  7. from ray.rllib.utils.deprecation import Deprecated
  8. from ray.rllib.utils.typing import TensorShape, TensorType
  9. logger = logging.getLogger(__name__)
  10. @PublicAPI
  11. def try_import_jax(error: bool = False):
  12. """Tries importing JAX and FLAX and returns both modules (or Nones).
  13. Args:
  14. error: Whether to raise an error if JAX/FLAX cannot be imported.
  15. Returns:
  16. Tuple containing the jax- and the flax modules.
  17. Raises:
  18. ImportError: If error=True and JAX is not installed.
  19. """
  20. if "RLLIB_TEST_NO_JAX_IMPORT" in os.environ:
  21. logger.warning("Not importing JAX for test purposes.")
  22. return None, None
  23. try:
  24. import jax
  25. import flax
  26. except ImportError:
  27. if error:
  28. raise ImportError(
  29. "Could not import JAX! RLlib requires you to "
  30. "install at least one deep-learning framework: "
  31. "`pip install [torch|tensorflow|jax]`."
  32. )
  33. return None, None
  34. return jax, flax
  35. @PublicAPI
  36. def try_import_tf(error: bool = False):
  37. """Tries importing tf and returns the module (or None).
  38. Args:
  39. error: Whether to raise an error if tf cannot be imported.
  40. Returns:
  41. Tuple containing
  42. 1) tf1.x module (either from tf2.x.compat.v1 OR as tf1.x).
  43. 2) tf module (resulting from `import tensorflow`). Either tf1.x or
  44. 2.x. 3) The actually installed tf version as int: 1 or 2.
  45. Raises:
  46. ImportError: If error=True and tf is not installed.
  47. """
  48. tf_stub = _TFStub()
  49. # Make sure, these are reset after each test case
  50. # that uses them: del os.environ["RLLIB_TEST_NO_TF_IMPORT"]
  51. if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
  52. logger.warning("Not importing TensorFlow for test purposes")
  53. return None, tf_stub, None
  54. if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
  55. os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
  56. # Try to reuse already imported tf module. This will avoid going through
  57. # the initial import steps below and thereby switching off v2_behavior
  58. # (switching off v2 behavior twice breaks all-framework tests for eager).
  59. was_imported = False
  60. if "tensorflow" in sys.modules:
  61. tf_module = sys.modules["tensorflow"]
  62. was_imported = True
  63. else:
  64. try:
  65. import tensorflow as tf_module
  66. except ImportError:
  67. if error:
  68. raise ImportError(
  69. "Could not import TensorFlow! RLlib requires you to "
  70. "install at least one deep-learning framework: "
  71. "`pip install [torch|tensorflow|jax]`."
  72. )
  73. return None, tf_stub, None
  74. # Try "reducing" tf to tf.compat.v1.
  75. try:
  76. tf1_module = tf_module.compat.v1
  77. tf1_module.logging.set_verbosity(tf1_module.logging.ERROR)
  78. if not was_imported:
  79. tf1_module.disable_v2_behavior()
  80. tf1_module.enable_resource_variables()
  81. tf1_module.logging.set_verbosity(tf1_module.logging.WARN)
  82. # No compat.v1 -> return tf as is.
  83. except AttributeError:
  84. tf1_module = tf_module
  85. if not hasattr(tf_module, "__version__"):
  86. version = 1 # sphinx doc gen
  87. else:
  88. version = 2 if "2." in tf_module.__version__[:2] else 1
  89. return tf1_module, tf_module, version
  90. # Fake module for tf.
  91. class _TFStub:
  92. def __init__(self) -> None:
  93. self.keras = _KerasStub()
  94. def __bool__(self):
  95. # if tf should return False
  96. return False
  97. # Fake module for tf.keras.
  98. class _KerasStub:
  99. def __init__(self) -> None:
  100. self.Model = _FakeTfClassStub
  101. # Fake classes under keras (e.g for tf.keras.Model)
  102. class _FakeTfClassStub:
  103. def __init__(self, *a, **kw):
  104. raise ImportError("Could not import `tensorflow`. Try pip install tensorflow.")
  105. @DeveloperAPI
  106. def tf_function(tf_module):
  107. """Conditional decorator for @tf.function.
  108. Use @tf_function(tf) instead to avoid errors if tf is not installed."""
  109. # The actual decorator to use (pass in `tf` (which could be None)).
  110. def decorator(func):
  111. # If tf not installed -> return function as is (won't be used anyways).
  112. if tf_module is None or tf_module.executing_eagerly():
  113. return func
  114. # If tf installed, return @tf.function-decorated function.
  115. return tf_module.function(func)
  116. return decorator
  117. @PublicAPI
  118. def try_import_tfp(error: bool = False):
  119. """Tries importing tfp and returns the module (or None).
  120. Args:
  121. error: Whether to raise an error if tfp cannot be imported.
  122. Returns:
  123. The tfp module.
  124. Raises:
  125. ImportError: If error=True and tfp is not installed.
  126. """
  127. if "RLLIB_TEST_NO_TF_IMPORT" in os.environ:
  128. logger.warning("Not importing TensorFlow Probability for test purposes.")
  129. return None
  130. try:
  131. import tensorflow_probability as tfp
  132. return tfp
  133. except ImportError as e:
  134. if error:
  135. raise e
  136. return None
  137. # Fake module for torch.nn.
  138. class _NNStub:
  139. def __init__(self, *a, **kw):
  140. # Fake nn.functional module within torch.nn.
  141. self.functional = None
  142. self.Module = _FakeTorchClassStub
  143. self.parallel = _ParallelStub()
  144. # Fake class for e.g. torch.nn.Module to allow it to be inherited from.
  145. class _FakeTorchClassStub:
  146. def __init__(self, *a, **kw):
  147. raise ImportError("Could not import `torch`. Try pip install torch.")
  148. class _ParallelStub:
  149. def __init__(self, *a, **kw):
  150. self.DataParallel = _FakeTorchClassStub
  151. self.DistributedDataParallel = _FakeTorchClassStub
  152. @PublicAPI
  153. def try_import_torch(error: bool = False):
  154. """Tries importing torch and returns the module (or None).
  155. Args:
  156. error: Whether to raise an error if torch cannot be imported.
  157. Returns:
  158. Tuple consisting of the torch- AND torch.nn modules.
  159. Raises:
  160. ImportError: If error=True and PyTorch is not installed.
  161. """
  162. if "RLLIB_TEST_NO_TORCH_IMPORT" in os.environ:
  163. logger.warning("Not importing PyTorch for test purposes.")
  164. return _torch_stubs()
  165. try:
  166. import torch
  167. import torch.nn as nn
  168. return torch, nn
  169. except ImportError:
  170. if error:
  171. raise ImportError(
  172. "Could not import PyTorch! RLlib requires you to "
  173. "install at least one deep-learning framework: "
  174. "`pip install [torch|tensorflow|jax]`."
  175. )
  176. return _torch_stubs()
  177. def _torch_stubs():
  178. nn = _NNStub()
  179. return None, nn
  180. @DeveloperAPI
  181. def get_variable(
  182. value: Any,
  183. framework: str = "tf",
  184. trainable: bool = False,
  185. tf_name: str = "unnamed-variable",
  186. torch_tensor: bool = False,
  187. device: Optional[str] = None,
  188. shape: Optional[TensorShape] = None,
  189. dtype: Optional[TensorType] = None,
  190. ) -> Any:
  191. """Creates a tf variable, a torch tensor, or a python primitive.
  192. Args:
  193. value: The initial value to use. In the non-tf case, this will
  194. be returned as is. In the tf case, this could be a tf-Initializer
  195. object.
  196. framework: One of "tf", "torch", or None.
  197. trainable: Whether the generated variable should be
  198. trainable (tf)/require_grad (torch) or not (default: False).
  199. tf_name: For framework="tf": An optional name for the
  200. tf.Variable.
  201. torch_tensor: For framework="torch": Whether to actually create
  202. a torch.tensor, or just a python value (default).
  203. device: An optional torch device to use for
  204. the created torch tensor.
  205. shape: An optional shape to use iff `value`
  206. does not have any (e.g. if it's an initializer w/o explicit value).
  207. dtype: An optional dtype to use iff `value` does
  208. not have any (e.g. if it's an initializer w/o explicit value).
  209. This should always be a numpy dtype (e.g. np.float32, np.int64).
  210. Returns:
  211. A framework-specific variable (tf.Variable, torch.tensor, or
  212. python primitive).
  213. """
  214. if framework in ["tf2", "tf"]:
  215. import tensorflow as tf
  216. dtype = dtype or getattr(
  217. value,
  218. "dtype",
  219. tf.float32
  220. if isinstance(value, float)
  221. else tf.int32
  222. if isinstance(value, int)
  223. else None,
  224. )
  225. return tf.compat.v1.get_variable(
  226. tf_name,
  227. initializer=value,
  228. dtype=dtype,
  229. trainable=trainable,
  230. **({} if shape is None else {"shape": shape})
  231. )
  232. elif framework == "torch" and torch_tensor is True:
  233. torch, _ = try_import_torch()
  234. if not isinstance(value, np.ndarray):
  235. value = np.array(value)
  236. var_ = torch.from_numpy(value)
  237. if dtype in [torch.float32, np.float32]:
  238. var_ = var_.float()
  239. elif dtype in [torch.int32, np.int32]:
  240. var_ = var_.int()
  241. elif dtype in [torch.float64, np.float64]:
  242. var_ = var_.double()
  243. if device:
  244. var_ = var_.to(device)
  245. var_.requires_grad = trainable
  246. return var_
  247. # torch or None: Return python primitive.
  248. return value
  249. @Deprecated(
  250. old="rllib/utils/framework.py::get_activation_fn",
  251. new="rllib/models/utils.py::get_activation_fn",
  252. error=True,
  253. )
  254. def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
  255. pass