123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import numpy as np
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.tf.tf_modelv2 import TFModelV2
- from ray.rllib.models.torch.misc import SlimFC
- from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- tf1, tf, tfv = try_import_tf()
- torch, nn = try_import_torch()
- TF2_GLOBAL_SHARED_LAYER = None
- class TF2SharedWeightsModel(TFModelV2):
- """Example of weight sharing between two different TFModelV2s.
- NOTE: This will only work for tf2.x. When running with config.framework=tf,
- use SharedWeightsModel1 and SharedWeightsModel2 below, instead!
- The shared (single) layer is simply defined outside of the two Models,
- then used by both Models in their forward pass.
- """
- def __init__(self, observation_space, action_space, num_outputs,
- model_config, name):
- super().__init__(observation_space, action_space, num_outputs,
- model_config, name)
- global TF2_GLOBAL_SHARED_LAYER
- # The global, shared layer to be used by both models.
- if TF2_GLOBAL_SHARED_LAYER is None:
- TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
- units=64, activation=tf.nn.relu, name="fc1")
- inputs = tf.keras.layers.Input(observation_space.shape)
- last_layer = TF2_GLOBAL_SHARED_LAYER(inputs)
- output = tf.keras.layers.Dense(
- units=num_outputs, activation=None, name="fc_out")(last_layer)
- vf = tf.keras.layers.Dense(
- units=1, activation=None, name="value_out")(last_layer)
- self.base_model = tf.keras.models.Model(inputs, [output, vf])
- @override(ModelV2)
- def forward(self, input_dict, state, seq_lens):
- out, self._value_out = self.base_model(input_dict["obs"])
- return out, []
- @override(ModelV2)
- def value_function(self):
- return tf.reshape(self._value_out, [-1])
- class SharedWeightsModel1(TFModelV2):
- """Example of weight sharing between two different TFModelV2s.
- NOTE: This will only work for tf1 (static graph). When running with
- config.framework=tf2, use TF2SharedWeightsModel, instead!
- Here, we share the variables defined in the 'shared' variable scope
- by entering it explicitly with tf1.AUTO_REUSE. This creates the
- variables for the 'fc1' layer in a global scope called 'shared'
- (outside of the Policy's normal variable scope).
- """
- def __init__(self, observation_space, action_space, num_outputs,
- model_config, name):
- super().__init__(observation_space, action_space, num_outputs,
- model_config, name)
- inputs = tf.keras.layers.Input(observation_space.shape)
- with tf1.variable_scope(
- tf1.VariableScope(tf1.AUTO_REUSE, "shared"),
- reuse=tf1.AUTO_REUSE,
- auxiliary_name_scope=False):
- last_layer = tf.keras.layers.Dense(
- units=64, activation=tf.nn.relu, name="fc1")(inputs)
- output = tf.keras.layers.Dense(
- units=num_outputs, activation=None, name="fc_out")(last_layer)
- vf = tf.keras.layers.Dense(
- units=1, activation=None, name="value_out")(last_layer)
- self.base_model = tf.keras.models.Model(inputs, [output, vf])
- @override(ModelV2)
- def forward(self, input_dict, state, seq_lens):
- out, self._value_out = self.base_model(input_dict["obs"])
- return out, []
- @override(ModelV2)
- def value_function(self):
- return tf.reshape(self._value_out, [-1])
- class SharedWeightsModel2(TFModelV2):
- """The "other" TFModelV2 using the same shared space as the one above."""
- def __init__(self, observation_space, action_space, num_outputs,
- model_config, name):
- super().__init__(observation_space, action_space, num_outputs,
- model_config, name)
- inputs = tf.keras.layers.Input(observation_space.shape)
- # Weights shared with SharedWeightsModel1.
- with tf1.variable_scope(
- tf1.VariableScope(tf1.AUTO_REUSE, "shared"),
- reuse=tf1.AUTO_REUSE,
- auxiliary_name_scope=False):
- last_layer = tf.keras.layers.Dense(
- units=64, activation=tf.nn.relu, name="fc1")(inputs)
- output = tf.keras.layers.Dense(
- units=num_outputs, activation=None, name="fc_out")(last_layer)
- vf = tf.keras.layers.Dense(
- units=1, activation=None, name="value_out")(last_layer)
- self.base_model = tf.keras.models.Model(inputs, [output, vf])
- @override(ModelV2)
- def forward(self, input_dict, state, seq_lens):
- out, self._value_out = self.base_model(input_dict["obs"])
- return out, []
- @override(ModelV2)
- def value_function(self):
- return tf.reshape(self._value_out, [-1])
- TORCH_GLOBAL_SHARED_LAYER = None
- if torch:
- # The global, shared layer to be used by both models.
- TORCH_GLOBAL_SHARED_LAYER = SlimFC(
- 64,
- 64,
- activation_fn=nn.ReLU,
- initializer=torch.nn.init.xavier_uniform_,
- )
- class TorchSharedWeightsModel(TorchModelV2, nn.Module):
- """Example of weight sharing between two different TorchModelV2s.
- The shared (single) layer is simply defined outside of the two Models,
- then used by both Models in their forward pass.
- """
- def __init__(self, observation_space, action_space, num_outputs,
- model_config, name):
- TorchModelV2.__init__(self, observation_space, action_space,
- num_outputs, model_config, name)
- nn.Module.__init__(self)
- # Non-shared initial layer.
- self.first_layer = SlimFC(
- int(np.product(observation_space.shape)),
- 64,
- activation_fn=nn.ReLU,
- initializer=torch.nn.init.xavier_uniform_)
- # Non-shared final layer.
- self.last_layer = SlimFC(
- 64,
- self.num_outputs,
- activation_fn=None,
- initializer=torch.nn.init.xavier_uniform_)
- self.vf = SlimFC(
- 64,
- 1,
- activation_fn=None,
- initializer=torch.nn.init.xavier_uniform_,
- )
- self._global_shared_layer = TORCH_GLOBAL_SHARED_LAYER
- self._output = None
- @override(ModelV2)
- def forward(self, input_dict, state, seq_lens):
- out = self.first_layer(input_dict["obs"])
- self._output = self._global_shared_layer(out)
- model_out = self.last_layer(self._output)
- return model_out, []
- @override(ModelV2)
- def value_function(self):
- assert self._output is not None, "must call forward first!"
- return torch.reshape(self.vf(self._output), [-1])
|