custom_model_api.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from gym.spaces import Box
  2. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  3. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  4. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as \
  5. TorchFullyConnectedNetwork
  6. from ray.rllib.models.torch.misc import SlimFC
  7. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  8. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  9. tf1, tf, tfv = try_import_tf()
  10. torch, nn = try_import_torch()
  11. # __sphinx_doc_model_api_1_begin__
  12. class DuelingQModel(TFModelV2): # or: TorchModelV2
  13. """A simple, hard-coded dueling head model."""
  14. def __init__(self, obs_space, action_space, num_outputs, model_config,
  15. name):
  16. # Pass num_outputs=None into super constructor (so that no action/
  17. # logits output layer is built).
  18. # Alternatively, you can pass in num_outputs=[last layer size of
  19. # config[model][fcnet_hiddens]] AND set no_last_linear=True, but
  20. # this seems more tedious as you will have to explain users of this
  21. # class that num_outputs is NOT the size of your Q-output layer.
  22. super(DuelingQModel, self).__init__(obs_space, action_space, None,
  23. model_config, name)
  24. # Now: self.num_outputs contains the last layer's size, which
  25. # we can use to construct the dueling head (see torch: SlimFC
  26. # below).
  27. # Construct advantage head ...
  28. self.A = tf.keras.layers.Dense(num_outputs)
  29. # torch:
  30. # self.A = SlimFC(
  31. # in_size=self.num_outputs, out_size=num_outputs)
  32. # ... and value head.
  33. self.V = tf.keras.layers.Dense(1)
  34. # torch:
  35. # self.V = SlimFC(in_size=self.num_outputs, out_size=1)
  36. def get_q_values(self, underlying_output):
  37. # Calculate q-values following dueling logic:
  38. v = self.V(underlying_output) # value
  39. a = self.A(underlying_output) # advantages (per action)
  40. advantages_mean = tf.reduce_mean(a, 1)
  41. advantages_centered = a - tf.expand_dims(advantages_mean, 1)
  42. return v + advantages_centered # q-values
  43. # __sphinx_doc_model_api_1_end__
  44. class TorchDuelingQModel(TorchModelV2):
  45. """A simple, hard-coded dueling head model."""
  46. def __init__(self, obs_space, action_space, num_outputs, model_config,
  47. name):
  48. # Pass num_outputs=None into super constructor (so that no action/
  49. # logits output layer is built).
  50. # Alternatively, you can pass in num_outputs=[last layer size of
  51. # config[model][fcnet_hiddens]] AND set no_last_linear=True, but
  52. # this seems more tedious as you will have to explain users of this
  53. # class that num_outputs is NOT the size of your Q-output layer.
  54. nn.Module.__init__(self)
  55. super(TorchDuelingQModel, self).__init__(obs_space, action_space, None,
  56. model_config, name)
  57. # Now: self.num_outputs contains the last layer's size, which
  58. # we can use to construct the dueling head (see torch: SlimFC
  59. # below).
  60. # Construct advantage head ...
  61. self.A = SlimFC(in_size=self.num_outputs, out_size=num_outputs)
  62. # ... and value head.
  63. self.V = SlimFC(in_size=self.num_outputs, out_size=1)
  64. def get_q_values(self, underlying_output):
  65. # Calculate q-values following dueling logic:
  66. v = self.V(underlying_output) # value
  67. a = self.A(underlying_output) # advantages (per action)
  68. advantages_mean = torch.mean(a, 1)
  69. advantages_centered = a - torch.unsqueeze(advantages_mean, 1)
  70. return v + advantages_centered # q-values
  71. class ContActionQModel(TFModelV2):
  72. """A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""
  73. def __init__(self, obs_space, action_space, num_outputs, model_config,
  74. name):
  75. # Pass num_outputs=None into super constructor (so that no action/
  76. # logits output layer is built).
  77. # Alternatively, you can pass in num_outputs=[last layer size of
  78. # config[model][fcnet_hiddens]] AND set no_last_linear=True, but
  79. # this seems more tedious as you will have to explain users of this
  80. # class that num_outputs is NOT the size of your Q-output layer.
  81. super(ContActionQModel, self).__init__(obs_space, action_space, None,
  82. model_config, name)
  83. # Now: self.num_outputs contains the last layer's size, which
  84. # we can use to construct the single q-value computing head.
  85. # Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
  86. # to be used for Q-value calculation.
  87. # Use the current value of self.num_outputs, which is the wrapped
  88. # model's output layer size.
  89. combined_space = Box(-1.0, 1.0,
  90. (self.num_outputs + action_space.shape[0], ))
  91. self.q_head = FullyConnectedNetwork(combined_space, action_space, 1,
  92. model_config, "q_head")
  93. # Missing here: Probably still have to provide action output layer
  94. # and value layer and make sure self.num_outputs is correctly set.
  95. def get_single_q_value(self, underlying_output, action):
  96. # Calculate the q-value after concating the underlying output with
  97. # the given action.
  98. input_ = tf.concat([underlying_output, action], axis=-1)
  99. # Construct a simple input_dict (needed for self.q_head as it's an
  100. # RLlib ModelV2).
  101. input_dict = {"obs": input_}
  102. # Ignore state outputs.
  103. q_values, _ = self.q_head(input_dict)
  104. return q_values
  105. # __sphinx_doc_model_api_2_begin__
  106. class TorchContActionQModel(TorchModelV2):
  107. """A simple, q-value-from-cont-action model (for e.g. SAC type algos)."""
  108. def __init__(self, obs_space, action_space, num_outputs, model_config,
  109. name):
  110. nn.Module.__init__(self)
  111. # Pass num_outputs=None into super constructor (so that no action/
  112. # logits output layer is built).
  113. # Alternatively, you can pass in num_outputs=[last layer size of
  114. # config[model][fcnet_hiddens]] AND set no_last_linear=True, but
  115. # this seems more tedious as you will have to explain users of this
  116. # class that num_outputs is NOT the size of your Q-output layer.
  117. super(TorchContActionQModel, self).__init__(obs_space, action_space,
  118. None, model_config, name)
  119. # Now: self.num_outputs contains the last layer's size, which
  120. # we can use to construct the single q-value computing head.
  121. # Nest an RLlib FullyConnectedNetwork (torch or tf) into this one here
  122. # to be used for Q-value calculation.
  123. # Use the current value of self.num_outputs, which is the wrapped
  124. # model's output layer size.
  125. combined_space = Box(-1.0, 1.0,
  126. (self.num_outputs + action_space.shape[0], ))
  127. self.q_head = TorchFullyConnectedNetwork(combined_space, action_space,
  128. 1, model_config, "q_head")
  129. # Missing here: Probably still have to provide action output layer
  130. # and value layer and make sure self.num_outputs is correctly set.
  131. def get_single_q_value(self, underlying_output, action):
  132. # Calculate the q-value after concating the underlying output with
  133. # the given action.
  134. input_ = torch.cat([underlying_output, action], dim=-1)
  135. # Construct a simple input_dict (needed for self.q_head as it's an
  136. # RLlib ModelV2).
  137. input_dict = {"obs": input_}
  138. # Ignore state outputs.
  139. q_values, _ = self.q_head(input_dict)
  140. return q_values
  141. # __sphinx_doc_model_api_2_end__