shared_weights_model.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import numpy as np
  2. from ray.rllib.models.modelv2 import ModelV2
  3. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  4. from ray.rllib.models.torch.misc import SlimFC
  5. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  6. from ray.rllib.utils.annotations import override
  7. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  8. tf1, tf, tfv = try_import_tf()
  9. torch, nn = try_import_torch()
  10. TF2_GLOBAL_SHARED_LAYER = None
  11. class TF2SharedWeightsModel(TFModelV2):
  12. """Example of weight sharing between two different TFModelV2s.
  13. NOTE: This will only work for tf2.x. When running with config.framework=tf,
  14. use SharedWeightsModel1 and SharedWeightsModel2 below, instead!
  15. The shared (single) layer is simply defined outside of the two Models,
  16. then used by both Models in their forward pass.
  17. """
  18. def __init__(self, observation_space, action_space, num_outputs,
  19. model_config, name):
  20. super().__init__(observation_space, action_space, num_outputs,
  21. model_config, name)
  22. global TF2_GLOBAL_SHARED_LAYER
  23. # The global, shared layer to be used by both models.
  24. if TF2_GLOBAL_SHARED_LAYER is None:
  25. TF2_GLOBAL_SHARED_LAYER = tf.keras.layers.Dense(
  26. units=64, activation=tf.nn.relu, name="fc1")
  27. inputs = tf.keras.layers.Input(observation_space.shape)
  28. last_layer = TF2_GLOBAL_SHARED_LAYER(inputs)
  29. output = tf.keras.layers.Dense(
  30. units=num_outputs, activation=None, name="fc_out")(last_layer)
  31. vf = tf.keras.layers.Dense(
  32. units=1, activation=None, name="value_out")(last_layer)
  33. self.base_model = tf.keras.models.Model(inputs, [output, vf])
  34. @override(ModelV2)
  35. def forward(self, input_dict, state, seq_lens):
  36. out, self._value_out = self.base_model(input_dict["obs"])
  37. return out, []
  38. @override(ModelV2)
  39. def value_function(self):
  40. return tf.reshape(self._value_out, [-1])
  41. class SharedWeightsModel1(TFModelV2):
  42. """Example of weight sharing between two different TFModelV2s.
  43. NOTE: This will only work for tf1 (static graph). When running with
  44. config.framework=tf2, use TF2SharedWeightsModel, instead!
  45. Here, we share the variables defined in the 'shared' variable scope
  46. by entering it explicitly with tf1.AUTO_REUSE. This creates the
  47. variables for the 'fc1' layer in a global scope called 'shared'
  48. (outside of the Policy's normal variable scope).
  49. """
  50. def __init__(self, observation_space, action_space, num_outputs,
  51. model_config, name):
  52. super().__init__(observation_space, action_space, num_outputs,
  53. model_config, name)
  54. inputs = tf.keras.layers.Input(observation_space.shape)
  55. with tf1.variable_scope(
  56. tf1.VariableScope(tf1.AUTO_REUSE, "shared"),
  57. reuse=tf1.AUTO_REUSE,
  58. auxiliary_name_scope=False):
  59. last_layer = tf.keras.layers.Dense(
  60. units=64, activation=tf.nn.relu, name="fc1")(inputs)
  61. output = tf.keras.layers.Dense(
  62. units=num_outputs, activation=None, name="fc_out")(last_layer)
  63. vf = tf.keras.layers.Dense(
  64. units=1, activation=None, name="value_out")(last_layer)
  65. self.base_model = tf.keras.models.Model(inputs, [output, vf])
  66. @override(ModelV2)
  67. def forward(self, input_dict, state, seq_lens):
  68. out, self._value_out = self.base_model(input_dict["obs"])
  69. return out, []
  70. @override(ModelV2)
  71. def value_function(self):
  72. return tf.reshape(self._value_out, [-1])
  73. class SharedWeightsModel2(TFModelV2):
  74. """The "other" TFModelV2 using the same shared space as the one above."""
  75. def __init__(self, observation_space, action_space, num_outputs,
  76. model_config, name):
  77. super().__init__(observation_space, action_space, num_outputs,
  78. model_config, name)
  79. inputs = tf.keras.layers.Input(observation_space.shape)
  80. # Weights shared with SharedWeightsModel1.
  81. with tf1.variable_scope(
  82. tf1.VariableScope(tf1.AUTO_REUSE, "shared"),
  83. reuse=tf1.AUTO_REUSE,
  84. auxiliary_name_scope=False):
  85. last_layer = tf.keras.layers.Dense(
  86. units=64, activation=tf.nn.relu, name="fc1")(inputs)
  87. output = tf.keras.layers.Dense(
  88. units=num_outputs, activation=None, name="fc_out")(last_layer)
  89. vf = tf.keras.layers.Dense(
  90. units=1, activation=None, name="value_out")(last_layer)
  91. self.base_model = tf.keras.models.Model(inputs, [output, vf])
  92. @override(ModelV2)
  93. def forward(self, input_dict, state, seq_lens):
  94. out, self._value_out = self.base_model(input_dict["obs"])
  95. return out, []
  96. @override(ModelV2)
  97. def value_function(self):
  98. return tf.reshape(self._value_out, [-1])
  99. TORCH_GLOBAL_SHARED_LAYER = None
  100. if torch:
  101. # The global, shared layer to be used by both models.
  102. TORCH_GLOBAL_SHARED_LAYER = SlimFC(
  103. 64,
  104. 64,
  105. activation_fn=nn.ReLU,
  106. initializer=torch.nn.init.xavier_uniform_,
  107. )
  108. class TorchSharedWeightsModel(TorchModelV2, nn.Module):
  109. """Example of weight sharing between two different TorchModelV2s.
  110. The shared (single) layer is simply defined outside of the two Models,
  111. then used by both Models in their forward pass.
  112. """
  113. def __init__(self, observation_space, action_space, num_outputs,
  114. model_config, name):
  115. TorchModelV2.__init__(self, observation_space, action_space,
  116. num_outputs, model_config, name)
  117. nn.Module.__init__(self)
  118. # Non-shared initial layer.
  119. self.first_layer = SlimFC(
  120. int(np.product(observation_space.shape)),
  121. 64,
  122. activation_fn=nn.ReLU,
  123. initializer=torch.nn.init.xavier_uniform_)
  124. # Non-shared final layer.
  125. self.last_layer = SlimFC(
  126. 64,
  127. self.num_outputs,
  128. activation_fn=None,
  129. initializer=torch.nn.init.xavier_uniform_)
  130. self.vf = SlimFC(
  131. 64,
  132. 1,
  133. activation_fn=None,
  134. initializer=torch.nn.init.xavier_uniform_,
  135. )
  136. self._global_shared_layer = TORCH_GLOBAL_SHARED_LAYER
  137. self._output = None
  138. @override(ModelV2)
  139. def forward(self, input_dict, state, seq_lens):
  140. out = self.first_layer(input_dict["obs"])
  141. self._output = self._global_shared_layer(out)
  142. model_out = self.last_layer(self._output)
  143. return model_out, []
  144. @override(ModelV2)
  145. def value_function(self):
  146. assert self._output is not None, "must call forward first!"
  147. return torch.reshape(self.vf(self._output), [-1])