123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- from gym.spaces import Box
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.tf.tf_modelv2 import TFModelV2
- from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
- from ray.rllib.models.torch.misc import SlimFC
- from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
- from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- tf1, tf, tfv = try_import_tf()
- torch, nn = try_import_torch()
- class CentralizedCriticModel(TFModelV2):
- """Multi-agent model that implements a centralized value function."""
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- super(CentralizedCriticModel, self).__init__(
- obs_space, action_space, num_outputs, model_config, name)
- # Base of the model
- self.model = FullyConnectedNetwork(obs_space, action_space,
- num_outputs, model_config, name)
- # Central VF maps (obs, opp_obs, opp_act) -> vf_pred
- obs = tf.keras.layers.Input(shape=(6, ), name="obs")
- opp_obs = tf.keras.layers.Input(shape=(6, ), name="opp_obs")
- opp_act = tf.keras.layers.Input(shape=(2, ), name="opp_act")
- concat_obs = tf.keras.layers.Concatenate(axis=1)(
- [obs, opp_obs, opp_act])
- central_vf_dense = tf.keras.layers.Dense(
- 16, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs)
- central_vf_out = tf.keras.layers.Dense(
- 1, activation=None, name="c_vf_out")(central_vf_dense)
- self.central_vf = tf.keras.Model(
- inputs=[obs, opp_obs, opp_act], outputs=central_vf_out)
- @override(ModelV2)
- def forward(self, input_dict, state, seq_lens):
- return self.model.forward(input_dict, state, seq_lens)
- def central_value_function(self, obs, opponent_obs, opponent_actions):
- return tf.reshape(
- self.central_vf([
- obs, opponent_obs,
- tf.one_hot(tf.cast(opponent_actions, tf.int32), 2)
- ]), [-1])
- @override(ModelV2)
- def value_function(self):
- return self.model.value_function() # not used
- class YetAnotherCentralizedCriticModel(TFModelV2):
- """Multi-agent model that implements a centralized value function.
- It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
- former of which can be used for computing actions (i.e., decentralized
- execution), and the latter for optimization (i.e., centralized learning).
- This model has two parts:
- - An action model that looks at just 'own_obs' to compute actions
- - A value model that also looks at the 'opponent_obs' / 'opponent_action'
- to compute the value (it does this by using the 'obs_flat' tensor).
- """
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- super(YetAnotherCentralizedCriticModel, self).__init__(
- obs_space, action_space, num_outputs, model_config, name)
- self.action_model = FullyConnectedNetwork(
- Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6)
- action_space,
- num_outputs,
- model_config,
- name + "_action")
- self.value_model = FullyConnectedNetwork(obs_space, action_space, 1,
- model_config, name + "_vf")
- def forward(self, input_dict, state, seq_lens):
- self._value_out, _ = self.value_model({
- "obs": input_dict["obs_flat"]
- }, state, seq_lens)
- return self.action_model({
- "obs": input_dict["obs"]["own_obs"]
- }, state, seq_lens)
- def value_function(self):
- return tf.reshape(self._value_out, [-1])
- class TorchCentralizedCriticModel(TorchModelV2, nn.Module):
- """Multi-agent model that implements a centralized VF."""
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
- model_config, name)
- nn.Module.__init__(self)
- # Base of the model
- self.model = TorchFC(obs_space, action_space, num_outputs,
- model_config, name)
- # Central VF maps (obs, opp_obs, opp_act) -> vf_pred
- input_size = 6 + 6 + 2 # obs + opp_obs + opp_act
- self.central_vf = nn.Sequential(
- SlimFC(input_size, 16, activation_fn=nn.Tanh),
- SlimFC(16, 1),
- )
- @override(ModelV2)
- def forward(self, input_dict, state, seq_lens):
- model_out, _ = self.model(input_dict, state, seq_lens)
- return model_out, []
- def central_value_function(self, obs, opponent_obs, opponent_actions):
- input_ = torch.cat([
- obs, opponent_obs,
- torch.nn.functional.one_hot(opponent_actions.long(), 2).float()
- ], 1)
- return torch.reshape(self.central_vf(input_), [-1])
- @override(ModelV2)
- def value_function(self):
- return self.model.value_function() # not used
- class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
- """Multi-agent model that implements a centralized value function.
- It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
- former of which can be used for computing actions (i.e., decentralized
- execution), and the latter for optimization (i.e., centralized learning).
- This model has two parts:
- - An action model that looks at just 'own_obs' to compute actions
- - A value model that also looks at the 'opponent_obs' / 'opponent_action'
- to compute the value (it does this by using the 'obs_flat' tensor).
- """
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
- model_config, name)
- nn.Module.__init__(self)
- self.action_model = TorchFC(
- Box(low=0, high=1, shape=(6, )), # one-hot encoded Discrete(6)
- action_space,
- num_outputs,
- model_config,
- name + "_action")
- self.value_model = TorchFC(obs_space, action_space, 1, model_config,
- name + "_vf")
- self._model_in = None
- def forward(self, input_dict, state, seq_lens):
- # Store model-input for possible `value_function()` call.
- self._model_in = [input_dict["obs_flat"], state, seq_lens]
- return self.action_model({
- "obs": input_dict["obs"]["own_obs"]
- }, state, seq_lens)
- def value_function(self):
- value_out, _ = self.value_model({
- "obs": self._model_in[0]
- }, self._model_in[1], self._model_in[2])
- return torch.reshape(value_out, [-1])
|