parametric_actions_model.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from gym.spaces import Box
  2. from ray.rllib.agents.dqn.distributional_q_tf_model import \
  3. DistributionalQTFModel
  4. from ray.rllib.agents.dqn.dqn_torch_model import \
  5. DQNTorchModel
  6. from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
  7. from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
  8. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  9. from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX
  10. tf1, tf, tfv = try_import_tf()
  11. torch, nn = try_import_torch()
  12. class ParametricActionsModel(DistributionalQTFModel):
  13. """Parametric action model that handles the dot product and masking.
  14. This assumes the outputs are logits for a single Categorical action dist.
  15. Getting this to work with a more complex output (e.g., if the action space
  16. is a tuple of several distributions) is also possible but left as an
  17. exercise to the reader.
  18. """
  19. def __init__(self,
  20. obs_space,
  21. action_space,
  22. num_outputs,
  23. model_config,
  24. name,
  25. true_obs_shape=(4, ),
  26. action_embed_size=2,
  27. **kw):
  28. super(ParametricActionsModel, self).__init__(
  29. obs_space, action_space, num_outputs, model_config, name, **kw)
  30. self.action_embed_model = FullyConnectedNetwork(
  31. Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
  32. model_config, name + "_action_embed")
  33. def forward(self, input_dict, state, seq_lens):
  34. # Extract the available actions tensor from the observation.
  35. avail_actions = input_dict["obs"]["avail_actions"]
  36. action_mask = input_dict["obs"]["action_mask"]
  37. # Compute the predicted action embedding
  38. action_embed, _ = self.action_embed_model({
  39. "obs": input_dict["obs"]["cart"]
  40. })
  41. # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
  42. # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
  43. intent_vector = tf.expand_dims(action_embed, 1)
  44. # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
  45. action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=2)
  46. # Mask out invalid actions (use tf.float32.min for stability)
  47. inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
  48. return action_logits + inf_mask, state
  49. def value_function(self):
  50. return self.action_embed_model.value_function()
  51. class TorchParametricActionsModel(DQNTorchModel):
  52. """PyTorch version of above ParametricActionsModel."""
  53. def __init__(self,
  54. obs_space,
  55. action_space,
  56. num_outputs,
  57. model_config,
  58. name,
  59. true_obs_shape=(4, ),
  60. action_embed_size=2,
  61. **kw):
  62. DQNTorchModel.__init__(self, obs_space, action_space, num_outputs,
  63. model_config, name, **kw)
  64. self.action_embed_model = TorchFC(
  65. Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
  66. model_config, name + "_action_embed")
  67. def forward(self, input_dict, state, seq_lens):
  68. # Extract the available actions tensor from the observation.
  69. avail_actions = input_dict["obs"]["avail_actions"]
  70. action_mask = input_dict["obs"]["action_mask"]
  71. # Compute the predicted action embedding
  72. action_embed, _ = self.action_embed_model({
  73. "obs": input_dict["obs"]["cart"]
  74. })
  75. # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
  76. # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
  77. intent_vector = torch.unsqueeze(action_embed, 1)
  78. # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
  79. action_logits = torch.sum(avail_actions * intent_vector, dim=2)
  80. # Mask out invalid actions (use -inf to tag invalid).
  81. # These are then recognized by the EpsilonGreedy exploration component
  82. # as invalid actions that are not to be chosen.
  83. inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
  84. return action_logits + inf_mask, state
  85. def value_function(self):
  86. return self.action_embed_model.value_function()
  87. class ParametricActionsModelThatLearnsEmbeddings(DistributionalQTFModel):
  88. """Same as the above ParametricActionsModel.
  89. However, this version also learns the action embeddings.
  90. """
  91. def __init__(self,
  92. obs_space,
  93. action_space,
  94. num_outputs,
  95. model_config,
  96. name,
  97. true_obs_shape=(4, ),
  98. action_embed_size=2,
  99. **kw):
  100. super(ParametricActionsModelThatLearnsEmbeddings, self).__init__(
  101. obs_space, action_space, num_outputs, model_config, name, **kw)
  102. action_ids_shifted = tf.constant(
  103. list(range(1, num_outputs + 1)), dtype=tf.float32)
  104. obs_cart = tf.keras.layers.Input(shape=true_obs_shape, name="obs_cart")
  105. valid_avail_actions_mask = tf.keras.layers.Input(
  106. shape=(num_outputs), name="valid_avail_actions_mask")
  107. self.pred_action_embed_model = FullyConnectedNetwork(
  108. Box(-1, 1, shape=true_obs_shape), action_space, action_embed_size,
  109. model_config, name + "_pred_action_embed")
  110. # Compute the predicted action embedding
  111. pred_action_embed, _ = self.pred_action_embed_model({"obs": obs_cart})
  112. _value_out = self.pred_action_embed_model.value_function()
  113. # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
  114. # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
  115. intent_vector = tf.expand_dims(pred_action_embed, 1)
  116. valid_avail_actions = action_ids_shifted * valid_avail_actions_mask
  117. # Embedding for valid available actions which will be learned.
  118. # Embedding vector for 0 is an invalid embedding (a "dummy embedding").
  119. valid_avail_actions_embed = tf.keras.layers.Embedding(
  120. input_dim=num_outputs + 1,
  121. output_dim=action_embed_size,
  122. name="action_embed_matrix")(valid_avail_actions)
  123. # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
  124. action_logits = tf.reduce_sum(
  125. valid_avail_actions_embed * intent_vector, axis=2)
  126. # Mask out invalid actions (use tf.float32.min for stability)
  127. inf_mask = tf.maximum(
  128. tf.math.log(valid_avail_actions_mask), tf.float32.min)
  129. action_logits = action_logits + inf_mask
  130. self.param_actions_model = tf.keras.Model(
  131. inputs=[obs_cart, valid_avail_actions_mask],
  132. outputs=[action_logits, _value_out])
  133. self.param_actions_model.summary()
  134. def forward(self, input_dict, state, seq_lens):
  135. # Extract the available actions mask tensor from the observation.
  136. valid_avail_actions_mask = input_dict["obs"][
  137. "valid_avail_actions_mask"]
  138. action_logits, self._value_out = self.param_actions_model(
  139. [input_dict["obs"]["cart"], valid_avail_actions_mask])
  140. return action_logits, state
  141. def value_function(self):
  142. return self._value_out