utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. from typing import Optional
  2. from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
  3. try_import_torch
  4. def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
  5. """Returns a framework specific activation function, given a name string.
  6. Args:
  7. name (Optional[str]): One of "relu" (default), "tanh", "elu",
  8. "swish", or "linear" (same as None).
  9. framework (str): One of "jax", "tf|tfe|tf2" or "torch".
  10. Returns:
  11. A framework-specific activtion function. e.g. tf.nn.tanh or
  12. torch.nn.ReLU. None if name in ["linear", None].
  13. Raises:
  14. ValueError: If name is an unknown activation function.
  15. """
  16. # Already a callable, return as-is.
  17. if callable(name):
  18. return name
  19. # Infer the correct activation function from the string specifier.
  20. if framework == "torch":
  21. if name in ["linear", None]:
  22. return None
  23. if name == "swish":
  24. from ray.rllib.utils.torch_utils import Swish
  25. return Swish
  26. _, nn = try_import_torch()
  27. if name == "relu":
  28. return nn.ReLU
  29. elif name == "tanh":
  30. return nn.Tanh
  31. elif name == "elu":
  32. return nn.ELU
  33. elif framework == "jax":
  34. if name in ["linear", None]:
  35. return None
  36. jax, _ = try_import_jax()
  37. if name == "swish":
  38. return jax.nn.swish
  39. if name == "relu":
  40. return jax.nn.relu
  41. elif name == "tanh":
  42. return jax.nn.hard_tanh
  43. elif name == "elu":
  44. return jax.nn.elu
  45. else:
  46. assert framework in ["tf", "tfe", "tf2"],\
  47. "Unsupported framework `{}`!".format(framework)
  48. if name in ["linear", None]:
  49. return None
  50. tf1, tf, tfv = try_import_tf()
  51. fn = getattr(tf.nn, name, None)
  52. if fn is not None:
  53. return fn
  54. raise ValueError("Unknown activation ({}) for framework={}!".format(
  55. name, framework))
  56. def get_filter_config(shape):
  57. """Returns a default Conv2D filter config (list) for a given image shape.
  58. Args:
  59. shape (Tuple[int]): The input (image) shape, e.g. (84,84,3).
  60. Returns:
  61. List[list]: The Conv2D filter configuration usable as `conv_filters`
  62. inside a model config dict.
  63. """
  64. # VizdoomGym (large 480x640).
  65. filters_480x640 = [
  66. [16, [24, 32], [14, 18]],
  67. [32, [6, 6], 4],
  68. [256, [9, 9], 1],
  69. ]
  70. # VizdoomGym (small 240x320).
  71. filters_240x320 = [
  72. [16, [12, 16], [7, 9]],
  73. [32, [6, 6], 4],
  74. [256, [9, 9], 1],
  75. ]
  76. # 96x96x3 (e.g. CarRacing-v0).
  77. filters_96x96 = [
  78. [16, [8, 8], 4],
  79. [32, [4, 4], 2],
  80. [256, [11, 11], 2],
  81. ]
  82. # Atari.
  83. filters_84x84 = [
  84. [16, [8, 8], 4],
  85. [32, [4, 4], 2],
  86. [256, [11, 11], 1],
  87. ]
  88. # Small (1/2) Atari.
  89. filters_42x42 = [
  90. [16, [4, 4], 2],
  91. [32, [4, 4], 2],
  92. [256, [11, 11], 1],
  93. ]
  94. # Test image (10x10).
  95. filters_10x10 = [
  96. [16, [5, 5], 2],
  97. [32, [5, 5], 2],
  98. ]
  99. shape = list(shape)
  100. if len(shape) in [2, 3] and (shape[:2] == [480, 640]
  101. or shape[1:] == [480, 640]):
  102. return filters_480x640
  103. elif len(shape) in [2, 3] and (shape[:2] == [240, 320]
  104. or shape[1:] == [240, 320]):
  105. return filters_240x320
  106. elif len(shape) in [2, 3] and (shape[:2] == [96, 96]
  107. or shape[1:] == [96, 96]):
  108. return filters_96x96
  109. elif len(shape) in [2, 3] and (shape[:2] == [84, 84]
  110. or shape[1:] == [84, 84]):
  111. return filters_84x84
  112. elif len(shape) in [2, 3] and (shape[:2] == [42, 42]
  113. or shape[1:] == [42, 42]):
  114. return filters_42x42
  115. elif len(shape) in [2, 3] and (shape[:2] == [10, 10]
  116. or shape[1:] == [10, 10]):
  117. return filters_10x10
  118. else:
  119. raise ValueError(
  120. "No default configuration for obs shape {}".format(shape) +
  121. ", you must specify `conv_filters` manually as a model option. "
  122. "Default configurations are only available for inputs of shape "
  123. "[42, 42, K] and [84, 84, K]. You may alternatively want "
  124. "to use a custom model or preprocessor.")
  125. def get_initializer(name, framework="tf"):
  126. """Returns a framework specific initializer, given a name string.
  127. Args:
  128. name (str): One of "xavier_uniform" (default), "xavier_normal".
  129. framework (str): One of "jax", "tf|tfe|tf2" or "torch".
  130. Returns:
  131. A framework-specific initializer function, e.g.
  132. tf.keras.initializers.GlorotUniform or
  133. torch.nn.init.xavier_uniform_.
  134. Raises:
  135. ValueError: If name is an unknown initializer.
  136. """
  137. # Already a callable, return as-is.
  138. if callable(name):
  139. return name
  140. if framework == "jax":
  141. _, flax = try_import_jax()
  142. assert flax is not None,\
  143. "`flax` not installed. Try `pip install jax flax`."
  144. import flax.linen as nn
  145. if name in [None, "default", "xavier_uniform"]:
  146. return nn.initializers.xavier_uniform()
  147. elif name == "xavier_normal":
  148. return nn.initializers.xavier_normal()
  149. if framework == "torch":
  150. _, nn = try_import_torch()
  151. assert nn is not None,\
  152. "`torch` not installed. Try `pip install torch`."
  153. if name in [None, "default", "xavier_uniform"]:
  154. return nn.init.xavier_uniform_
  155. elif name == "xavier_normal":
  156. return nn.init.xavier_normal_
  157. else:
  158. assert framework in ["tf", "tfe", "tf2"],\
  159. "Unsupported framework `{}`!".format(framework)
  160. tf1, tf, tfv = try_import_tf()
  161. assert tf is not None,\
  162. "`tensorflow` not installed. Try `pip install tensorflow`."
  163. if name in [None, "default", "xavier_uniform"]:
  164. return tf.keras.initializers.GlorotUniform
  165. elif name == "xavier_normal":
  166. return tf.keras.initializers.GlorotNormal
  167. raise ValueError("Unknown activation ({}) for framework={}!".format(
  168. name, framework))