123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import logging
- import numpy as np
- import time
- from ray.rllib.models.jax.jax_modelv2 import JAXModelV2
- from ray.rllib.models.jax.misc import SlimFC
- from ray.rllib.utils.annotations import override
- from ray.rllib.utils.framework import try_import_jax
- jax, flax = try_import_jax()
- logger = logging.getLogger(__name__)
- class FullyConnectedNetwork(JAXModelV2):
- """Generic fully connected network."""
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- super().__init__(obs_space, action_space, num_outputs, model_config,
- name)
- self.key = jax.random.PRNGKey(int(time.time()))
- activation = model_config.get("fcnet_activation")
- hiddens = model_config.get("fcnet_hiddens", [])
- no_final_linear = model_config.get("no_final_linear")
- self.vf_share_layers = model_config.get("vf_share_layers")
- self.free_log_std = model_config.get("free_log_std")
- # Generate free-floating bias variables for the second half of
- # the outputs.
- if self.free_log_std:
- assert num_outputs % 2 == 0, (
- "num_outputs must be divisible by two", num_outputs)
- num_outputs = num_outputs // 2
- self._hidden_layers = []
- prev_layer_size = int(np.product(obs_space.shape))
- self._logits = None
- # Create layers 0 to second-last.
- for size in hiddens[:-1]:
- self._hidden_layers.append(
- SlimFC(
- in_size=prev_layer_size,
- out_size=size,
- activation_fn=activation))
- prev_layer_size = size
- # The last layer is adjusted to be of size num_outputs, but it's a
- # layer with activation.
- if no_final_linear and num_outputs:
- self._hidden_layers.append(
- SlimFC(
- in_size=prev_layer_size,
- out_size=num_outputs,
- activation_fn=activation))
- prev_layer_size = num_outputs
- # Finish the layers with the provided sizes (`hiddens`), plus -
- # iff num_outputs > 0 - a last linear layer of size num_outputs.
- else:
- if len(hiddens) > 0:
- self._hidden_layers.append(
- SlimFC(
- in_size=prev_layer_size,
- out_size=hiddens[-1],
- activation_fn=activation))
- prev_layer_size = hiddens[-1]
- if num_outputs:
- self._logits = SlimFC(
- in_size=prev_layer_size,
- out_size=num_outputs,
- activation_fn=None)
- else:
- self.num_outputs = (
- [int(np.product(obs_space.shape))] + hiddens[-1:])[-1]
- # Layer to add the log std vars to the state-dependent means.
- if self.free_log_std and self._logits:
- raise ValueError("`free_log_std` not supported for JAX yet!")
- self._value_branch_separate = None
- if not self.vf_share_layers:
- # Build a parallel set of hidden layers for the value net.
- prev_vf_layer_size = int(np.product(obs_space.shape))
- vf_layers = []
- for size in hiddens:
- vf_layers.append(
- SlimFC(
- in_size=prev_vf_layer_size,
- out_size=size,
- activation_fn=activation,
- ))
- prev_vf_layer_size = size
- self._value_branch_separate = vf_layers
- self._value_branch = SlimFC(
- in_size=prev_layer_size, out_size=1, activation_fn=None)
- # Holds the current "base" output (before logits layer).
- self._features = None
- # Holds the last input, in case value branch is separate.
- self._last_flat_in = None
- @override(JAXModelV2)
- def forward(self, input_dict, state, seq_lens):
- self._last_flat_in = input_dict["obs_flat"]
- x = self._last_flat_in
- for layer in self._hidden_layers:
- x = layer(x)
- self._features = x
- logits = self._logits(self._features) if self._logits else \
- self._features
- if self.free_log_std:
- logits = self._append_free_log_std(logits)
- return logits, state
- @override(JAXModelV2)
- def value_function(self):
- assert self._features is not None, "must call forward() first"
- if self._value_branch_separate:
- return self._value_branch(
- self._value_branch_separate(self._last_flat_in)).squeeze(1)
- else:
- return self._value_branch(self._features).squeeze(1)
|