fcnet.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import logging
  2. import numpy as np
  3. import time
  4. from ray.rllib.models.jax.jax_modelv2 import JAXModelV2
  5. from ray.rllib.models.jax.misc import SlimFC
  6. from ray.rllib.utils.annotations import override
  7. from ray.rllib.utils.framework import try_import_jax
  8. jax, flax = try_import_jax()
  9. logger = logging.getLogger(__name__)
  10. class FullyConnectedNetwork(JAXModelV2):
  11. """Generic fully connected network."""
  12. def __init__(self, obs_space, action_space, num_outputs, model_config,
  13. name):
  14. super().__init__(obs_space, action_space, num_outputs, model_config,
  15. name)
  16. self.key = jax.random.PRNGKey(int(time.time()))
  17. activation = model_config.get("fcnet_activation")
  18. hiddens = model_config.get("fcnet_hiddens", [])
  19. no_final_linear = model_config.get("no_final_linear")
  20. self.vf_share_layers = model_config.get("vf_share_layers")
  21. self.free_log_std = model_config.get("free_log_std")
  22. # Generate free-floating bias variables for the second half of
  23. # the outputs.
  24. if self.free_log_std:
  25. assert num_outputs % 2 == 0, (
  26. "num_outputs must be divisible by two", num_outputs)
  27. num_outputs = num_outputs // 2
  28. self._hidden_layers = []
  29. prev_layer_size = int(np.product(obs_space.shape))
  30. self._logits = None
  31. # Create layers 0 to second-last.
  32. for size in hiddens[:-1]:
  33. self._hidden_layers.append(
  34. SlimFC(
  35. in_size=prev_layer_size,
  36. out_size=size,
  37. activation_fn=activation))
  38. prev_layer_size = size
  39. # The last layer is adjusted to be of size num_outputs, but it's a
  40. # layer with activation.
  41. if no_final_linear and num_outputs:
  42. self._hidden_layers.append(
  43. SlimFC(
  44. in_size=prev_layer_size,
  45. out_size=num_outputs,
  46. activation_fn=activation))
  47. prev_layer_size = num_outputs
  48. # Finish the layers with the provided sizes (`hiddens`), plus -
  49. # iff num_outputs > 0 - a last linear layer of size num_outputs.
  50. else:
  51. if len(hiddens) > 0:
  52. self._hidden_layers.append(
  53. SlimFC(
  54. in_size=prev_layer_size,
  55. out_size=hiddens[-1],
  56. activation_fn=activation))
  57. prev_layer_size = hiddens[-1]
  58. if num_outputs:
  59. self._logits = SlimFC(
  60. in_size=prev_layer_size,
  61. out_size=num_outputs,
  62. activation_fn=None)
  63. else:
  64. self.num_outputs = (
  65. [int(np.product(obs_space.shape))] + hiddens[-1:])[-1]
  66. # Layer to add the log std vars to the state-dependent means.
  67. if self.free_log_std and self._logits:
  68. raise ValueError("`free_log_std` not supported for JAX yet!")
  69. self._value_branch_separate = None
  70. if not self.vf_share_layers:
  71. # Build a parallel set of hidden layers for the value net.
  72. prev_vf_layer_size = int(np.product(obs_space.shape))
  73. vf_layers = []
  74. for size in hiddens:
  75. vf_layers.append(
  76. SlimFC(
  77. in_size=prev_vf_layer_size,
  78. out_size=size,
  79. activation_fn=activation,
  80. ))
  81. prev_vf_layer_size = size
  82. self._value_branch_separate = vf_layers
  83. self._value_branch = SlimFC(
  84. in_size=prev_layer_size, out_size=1, activation_fn=None)
  85. # Holds the current "base" output (before logits layer).
  86. self._features = None
  87. # Holds the last input, in case value branch is separate.
  88. self._last_flat_in = None
  89. @override(JAXModelV2)
  90. def forward(self, input_dict, state, seq_lens):
  91. self._last_flat_in = input_dict["obs_flat"]
  92. x = self._last_flat_in
  93. for layer in self._hidden_layers:
  94. x = layer(x)
  95. self._features = x
  96. logits = self._logits(self._features) if self._logits else \
  97. self._features
  98. if self.free_log_std:
  99. logits = self._append_free_log_std(logits)
  100. return logits, state
  101. @override(JAXModelV2)
  102. def value_function(self):
  103. assert self._features is not None, "must call forward() first"
  104. if self._value_branch_separate:
  105. return self._value_branch(
  106. self._value_branch_separate(self._last_flat_in)).squeeze(1)
  107. else:
  108. return self._value_branch(self._features).squeeze(1)