centralized_critic_models.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from gym.spaces import Box
  2. from ray.rllib.models.modelv2 import ModelV2
  3. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  4. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  5. from ray.rllib.models.torch.misc import SlimFC
  6. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  7. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
  8. from ray.rllib.utils.annotations import override
  9. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  10. tf1, tf, tfv = try_import_tf()
  11. torch, nn = try_import_torch()
  12. class CentralizedCriticModel(TFModelV2):
  13. """Multi-agent model that implements a centralized value function."""
  14. def __init__(self, obs_space, action_space, num_outputs, model_config,
  15. name):
  16. super(CentralizedCriticModel, self).__init__(
  17. obs_space, action_space, num_outputs, model_config, name)
  18. # Base of the model
  19. self.model = FullyConnectedNetwork(obs_space, action_space,
  20. num_outputs, model_config, name)
  21. # Central VF maps (obs, opp_obs, opp_act) -> vf_pred
  22. obs = tf.keras.layers.Input(shape=(6, ), name="obs")
  23. opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs")
  24. opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act")
  25. concat_obs = tf.keras.layers.Concatenate(axis=1)(
  26. [obs, opp_obs, opp_act])
  27. central_vf_dense = tf.keras.layers.Dense(
  28. 16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs)
  29. central_vf_out = tf.keras.layers.Dense(
  30. 1, activation=None, name="c_vf_out")(central_vf_dense)
  31. self.central_vf = tf.keras.Model(
  32. inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
  33. @override(ModelV2)
  34. def forward(self, input_dict, state, seq_lens):
  35. return self.model.forward(input_dict, state, seq_lens)
  36. def central_value_function(self, obs, opponent_obs, opponent_actions):
  37. return tf.reshape(
  38. self.central_vf([
  39. obs, opponent_obs,
  40. tf.one_hot(tf.cast(opponent_actions, tf.int32), 2)
  41. ]), [-1])
  42. @override(ModelV2)
  43. def value_function(self):
  44. return self.model.value_function() # not used
  45. class YetAnotherCentralizedCriticModel(TFModelV2):
  46. """Multi-agent model that implements a centralized value function.
  47. It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
  48. former of which can be used for computing actions (i.e., decentralized
  49. execution), and the latter for optimization (i.e., centralized learning).
  50. This model has two parts:
  51. - An action model that looks at just 'own_obs' to compute actions
  52. - A value model that also looks at the 'opponent_obs' / 'opponent_action'
  53. to compute the value (it does this by using the 'obs_flat' tensor).
  54. """
  55. def __init__(self, obs_space, action_space, num_outputs, model_config,
  56. name):
  57. super(YetAnotherCentralizedCriticModel, self).__init__(
  58. obs_space, action_space, num_outputs, model_config, name)
  59. self.action_model = FullyConnectedNetwork(
  60. Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6)
  61. action_space,
  62. num_outputs,
  63. model_config,
  64. name + "_action")
  65. self.value_model = FullyConnectedNetwork(obs_space, action_space, 1,
  66. model_config, name + "_vf")
  67. def forward(self, input_dict, state, seq_lens):
  68. self._value_out, _ = self.value_model({
  69. "obs": input_dict["obs_flat"]
  70. }, state, seq_lens)
  71. return self.action_model({
  72. "obs": input_dict["obs"]["own_obs"]
  73. }, state, seq_lens)
  74. def value_function(self):
  75. return tf.reshape(self._value_out, [-1])
  76. class TorchCentralizedCriticModel(TorchModelV2, nn.Module):
  77. """Multi-agent model that implements a centralized VF."""
  78. def __init__(self, obs_space, action_space, num_outputs, model_config,
  79. name):
  80. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  81. model_config, name)
  82. nn.Module.__init__(self)
  83. # Base of the model
  84. self.model = TorchFC(obs_space, action_space, num_outputs,
  85. model_config, name)
  86. # Central VF maps (obs, opp_obs, opp_act) -> vf_pred
  87. input_size = 6 + 6 + 2 # obs + opp_obs + opp_act
  88. self.central_vf = nn.Sequential(
  89. SlimFC(input_size, 16, activation_fn=nn.Tanh),
  90. SlimFC(16, 1),
  91. )
  92. @override(ModelV2)
  93. def forward(self, input_dict, state, seq_lens):
  94. model_out, _ = self.model(input_dict, state, seq_lens)
  95. return model_out, []
  96. def central_value_function(self, obs, opponent_obs, opponent_actions):
  97. input_ = torch.cat([
  98. obs, opponent_obs,
  99. torch.nn.functional.one_hot(opponent_actions.long(), 2).float()
  100. ], 1)
  101. return torch.reshape(self.central_vf(input_), [-1])
  102. @override(ModelV2)
  103. def value_function(self):
  104. return self.model.value_function() # not used
  105. class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
  106. """Multi-agent model that implements a centralized value function.
  107. It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
  108. former of which can be used for computing actions (i.e., decentralized
  109. execution), and the latter for optimization (i.e., centralized learning).
  110. This model has two parts:
  111. - An action model that looks at just 'own_obs' to compute actions
  112. - A value model that also looks at the 'opponent_obs' / 'opponent_action'
  113. to compute the value (it does this by using the 'obs_flat' tensor).
  114. """
  115. def __init__(self, obs_space, action_space, num_outputs, model_config,
  116. name):
  117. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  118. model_config, name)
  119. nn.Module.__init__(self)
  120. self.action_model = TorchFC(
  121. Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6)
  122. action_space,
  123. num_outputs,
  124. model_config,
  125. name + "_action")
  126. self.value_model = TorchFC(obs_space, action_space, 1, model_config,
  127. name + "_vf")
  128. self._model_in = None
  129. def forward(self, input_dict, state, seq_lens):
  130. # Store model-input for possible `value_function()` call.
  131. self._model_in = [input_dict["obs_flat"], state, seq_lens]
  132. return self.action_model({
  133. "obs": input_dict["obs"]["own_obs"]
  134. }, state, seq_lens)
  135. def value_function(self):
  136. value_out, _ = self.value_model({
  137. "obs": self._model_in[0]
  138. }, self._model_in[1], self._model_in[2])
  139. return torch.reshape(value_out, [-1])