modelv3.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import numpy as np
  2. from ray.rllib.policy.sample_batch import SampleBatch
  3. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  4. tf1, tf, tfv = try_import_tf()
  5. torch, nn = try_import_torch()
  6. class RNNModel(tf.keras.models.Model if tf else object):
  7. """Example of using the Keras functional API to define an RNN model."""
  8. def __init__(self,
  9. input_space,
  10. action_space,
  11. num_outputs,
  12. *,
  13. name="",
  14. hiddens_size=256,
  15. cell_size=64):
  16. super().__init__(name=name)
  17. self.cell_size = cell_size
  18. # Preprocess observation with a hidden layer and send to LSTM cell
  19. self.dense = tf.keras.layers.Dense(
  20. hiddens_size, activation=tf.nn.relu, name="dense1")
  21. self.lstm = tf.keras.layers.LSTM(
  22. cell_size, return_sequences=True, return_state=True, name="lstm")
  23. # Postprocess LSTM output with another hidden layer and compute
  24. # values.
  25. self.logits = tf.keras.layers.Dense(
  26. num_outputs, activation=tf.keras.activations.linear, name="logits")
  27. self.values = tf.keras.layers.Dense(1, activation=None, name="values")
  28. def call(self, sample_batch):
  29. dense_out = self.dense(sample_batch["obs"])
  30. B = tf.shape(sample_batch[SampleBatch.SEQ_LENS])[0]
  31. lstm_in = tf.reshape(dense_out, [B, -1, dense_out.shape.as_list()[1]])
  32. lstm_out, h, c = self.lstm(
  33. inputs=lstm_in,
  34. mask=tf.sequence_mask(sample_batch[SampleBatch.SEQ_LENS]),
  35. initial_state=[
  36. sample_batch["state_in_0"], sample_batch["state_in_1"]
  37. ],
  38. )
  39. lstm_out = tf.reshape(lstm_out, [-1, lstm_out.shape.as_list()[2]])
  40. logits = self.logits(lstm_out)
  41. values = tf.reshape(self.values(lstm_out), [-1])
  42. return logits, [h, c], {SampleBatch.VF_PREDS: values}
  43. def get_initial_state(self):
  44. return [
  45. np.zeros(self.cell_size, np.float32),
  46. np.zeros(self.cell_size, np.float32),
  47. ]