fcnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import numpy as np
  2. import gym
  3. from typing import Dict, Optional, Sequence
  4. from ray.rllib.models.tf.misc import normc_initializer
  5. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  6. from ray.rllib.models.utils import get_activation_fn
  7. from ray.rllib.policy.sample_batch import SampleBatch
  8. from ray.rllib.utils.framework import try_import_tf
  9. from ray.rllib.utils.typing import TensorType, List, ModelConfigDict
  10. tf1, tf, tfv = try_import_tf()
  11. # TODO: (sven) obsolete this class once we only support native keras models.
  12. class FullyConnectedNetwork(TFModelV2):
  13. """Generic fully connected network implemented in ModelV2 API."""
  14. def __init__(self, obs_space: gym.spaces.Space,
  15. action_space: gym.spaces.Space, num_outputs: int,
  16. model_config: ModelConfigDict, name: str):
  17. super(FullyConnectedNetwork, self).__init__(
  18. obs_space, action_space, num_outputs, model_config, name)
  19. hiddens = list(model_config.get("fcnet_hiddens", [])) + \
  20. list(model_config.get("post_fcnet_hiddens", []))
  21. activation = model_config.get("fcnet_activation")
  22. if not model_config.get("fcnet_hiddens", []):
  23. activation = model_config.get("post_fcnet_activation")
  24. activation = get_activation_fn(activation)
  25. no_final_linear = model_config.get("no_final_linear")
  26. vf_share_layers = model_config.get("vf_share_layers")
  27. free_log_std = model_config.get("free_log_std")
  28. # Generate free-floating bias variables for the second half of
  29. # the outputs.
  30. if free_log_std:
  31. assert num_outputs % 2 == 0, (
  32. "num_outputs must be divisible by two", num_outputs)
  33. num_outputs = num_outputs // 2
  34. self.log_std_var = tf.Variable(
  35. [0.0] * num_outputs, dtype=tf.float32, name="log_std")
  36. # We are using obs_flat, so take the flattened shape as input.
  37. inputs = tf.keras.layers.Input(
  38. shape=(int(np.product(obs_space.shape)), ), name="observations")
  39. # Last hidden layer output (before logits outputs).
  40. last_layer = inputs
  41. # The action distribution outputs.
  42. logits_out = None
  43. i = 1
  44. # Create layers 0 to second-last.
  45. for size in hiddens[:-1]:
  46. last_layer = tf.keras.layers.Dense(
  47. size,
  48. name="fc_{}".format(i),
  49. activation=activation,
  50. kernel_initializer=normc_initializer(1.0))(last_layer)
  51. i += 1
  52. # The last layer is adjusted to be of size num_outputs, but it's a
  53. # layer with activation.
  54. if no_final_linear and num_outputs:
  55. logits_out = tf.keras.layers.Dense(
  56. num_outputs,
  57. name="fc_out",
  58. activation=activation,
  59. kernel_initializer=normc_initializer(1.0))(last_layer)
  60. # Finish the layers with the provided sizes (`hiddens`), plus -
  61. # iff num_outputs > 0 - a last linear layer of size num_outputs.
  62. else:
  63. if len(hiddens) > 0:
  64. last_layer = tf.keras.layers.Dense(
  65. hiddens[-1],
  66. name="fc_{}".format(i),
  67. activation=activation,
  68. kernel_initializer=normc_initializer(1.0))(last_layer)
  69. if num_outputs:
  70. logits_out = tf.keras.layers.Dense(
  71. num_outputs,
  72. name="fc_out",
  73. activation=None,
  74. kernel_initializer=normc_initializer(0.01))(last_layer)
  75. # Adjust num_outputs to be the number of nodes in the last layer.
  76. else:
  77. self.num_outputs = (
  78. [int(np.product(obs_space.shape))] + hiddens[-1:])[-1]
  79. # Concat the log std vars to the end of the state-dependent means.
  80. if free_log_std and logits_out is not None:
  81. def tiled_log_std(x):
  82. return tf.tile(
  83. tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])
  84. log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
  85. logits_out = tf.keras.layers.Concatenate(axis=1)(
  86. [logits_out, log_std_out])
  87. last_vf_layer = None
  88. if not vf_share_layers:
  89. # Build a parallel set of hidden layers for the value net.
  90. last_vf_layer = inputs
  91. i = 1
  92. for size in hiddens:
  93. last_vf_layer = tf.keras.layers.Dense(
  94. size,
  95. name="fc_value_{}".format(i),
  96. activation=activation,
  97. kernel_initializer=normc_initializer(1.0))(last_vf_layer)
  98. i += 1
  99. value_out = tf.keras.layers.Dense(
  100. 1,
  101. name="value_out",
  102. activation=None,
  103. kernel_initializer=normc_initializer(0.01))(
  104. last_vf_layer if last_vf_layer is not None else last_layer)
  105. self.base_model = tf.keras.Model(
  106. inputs, [(logits_out
  107. if logits_out is not None else last_layer), value_out])
  108. def forward(self, input_dict: Dict[str, TensorType],
  109. state: List[TensorType],
  110. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  111. model_out, self._value_out = self.base_model(input_dict["obs_flat"])
  112. return model_out, state
  113. def value_function(self) -> TensorType:
  114. return tf.reshape(self._value_out, [-1])
  115. class Keras_FullyConnectedNetwork(tf.keras.Model if tf else object):
  116. """Generic fully connected network implemented in tf Keras."""
  117. def __init__(
  118. self,
  119. input_space: gym.spaces.Space,
  120. action_space: gym.spaces.Space,
  121. num_outputs: Optional[int] = None,
  122. *,
  123. name: str = "",
  124. fcnet_hiddens: Optional[Sequence[int]] = (),
  125. fcnet_activation: Optional[str] = None,
  126. post_fcnet_hiddens: Optional[Sequence[int]] = (),
  127. post_fcnet_activation: Optional[str] = None,
  128. no_final_linear: bool = False,
  129. vf_share_layers: bool = False,
  130. free_log_std: bool = False,
  131. **kwargs,
  132. ):
  133. super().__init__(name=name)
  134. hiddens = list(fcnet_hiddens or ()) + \
  135. list(post_fcnet_hiddens or ())
  136. activation = fcnet_activation
  137. if not fcnet_hiddens:
  138. activation = post_fcnet_activation
  139. activation = get_activation_fn(activation)
  140. # Generate free-floating bias variables for the second half of
  141. # the outputs.
  142. if free_log_std:
  143. assert num_outputs % 2 == 0, (
  144. "num_outputs must be divisible by two", num_outputs)
  145. num_outputs = num_outputs // 2
  146. self.log_std_var = tf.Variable(
  147. [0.0] * num_outputs, dtype=tf.float32, name="log_std")
  148. # We are using obs_flat, so take the flattened shape as input.
  149. inputs = tf.keras.layers.Input(
  150. shape=(int(np.product(input_space.shape)), ), name="observations")
  151. # Last hidden layer output (before logits outputs).
  152. last_layer = inputs
  153. # The action distribution outputs.
  154. logits_out = None
  155. i = 1
  156. # Create layers 0 to second-last.
  157. for size in hiddens[:-1]:
  158. last_layer = tf.keras.layers.Dense(
  159. size,
  160. name="fc_{}".format(i),
  161. activation=activation,
  162. kernel_initializer=normc_initializer(1.0))(last_layer)
  163. i += 1
  164. # The last layer is adjusted to be of size num_outputs, but it's a
  165. # layer with activation.
  166. if no_final_linear and num_outputs:
  167. logits_out = tf.keras.layers.Dense(
  168. num_outputs,
  169. name="fc_out",
  170. activation=activation,
  171. kernel_initializer=normc_initializer(1.0))(last_layer)
  172. # Finish the layers with the provided sizes (`hiddens`), plus -
  173. # iff num_outputs > 0 - a last linear layer of size num_outputs.
  174. else:
  175. if len(hiddens) > 0:
  176. last_layer = tf.keras.layers.Dense(
  177. hiddens[-1],
  178. name="fc_{}".format(i),
  179. activation=activation,
  180. kernel_initializer=normc_initializer(1.0))(last_layer)
  181. if num_outputs:
  182. logits_out = tf.keras.layers.Dense(
  183. num_outputs,
  184. name="fc_out",
  185. activation=None,
  186. kernel_initializer=normc_initializer(0.01))(last_layer)
  187. # Concat the log std vars to the end of the state-dependent means.
  188. if free_log_std and logits_out is not None:
  189. def tiled_log_std(x):
  190. return tf.tile(
  191. tf.expand_dims(self.log_std_var, 0), [tf.shape(x)[0], 1])
  192. log_std_out = tf.keras.layers.Lambda(tiled_log_std)(inputs)
  193. logits_out = tf.keras.layers.Concatenate(axis=1)(
  194. [logits_out, log_std_out])
  195. last_vf_layer = None
  196. if not vf_share_layers:
  197. # Build a parallel set of hidden layers for the value net.
  198. last_vf_layer = inputs
  199. i = 1
  200. for size in hiddens:
  201. last_vf_layer = tf.keras.layers.Dense(
  202. size,
  203. name="fc_value_{}".format(i),
  204. activation=activation,
  205. kernel_initializer=normc_initializer(1.0))(last_vf_layer)
  206. i += 1
  207. value_out = tf.keras.layers.Dense(
  208. 1,
  209. name="value_out",
  210. activation=None,
  211. kernel_initializer=normc_initializer(0.01))(
  212. last_vf_layer if last_vf_layer is not None else last_layer)
  213. self.base_model = tf.keras.Model(
  214. inputs, [(logits_out
  215. if logits_out is not None else last_layer), value_out])
  216. def call(self, input_dict: SampleBatch) -> \
  217. (TensorType, List[TensorType], Dict[str, TensorType]):
  218. model_out, value_out = self.base_model(input_dict[SampleBatch.OBS])
  219. extra_outs = {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])}
  220. return model_out, [], extra_outs