trajectory_view_utilizing_models.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  2. from ray.rllib.models.torch.misc import SlimFC
  3. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  4. from ray.rllib.policy.view_requirement import ViewRequirement
  5. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  6. from ray.rllib.utils.tf_utils import one_hot
  7. from ray.rllib.utils.torch_utils import one_hot as torch_one_hot
  8. tf1, tf, tfv = try_import_tf()
  9. torch, nn = try_import_torch()
  10. # __sphinx_doc_begin__
  11. class FrameStackingCartPoleModel(TFModelV2):
  12. """A simple FC model that takes the last n observations as input."""
  13. def __init__(self,
  14. obs_space,
  15. action_space,
  16. num_outputs,
  17. model_config,
  18. name,
  19. num_frames=3):
  20. super(FrameStackingCartPoleModel, self).__init__(
  21. obs_space, action_space, None, model_config, name)
  22. self.num_frames = num_frames
  23. self.num_outputs = num_outputs
  24. # Construct actual (very simple) FC model.
  25. assert len(obs_space.shape) == 1
  26. obs = tf.keras.layers.Input(
  27. shape=(self.num_frames, obs_space.shape[0]))
  28. obs_reshaped = tf.keras.layers.Reshape(
  29. [obs_space.shape[0] * self.num_frames])(obs)
  30. rewards = tf.keras.layers.Input(shape=(self.num_frames))
  31. rewards_reshaped = tf.keras.layers.Reshape([self.num_frames])(rewards)
  32. actions = tf.keras.layers.Input(
  33. shape=(self.num_frames, self.action_space.n))
  34. actions_reshaped = tf.keras.layers.Reshape(
  35. [action_space.n * self.num_frames])(actions)
  36. input_ = tf.keras.layers.Concatenate(axis=-1)(
  37. [obs_reshaped, actions_reshaped, rewards_reshaped])
  38. layer1 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(input_)
  39. layer2 = tf.keras.layers.Dense(256, activation=tf.nn.relu)(layer1)
  40. out = tf.keras.layers.Dense(self.num_outputs)(layer2)
  41. values = tf.keras.layers.Dense(1)(layer1)
  42. self.base_model = tf.keras.models.Model([obs, actions, rewards],
  43. [out, values])
  44. self._last_value = None
  45. self.view_requirements["prev_n_obs"] = ViewRequirement(
  46. data_col="obs",
  47. shift="-{}:0".format(num_frames - 1),
  48. space=obs_space)
  49. self.view_requirements["prev_n_rewards"] = ViewRequirement(
  50. data_col="rewards", shift="-{}:-1".format(self.num_frames))
  51. self.view_requirements["prev_n_actions"] = ViewRequirement(
  52. data_col="actions",
  53. shift="-{}:-1".format(self.num_frames),
  54. space=self.action_space)
  55. def forward(self, input_dict, states, seq_lens):
  56. obs = tf.cast(input_dict["prev_n_obs"], tf.float32)
  57. rewards = tf.cast(input_dict["prev_n_rewards"], tf.float32)
  58. actions = one_hot(input_dict["prev_n_actions"], self.action_space)
  59. out, self._last_value = self.base_model([obs, actions, rewards])
  60. return out, []
  61. def value_function(self):
  62. return tf.squeeze(self._last_value, -1)
  63. # __sphinx_doc_end__
  64. class TorchFrameStackingCartPoleModel(TorchModelV2, nn.Module):
  65. """A simple FC model that takes the last n observations as input."""
  66. def __init__(self,
  67. obs_space,
  68. action_space,
  69. num_outputs,
  70. model_config,
  71. name,
  72. num_frames=3):
  73. nn.Module.__init__(self)
  74. super(TorchFrameStackingCartPoleModel, self).__init__(
  75. obs_space, action_space, None, model_config, name)
  76. self.num_frames = num_frames
  77. self.num_outputs = num_outputs
  78. # Construct actual (very simple) FC model.
  79. assert len(obs_space.shape) == 1
  80. in_size = self.num_frames * (obs_space.shape[0] + action_space.n + 1)
  81. self.layer1 = SlimFC(
  82. in_size=in_size, out_size=256, activation_fn="relu")
  83. self.layer2 = SlimFC(in_size=256, out_size=256, activation_fn="relu")
  84. self.out = SlimFC(
  85. in_size=256, out_size=self.num_outputs, activation_fn="linear")
  86. self.values = SlimFC(in_size=256, out_size=1, activation_fn="linear")
  87. self._last_value = None
  88. self.view_requirements["prev_n_obs"] = ViewRequirement(
  89. data_col="obs",
  90. shift="-{}:0".format(num_frames - 1),
  91. space=obs_space)
  92. self.view_requirements["prev_n_rewards"] = ViewRequirement(
  93. data_col="rewards", shift="-{}:-1".format(self.num_frames))
  94. self.view_requirements["prev_n_actions"] = ViewRequirement(
  95. data_col="actions",
  96. shift="-{}:-1".format(self.num_frames),
  97. space=self.action_space)
  98. def forward(self, input_dict, states, seq_lens):
  99. obs = input_dict["prev_n_obs"]
  100. obs = torch.reshape(obs,
  101. [-1, self.obs_space.shape[0] * self.num_frames])
  102. rewards = torch.reshape(input_dict["prev_n_rewards"],
  103. [-1, self.num_frames])
  104. actions = torch_one_hot(input_dict["prev_n_actions"],
  105. self.action_space)
  106. actions = torch.reshape(actions,
  107. [-1, self.num_frames * actions.shape[-1]])
  108. input_ = torch.cat([obs, actions, rewards], dim=-1)
  109. features = self.layer1(input_)
  110. features = self.layer2(features)
  111. out = self.out(features)
  112. self._last_value = self.values(features)
  113. return out, []
  114. def value_function(self):
  115. return torch.squeeze(self._last_value, -1)