dqn_torch_model.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """PyTorch model for DQN"""
  2. from typing import Sequence
  3. import gymnasium as gym
  4. from ray.rllib.models.torch.misc import SlimFC
  5. from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer
  6. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  7. from ray.rllib.utils.framework import try_import_torch
  8. from ray.rllib.utils.typing import ModelConfigDict
  9. torch, nn = try_import_torch()
  10. class DQNTorchModel(TorchModelV2, nn.Module):
  11. """Extension of standard TorchModelV2 to provide dueling-Q functionality."""
  12. def __init__(
  13. self,
  14. obs_space: gym.spaces.Space,
  15. action_space: gym.spaces.Space,
  16. num_outputs: int,
  17. model_config: ModelConfigDict,
  18. name: str,
  19. *,
  20. q_hiddens: Sequence[int] = (256,),
  21. dueling: bool = False,
  22. dueling_activation: str = "relu",
  23. num_atoms: int = 1,
  24. use_noisy: bool = False,
  25. v_min: float = -10.0,
  26. v_max: float = 10.0,
  27. sigma0: float = 0.5,
  28. # TODO(sven): Move `add_layer_norm` into ModelCatalog as
  29. # generic option, then error if we use ParameterNoise as
  30. # Exploration type and do not have any LayerNorm layers in
  31. # the net.
  32. add_layer_norm: bool = False
  33. ):
  34. """Initialize variables of this model.
  35. Extra model kwargs:
  36. q_hiddens (Sequence[int]): List of layer-sizes after(!) the
  37. Advantages(A)/Value(V)-split. Hence, each of the A- and V-
  38. branches will have this structure of Dense layers. To define
  39. the NN before this A/V-split, use - as always -
  40. config["model"]["fcnet_hiddens"].
  41. dueling: Whether to build the advantage(A)/value(V) heads
  42. for DDQN. If True, Q-values are calculated as:
  43. Q = (A - mean[A]) + V. If False, raw NN output is interpreted
  44. as Q-values.
  45. dueling_activation: The activation to use for all dueling
  46. layers (A- and V-branch). One of "relu", "tanh", "linear".
  47. num_atoms: If >1, enables distributional DQN.
  48. use_noisy: Use noisy layers.
  49. v_min: Min value support for distributional DQN.
  50. v_max: Max value support for distributional DQN.
  51. sigma0 (float): Initial value of noisy layers.
  52. add_layer_norm: Enable layer norm (for param noise).
  53. """
  54. nn.Module.__init__(self)
  55. super(DQNTorchModel, self).__init__(
  56. obs_space, action_space, num_outputs, model_config, name
  57. )
  58. self.dueling = dueling
  59. self.num_atoms = num_atoms
  60. self.v_min = v_min
  61. self.v_max = v_max
  62. self.sigma0 = sigma0
  63. ins = num_outputs
  64. advantage_module = nn.Sequential()
  65. value_module = nn.Sequential()
  66. # Dueling case: Build the shared (advantages and value) fc-network.
  67. for i, n in enumerate(q_hiddens):
  68. if use_noisy:
  69. advantage_module.add_module(
  70. "dueling_A_{}".format(i),
  71. NoisyLayer(
  72. ins, n, sigma0=self.sigma0, activation=dueling_activation
  73. ),
  74. )
  75. value_module.add_module(
  76. "dueling_V_{}".format(i),
  77. NoisyLayer(
  78. ins, n, sigma0=self.sigma0, activation=dueling_activation
  79. ),
  80. )
  81. else:
  82. advantage_module.add_module(
  83. "dueling_A_{}".format(i),
  84. SlimFC(ins, n, activation_fn=dueling_activation),
  85. )
  86. value_module.add_module(
  87. "dueling_V_{}".format(i),
  88. SlimFC(ins, n, activation_fn=dueling_activation),
  89. )
  90. # Add LayerNorm after each Dense.
  91. if add_layer_norm:
  92. advantage_module.add_module(
  93. "LayerNorm_A_{}".format(i), nn.LayerNorm(n)
  94. )
  95. value_module.add_module("LayerNorm_V_{}".format(i), nn.LayerNorm(n))
  96. ins = n
  97. # Actual Advantages layer (nodes=num-actions).
  98. if use_noisy:
  99. advantage_module.add_module(
  100. "A",
  101. NoisyLayer(
  102. ins, self.action_space.n * self.num_atoms, sigma0, activation=None
  103. ),
  104. )
  105. elif q_hiddens:
  106. advantage_module.add_module(
  107. "A", SlimFC(ins, action_space.n * self.num_atoms, activation_fn=None)
  108. )
  109. self.advantage_module = advantage_module
  110. # Value layer (nodes=1).
  111. if self.dueling:
  112. if use_noisy:
  113. value_module.add_module(
  114. "V", NoisyLayer(ins, self.num_atoms, sigma0, activation=None)
  115. )
  116. elif q_hiddens:
  117. value_module.add_module(
  118. "V", SlimFC(ins, self.num_atoms, activation_fn=None)
  119. )
  120. self.value_module = value_module
  121. def get_q_value_distributions(self, model_out):
  122. """Returns distributional values for Q(s, a) given a state embedding.
  123. Override this in your custom model to customize the Q output head.
  124. Args:
  125. model_out: Embedding from the model layers.
  126. Returns:
  127. (action_scores, logits, dist) if num_atoms == 1, otherwise
  128. (action_scores, z, support_logits_per_action, logits, dist)
  129. """
  130. action_scores = self.advantage_module(model_out)
  131. if self.num_atoms > 1:
  132. # Distributional Q-learning uses a discrete support z
  133. # to represent the action value distribution
  134. z = torch.arange(0.0, self.num_atoms, dtype=torch.float32).to(
  135. action_scores.device
  136. )
  137. z = self.v_min + z * (self.v_max - self.v_min) / float(self.num_atoms - 1)
  138. support_logits_per_action = torch.reshape(
  139. action_scores, shape=(-1, self.action_space.n, self.num_atoms)
  140. )
  141. support_prob_per_action = nn.functional.softmax(
  142. support_logits_per_action, dim=-1
  143. )
  144. action_scores = torch.sum(z * support_prob_per_action, dim=-1)
  145. logits = support_logits_per_action
  146. probs = support_prob_per_action
  147. return action_scores, z, support_logits_per_action, logits, probs
  148. else:
  149. logits = torch.unsqueeze(torch.ones_like(action_scores), -1)
  150. return action_scores, logits, logits
  151. def get_state_value(self, model_out):
  152. """Returns the state value prediction for the given state embedding."""
  153. return self.value_module(model_out)