autoregressive_action_model.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from gym.spaces import Discrete, Tuple
  2. from ray.rllib.models.tf.misc import normc_initializer
  3. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  4. from ray.rllib.models.torch.misc import normc_initializer as normc_init_torch
  5. from ray.rllib.models.torch.misc import SlimFC
  6. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  7. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  8. tf1, tf, tfv = try_import_tf()
  9. torch, nn = try_import_torch()
  10. class AutoregressiveActionModel(TFModelV2):
  11. """Implements the `.action_model` branch required above."""
  12. def __init__(self, obs_space, action_space, num_outputs, model_config,
  13. name):
  14. super(AutoregressiveActionModel, self).__init__(
  15. obs_space, action_space, num_outputs, model_config, name)
  16. if action_space != Tuple([Discrete(2), Discrete(2)]):
  17. raise ValueError(
  18. "This model only supports the [2, 2] action space")
  19. # Inputs
  20. obs_input = tf.keras.layers.Input(
  21. shape=obs_space.shape, name="obs_input")
  22. a1_input = tf.keras.layers.Input(shape=(1, ), name="a1_input")
  23. ctx_input = tf.keras.layers.Input(
  24. shape=(num_outputs, ), name="ctx_input")
  25. # Output of the model (normally 'logits', but for an autoregressive
  26. # dist this is more like a context/feature layer encoding the obs)
  27. context = tf.keras.layers.Dense(
  28. num_outputs,
  29. name="hidden",
  30. activation=tf.nn.tanh,
  31. kernel_initializer=normc_initializer(1.0))(obs_input)
  32. # V(s)
  33. value_out = tf.keras.layers.Dense(
  34. 1,
  35. name="value_out",
  36. activation=None,
  37. kernel_initializer=normc_initializer(0.01))(context)
  38. # P(a1 | obs)
  39. a1_logits = tf.keras.layers.Dense(
  40. 2,
  41. name="a1_logits",
  42. activation=None,
  43. kernel_initializer=normc_initializer(0.01))(ctx_input)
  44. # P(a2 | a1)
  45. # --note: typically you'd want to implement P(a2 | a1, obs) as follows:
  46. # a2_context = tf.keras.layers.Concatenate(axis=1)(
  47. # [ctx_input, a1_input])
  48. a2_context = a1_input
  49. a2_hidden = tf.keras.layers.Dense(
  50. 16,
  51. name="a2_hidden",
  52. activation=tf.nn.tanh,
  53. kernel_initializer=normc_initializer(1.0))(a2_context)
  54. a2_logits = tf.keras.layers.Dense(
  55. 2,
  56. name="a2_logits",
  57. activation=None,
  58. kernel_initializer=normc_initializer(0.01))(a2_hidden)
  59. # Base layers
  60. self.base_model = tf.keras.Model(obs_input, [context, value_out])
  61. self.base_model.summary()
  62. # Autoregressive action sampler
  63. self.action_model = tf.keras.Model([ctx_input, a1_input],
  64. [a1_logits, a2_logits])
  65. self.action_model.summary()
  66. def forward(self, input_dict, state, seq_lens):
  67. context, self._value_out = self.base_model(input_dict["obs"])
  68. return context, state
  69. def value_function(self):
  70. return tf.reshape(self._value_out, [-1])
  71. class TorchAutoregressiveActionModel(TorchModelV2, nn.Module):
  72. """PyTorch version of the AutoregressiveActionModel above."""
  73. def __init__(self, obs_space, action_space, num_outputs, model_config,
  74. name):
  75. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  76. model_config, name)
  77. nn.Module.__init__(self)
  78. if action_space != Tuple([Discrete(2), Discrete(2)]):
  79. raise ValueError(
  80. "This model only supports the [2, 2] action space")
  81. # Output of the model (normally 'logits', but for an autoregressive
  82. # dist this is more like a context/feature layer encoding the obs)
  83. self.context_layer = SlimFC(
  84. in_size=obs_space.shape[0],
  85. out_size=num_outputs,
  86. initializer=normc_init_torch(1.0),
  87. activation_fn=nn.Tanh,
  88. )
  89. # V(s)
  90. self.value_branch = SlimFC(
  91. in_size=num_outputs,
  92. out_size=1,
  93. initializer=normc_init_torch(0.01),
  94. activation_fn=None,
  95. )
  96. # P(a1 | obs)
  97. self.a1_logits = SlimFC(
  98. in_size=num_outputs,
  99. out_size=2,
  100. activation_fn=None,
  101. initializer=normc_init_torch(0.01))
  102. class _ActionModel(nn.Module):
  103. def __init__(self):
  104. nn.Module.__init__(self)
  105. self.a2_hidden = SlimFC(
  106. in_size=1,
  107. out_size=16,
  108. activation_fn=nn.Tanh,
  109. initializer=normc_init_torch(1.0))
  110. self.a2_logits = SlimFC(
  111. in_size=16,
  112. out_size=2,
  113. activation_fn=None,
  114. initializer=normc_init_torch(0.01))
  115. def forward(self_, ctx_input, a1_input):
  116. a1_logits = self.a1_logits(ctx_input)
  117. a2_logits = self_.a2_logits(self_.a2_hidden(a1_input))
  118. return a1_logits, a2_logits
  119. # P(a2 | a1)
  120. # --note: typically you'd want to implement P(a2 | a1, obs) as follows:
  121. # a2_context = tf.keras.layers.Concatenate(axis=1)(
  122. # [ctx_input, a1_input])
  123. self.action_module = _ActionModel()
  124. self._context = None
  125. def forward(self, input_dict, state, seq_lens):
  126. self._context = self.context_layer(input_dict["obs"])
  127. return self._context, state
  128. def value_function(self):
  129. return torch.reshape(self.value_branch(self._context), [-1])