batch_norm_model.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import numpy as np
  2. from ray.rllib.models.modelv2 import ModelV2
  3. from ray.rllib.models.tf.misc import normc_initializer
  4. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  5. from ray.rllib.models.torch.misc import SlimFC, normc_initializer as \
  6. torch_normc_initializer
  7. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  8. from ray.rllib.utils.annotations import override
  9. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  10. tf1, tf, tfv = try_import_tf()
  11. torch, nn = try_import_torch()
  12. class KerasBatchNormModel(TFModelV2):
  13. """Keras version of above BatchNormModel with exactly the same structure.
  14. IMORTANT NOTE: This model will not work with PPO due to a bug in keras
  15. that surfaces when having more than one input placeholder (here: `inputs`
  16. and `is_training`) AND using the `make_tf_callable` helper (e.g. used by
  17. PPO), in which auto-placeholders are generated, then passed through the
  18. tf.keras. models.Model. In this last step, the connection between 1) the
  19. provided value in the auto-placeholder and 2) the keras `is_training`
  20. Input is broken and keras complains.
  21. Use the below `BatchNormModel` (a non-keras based TFModelV2), instead.
  22. """
  23. def __init__(self, obs_space, action_space, num_outputs, model_config,
  24. name):
  25. super().__init__(obs_space, action_space, num_outputs, model_config,
  26. name)
  27. inputs = tf.keras.layers.Input(shape=obs_space.shape, name="inputs")
  28. # Have to batch the is_training flag (its batch size will always be 1).
  29. is_training = tf.keras.layers.Input(
  30. shape=(), dtype=tf.bool, batch_size=1, name="is_training")
  31. last_layer = inputs
  32. hiddens = [256, 256]
  33. for i, size in enumerate(hiddens):
  34. label = "fc{}".format(i)
  35. last_layer = tf.keras.layers.Dense(
  36. units=size,
  37. kernel_initializer=normc_initializer(1.0),
  38. activation=tf.nn.tanh,
  39. name=label)(last_layer)
  40. # Add a batch norm layer
  41. last_layer = tf.keras.layers.BatchNormalization()(
  42. last_layer, training=is_training[0])
  43. output = tf.keras.layers.Dense(
  44. units=self.num_outputs,
  45. kernel_initializer=normc_initializer(0.01),
  46. activation=None,
  47. name="fc_out")(last_layer)
  48. value_out = tf.keras.layers.Dense(
  49. units=1,
  50. kernel_initializer=normc_initializer(0.01),
  51. activation=None,
  52. name="value_out")(last_layer)
  53. self.base_model = tf.keras.models.Model(
  54. inputs=[inputs, is_training], outputs=[output, value_out])
  55. @override(ModelV2)
  56. def forward(self, input_dict, state, seq_lens):
  57. # Have to batch the is_training flag (B=1).
  58. out, self._value_out = self.base_model(
  59. [input_dict["obs"],
  60. tf.expand_dims(input_dict["is_training"], 0)])
  61. return out, []
  62. @override(ModelV2)
  63. def value_function(self):
  64. return tf.reshape(self._value_out, [-1])
  65. class BatchNormModel(TFModelV2):
  66. """Example of a TFModelV2 that is built w/o using tf.keras.
  67. NOTE: The above keras-based example model does not work with PPO (due to
  68. a bug in keras related to missing values for input placeholders, even
  69. though these input values have been provided in a forward pass through the
  70. actual keras Model).
  71. All Model logic (layers) is defined in the `forward` method (incl.
  72. the batch_normalization layers). Also, all variables are registered
  73. (only once) at the end of `forward`, so an optimizer knows which tensors
  74. to train on. A standard `value_function` override is used.
  75. """
  76. capture_index = 0
  77. def __init__(self, obs_space, action_space, num_outputs, model_config,
  78. name):
  79. super().__init__(obs_space, action_space, num_outputs, model_config,
  80. name)
  81. # Have we registered our vars yet (see `forward`)?
  82. self._registered = False
  83. @override(ModelV2)
  84. def forward(self, input_dict, state, seq_lens):
  85. last_layer = input_dict["obs"]
  86. hiddens = [256, 256]
  87. with tf1.variable_scope("model", reuse=tf1.AUTO_REUSE):
  88. for i, size in enumerate(hiddens):
  89. last_layer = tf1.layers.dense(
  90. last_layer,
  91. size,
  92. kernel_initializer=normc_initializer(1.0),
  93. activation=tf.nn.tanh,
  94. name="fc{}".format(i))
  95. # Add a batch norm layer
  96. last_layer = tf1.layers.batch_normalization(
  97. last_layer,
  98. training=input_dict["is_training"],
  99. name="bn_{}".format(i))
  100. output = tf1.layers.dense(
  101. last_layer,
  102. self.num_outputs,
  103. kernel_initializer=normc_initializer(0.01),
  104. activation=None,
  105. name="out")
  106. self._value_out = tf1.layers.dense(
  107. last_layer,
  108. 1,
  109. kernel_initializer=normc_initializer(1.0),
  110. activation=None,
  111. name="vf")
  112. # Register variables.
  113. # NOTE: This is not the recommended way of doing things. We would
  114. # prefer creating keras-style Layers like it's done in the
  115. # `KerasBatchNormModel` class above and then have TFModelV2 auto-detect
  116. # the created vars. However, since there is a bug
  117. # in keras/tf that prevents us from using that KerasBatchNormModel
  118. # example (see comments above), we do variable registration the old,
  119. # manual way for this example Model here.
  120. if not self._registered:
  121. # Register already auto-detected variables (from the wrapping
  122. # Model, e.g. DQNTFModel).
  123. self.register_variables(self.variables())
  124. # Then register everything we added to the graph in this `forward`
  125. # call.
  126. self.register_variables(
  127. tf1.get_collection(
  128. tf1.GraphKeys.TRAINABLE_VARIABLES, scope=".+/model/.+"))
  129. self._registered = True
  130. return output, []
  131. @override(ModelV2)
  132. def value_function(self):
  133. return tf.reshape(self._value_out, [-1])
  134. class TorchBatchNormModel(TorchModelV2, nn.Module):
  135. """Example of a TorchModelV2 using batch normalization."""
  136. capture_index = 0
  137. def __init__(self, obs_space, action_space, num_outputs, model_config,
  138. name, **kwargs):
  139. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  140. model_config, name)
  141. nn.Module.__init__(self)
  142. layers = []
  143. prev_layer_size = int(np.product(obs_space.shape))
  144. self._logits = None
  145. # Create layers 0 to second-last.
  146. for size in [256, 256]:
  147. layers.append(
  148. SlimFC(
  149. in_size=prev_layer_size,
  150. out_size=size,
  151. initializer=torch_normc_initializer(1.0),
  152. activation_fn=nn.ReLU))
  153. prev_layer_size = size
  154. # Add a batch norm layer.
  155. layers.append(nn.BatchNorm1d(prev_layer_size))
  156. self._logits = SlimFC(
  157. in_size=prev_layer_size,
  158. out_size=self.num_outputs,
  159. initializer=torch_normc_initializer(0.01),
  160. activation_fn=None)
  161. self._value_branch = SlimFC(
  162. in_size=prev_layer_size,
  163. out_size=1,
  164. initializer=torch_normc_initializer(1.0),
  165. activation_fn=None)
  166. self._hidden_layers = nn.Sequential(*layers)
  167. self._hidden_out = None
  168. @override(ModelV2)
  169. def forward(self, input_dict, state, seq_lens):
  170. # Set the correct train-mode for our hidden module (only important
  171. # b/c we have some batch-norm layers).
  172. self._hidden_layers.train(
  173. mode=bool(input_dict.get("is_training", False)))
  174. self._hidden_out = self._hidden_layers(input_dict["obs"])
  175. logits = self._logits(self._hidden_out)
  176. return logits, []
  177. @override(ModelV2)
  178. def value_function(self):
  179. assert self._hidden_out is not None, "must call forward first!"
  180. return torch.reshape(self._value_branch(self._hidden_out), [-1])