rnn_model.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. from ray.rllib.models.modelv2 import ModelV2
  3. from ray.rllib.models.preprocessors import get_preprocessor
  4. from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
  5. from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
  6. from ray.rllib.utils.annotations import override
  7. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  8. tf1, tf, tfv = try_import_tf()
  9. torch, nn = try_import_torch()
  10. class RNNModel(RecurrentNetwork):
  11. """Example of using the Keras functional API to define a RNN model."""
  12. def __init__(self,
  13. obs_space,
  14. action_space,
  15. num_outputs,
  16. model_config,
  17. name,
  18. hiddens_size=256,
  19. cell_size=64):
  20. super(RNNModel, self).__init__(obs_space, action_space, num_outputs,
  21. model_config, name)
  22. self.cell_size = cell_size
  23. # Define input layers
  24. input_layer = tf.keras.layers.Input(
  25. shape=(None, obs_space.shape[0]), name="inputs")
  26. state_in_h = tf.keras.layers.Input(shape=(cell_size, ), name="h")
  27. state_in_c = tf.keras.layers.Input(shape=(cell_size, ), name="c")
  28. seq_in = tf.keras.layers.Input(shape=(), name="seq_in", dtype=tf.int32)
  29. # Preprocess observation with a hidden layer and send to LSTM cell
  30. dense1 = tf.keras.layers.Dense(
  31. hiddens_size, activation=tf.nn.relu, name="dense1")(input_layer)
  32. lstm_out, state_h, state_c = tf.keras.layers.LSTM(
  33. cell_size, return_sequences=True, return_state=True, name="lstm")(
  34. inputs=dense1,
  35. mask=tf.sequence_mask(seq_in),
  36. initial_state=[state_in_h, state_in_c])
  37. # Postprocess LSTM output with another hidden layer and compute values
  38. logits = tf.keras.layers.Dense(
  39. self.num_outputs,
  40. activation=tf.keras.activations.linear,
  41. name="logits")(lstm_out)
  42. values = tf.keras.layers.Dense(
  43. 1, activation=None, name="values")(lstm_out)
  44. # Create the RNN model
  45. self.rnn_model = tf.keras.Model(
  46. inputs=[input_layer, seq_in, state_in_h, state_in_c],
  47. outputs=[logits, values, state_h, state_c])
  48. self.rnn_model.summary()
  49. @override(RecurrentNetwork)
  50. def forward_rnn(self, inputs, state, seq_lens):
  51. model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] +
  52. state)
  53. return model_out, [h, c]
  54. @override(ModelV2)
  55. def get_initial_state(self):
  56. return [
  57. np.zeros(self.cell_size, np.float32),
  58. np.zeros(self.cell_size, np.float32),
  59. ]
  60. @override(ModelV2)
  61. def value_function(self):
  62. return tf.reshape(self._value_out, [-1])
  63. class TorchRNNModel(TorchRNN, nn.Module):
  64. def __init__(self,
  65. obs_space,
  66. action_space,
  67. num_outputs,
  68. model_config,
  69. name,
  70. fc_size=64,
  71. lstm_state_size=256):
  72. nn.Module.__init__(self)
  73. super().__init__(obs_space, action_space, num_outputs, model_config,
  74. name)
  75. self.obs_size = get_preprocessor(obs_space)(obs_space).size
  76. self.fc_size = fc_size
  77. self.lstm_state_size = lstm_state_size
  78. # Build the Module from fc + LSTM + 2xfc (action + value outs).
  79. self.fc1 = nn.Linear(self.obs_size, self.fc_size)
  80. self.lstm = nn.LSTM(
  81. self.fc_size, self.lstm_state_size, batch_first=True)
  82. self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
  83. self.value_branch = nn.Linear(self.lstm_state_size, 1)
  84. # Holds the current "base" output (before logits layer).
  85. self._features = None
  86. @override(ModelV2)
  87. def get_initial_state(self):
  88. # TODO: (sven): Get rid of `get_initial_state` once Trajectory
  89. # View API is supported across all of RLlib.
  90. # Place hidden states on same device as model.
  91. h = [
  92. self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
  93. self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
  94. ]
  95. return h
  96. @override(ModelV2)
  97. def value_function(self):
  98. assert self._features is not None, "must call forward() first"
  99. return torch.reshape(self.value_branch(self._features), [-1])
  100. @override(TorchRNN)
  101. def forward_rnn(self, inputs, state, seq_lens):
  102. """Feeds `inputs` (B x T x ..) through the Gru Unit.
  103. Returns the resulting outputs as a sequence (B x T x ...).
  104. Values are stored in self._cur_value in simple (B) shape (where B
  105. contains both the B and T dims!).
  106. Returns:
  107. NN Outputs (B x T x ...) as sequence.
  108. The state batches as a List of two items (c- and h-states).
  109. """
  110. x = nn.functional.relu(self.fc1(inputs))
  111. self._features, [h, c] = self.lstm(
  112. x, [torch.unsqueeze(state[0], 0),
  113. torch.unsqueeze(state[1], 0)])
  114. action_out = self.action_branch(self._features)
  115. return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]