123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import numpy as np
- from ray.rllib.models.modelv2 import ModelV2
- from ray.rllib.models.preprocessors import get_preprocessor
- from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
- from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- tf1, tf, tfv = try_import_tf()
- torch, nn = try_import_torch()
- class RNNModel(RecurrentNetwork):
- """Example of using the Keras functional API to define a RNN model."""
- def __init__(self,
- obs_space,
- action_space,
- num_outputs,
- model_config,
- name,
- hiddens_size=256,
- cell_size=64):
- super(RNNModel, self).__init__(obs_space, action_space, num_outputs,
- model_config, name)
- self.cell_size = cell_size
- # Define input layers
- input_layer = tf.keras.layers.Input(
- shape=(None, obs_space.shape[0]), name="inputs")
- state_in_h = tf.keras.layers.Input(shape=(cell_size, ), name="h")
- state_in_c = tf.keras.layers.Input(shape=(cell_size, ), name="c")
- seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
- # Preprocess observation with a hidden layer and send to LSTM cell
- dense1 = tf.keras.layers.Dense(
- hiddens_size, activation=tf.nn.relu, name="dense1")(input_layer)
- lstm_out, state_h, state_c = tf.keras.layers.LSTM(
- cell_size, return_sequences=True, return_state=True, name="lstm")(
- inputs=dense1,
- mask=tf.sequence_mask(seq_in),
- initial_state=[state_in_h, state_in_c])
- # Postprocess LSTM output with another hidden layer and compute values
- logits = tf.keras.layers.Dense(
- self.num_outputs,
- activation=tf.keras.activations.linear,
- name="logits")(lstm_out)
- values = tf.keras.layers.Dense(
- 1, activation=None, name="values")(lstm_out)
- # Create the RNN model
- self.rnn_model = tf.keras.Model(
- inputs=[input_layer, seq_in, state_in_h, state_in_c],
- outputs=[logits, values, state_h, state_c])
- self.rnn_model.summary()
- @override(RecurrentNetwork)
- def forward_rnn(self, inputs, state, seq_lens):
- model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] +
- state)
- return model_out, [h, c]
- @override(ModelV2)
- def get_initial_state(self):
- return [
- np.zeros(self.cell_size, np.float32),
- np.zeros(self.cell_size, np.float32),
- ]
- @override(ModelV2)
- def value_function(self):
- return tf.reshape(self._value_out, [-1])
- class TorchRNNModel(TorchRNN, nn.Module):
- def __init__(self,
- obs_space,
- action_space,
- num_outputs,
- model_config,
- name,
- fc_size=64,
- lstm_state_size=256):
- nn.Module.__init__(self)
- super().__init__(obs_space, action_space, num_outputs, model_config,
- name)
- self.obs_size = get_preprocessor(obs_space)(obs_space).size
- self.fc_size = fc_size
- self.lstm_state_size = lstm_state_size
- # Build the Module from fc + LSTM + 2xfc (action + value outs).
- self.fc1 = nn.Linear(self.obs_size, self.fc_size)
- self.lstm = nn.LSTM(
- self.fc_size, self.lstm_state_size, batch_first=True)
- self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
- self.value_branch = nn.Linear(self.lstm_state_size, 1)
- # Holds the current "base" output (before logits layer).
- self._features = None
- @override(ModelV2)
- def get_initial_state(self):
- # TODO: (sven): Get rid of `get_initial_state` once Trajectory
- # View API is supported across all of RLlib.
- # Place hidden states on same device as model.
- h = [
- self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
- self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
- ]
- return h
- @override(ModelV2)
- def value_function(self):
- assert self._features is not None, "must call forward() first"
- return torch.reshape(self.value_branch(self._features), [-1])
- @override(TorchRNN)
- def forward_rnn(self, inputs, state, seq_lens):
- """Feeds `inputs` (B x T x ..) through the Gru Unit.
- Returns the resulting outputs as a sequence (B x T x ...).
- Values are stored in self._cur_value in simple (B) shape (where B
- contains both the B and T dims!).
- Returns:
- NN Outputs (B x T x ...) as sequence.
- The state batches as a List of two items (c- and h-states).
- """
- x = nn.functional.relu(self.fc1(inputs))
- self._features, [h, c] = self.lstm(
- x, [torch.unsqueeze(state[0], 0),
- torch.unsqueeze(state[1], 0)])
- action_out = self.action_branch(self._features)
- return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
|