import numpy as np from typing import Dict, List import gym from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.misc import normc_initializer, same_padding, \ SlimConv2d, SlimFC from ray.rllib.models.utils import get_activation_fn, get_filter_config from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import ModelConfigDict, TensorType torch, nn = try_import_torch() class VisionNetwork(TorchModelV2, nn.Module): """Generic vision network.""" def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int, model_config: ModelConfigDict, name: str): if not model_config.get("conv_filters"): model_config["conv_filters"] = get_filter_config(obs_space.shape) TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) activation = self.model_config.get("conv_activation") filters = self.model_config["conv_filters"] assert len(filters) > 0,\ "Must provide at least 1 entry in `conv_filters`!" # Post FC net config. post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", []) post_fcnet_activation = get_activation_fn( model_config.get("post_fcnet_activation"), framework="torch") no_final_linear = self.model_config.get("no_final_linear") vf_share_layers = self.model_config.get("vf_share_layers") # Whether the last layer is the output of a Flattened (rather than # a n x (1,1) Conv2D). self.last_layer_is_flattened = False self._logits = None layers = [] (w, h, in_channels) = obs_space.shape in_size = [w, h] for out_channels, kernel, stride in filters[:-1]: padding, out_size = same_padding(in_size, kernel, stride) layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, padding, activation_fn=activation)) in_channels = out_channels in_size = out_size out_channels, kernel, stride = filters[-1] # No final linear: Last layer has activation function and exits with # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending # on `post_fcnet_...` settings). if no_final_linear and num_outputs: out_channels = out_channels if post_fcnet_hiddens else num_outputs layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, None, # padding=valid activation_fn=activation)) # Add (optional) post-fc-stack after last Conv2D layer. layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs] if post_fcnet_hiddens else []) for i, out_size in enumerate(layer_sizes): layers.append( SlimFC( in_size=out_channels, out_size=out_size, activation_fn=post_fcnet_activation, initializer=normc_initializer(1.0))) out_channels = out_size # Finish network normally (w/o overriding last layer size with # `num_outputs`), then add another linear one of size `num_outputs`. else: layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, None, # padding=valid activation_fn=activation)) # num_outputs defined. Use that to create an exact # `num_output`-sized (1,1)-Conv2D. if num_outputs: in_size = [ np.ceil((in_size[0] - kernel[0]) / stride), np.ceil((in_size[1] - kernel[1]) / stride) ] padding, _ = same_padding(in_size, [1, 1], [1, 1]) if post_fcnet_hiddens: layers.append(nn.Flatten()) in_size = out_channels # Add (optional) post-fc-stack after last Conv2D layer. for i, out_size in enumerate(post_fcnet_hiddens + [num_outputs]): layers.append( SlimFC( in_size=in_size, out_size=out_size, activation_fn=post_fcnet_activation if i < len(post_fcnet_hiddens) - 1 else None, initializer=normc_initializer(1.0))) in_size = out_size # Last layer is logits layer. self._logits = layers.pop() else: self._logits = SlimConv2d( out_channels, num_outputs, [1, 1], 1, padding, activation_fn=None) # num_outputs not known -> Flatten, then set self.num_outputs # to the resulting number of nodes. else: self.last_layer_is_flattened = True layers.append(nn.Flatten()) self._convs = nn.Sequential(*layers) # If our num_outputs still unknown, we need to do a test pass to # figure out the output dimensions. This could be the case, if we have # the Flatten layer at the end. if self.num_outputs is None: # Create a B=1 dummy sample and push it through out conv-net. dummy_in = torch.from_numpy(self.obs_space.sample()).permute( 2, 0, 1).unsqueeze(0).float() dummy_out = self._convs(dummy_in) self.num_outputs = dummy_out.shape[1] # Build the value layers self._value_branch_separate = self._value_branch = None if vf_share_layers: self._value_branch = SlimFC( out_channels, 1, initializer=normc_initializer(0.01), activation_fn=None) else: vf_layers = [] (w, h, in_channels) = obs_space.shape in_size = [w, h] for out_channels, kernel, stride in filters[:-1]: padding, out_size = same_padding(in_size, kernel, stride) vf_layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, padding, activation_fn=activation)) in_channels = out_channels in_size = out_size out_channels, kernel, stride = filters[-1] vf_layers.append( SlimConv2d( in_channels, out_channels, kernel, stride, None, activation_fn=activation)) vf_layers.append( SlimConv2d( in_channels=out_channels, out_channels=1, kernel=1, stride=1, padding=None, activation_fn=None)) self._value_branch_separate = nn.Sequential(*vf_layers) # Holds the current "base" output (before logits layer). self._features = None @override(TorchModelV2) def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): self._features = input_dict["obs"].float() # Permuate b/c data comes in as [B, dim, dim, channels]: self._features = self._features.permute(0, 3, 1, 2) conv_out = self._convs(self._features) # Store features to save forward pass when getting value_function out. if not self._value_branch_separate: self._features = conv_out if not self.last_layer_is_flattened: if self._logits: conv_out = self._logits(conv_out) if len(conv_out.shape) == 4: if conv_out.shape[2] != 1 or conv_out.shape[3] != 1: raise ValueError( "Given `conv_filters` ({}) do not result in a [B, {} " "(`num_outputs`), 1, 1] shape (but in {})! Please " "adjust your Conv2D stack such that the last 2 dims " "are both 1.".format(self.model_config["conv_filters"], self.num_outputs, list(conv_out.shape))) logits = conv_out.squeeze(3) logits = logits.squeeze(2) else: logits = conv_out return logits, state else: return conv_out, state @override(TorchModelV2) def value_function(self) -> TensorType: assert self._features is not None, "must call forward() first" if self._value_branch_separate: value = self._value_branch_separate(self._features) value = value.squeeze(3) value = value.squeeze(2) return value.squeeze(1) else: if not self.last_layer_is_flattened: features = self._features.squeeze(3) features = features.squeeze(2) else: features = self._features return self._value_branch(features).squeeze(1) def _hidden_layers(self, obs: TensorType) -> TensorType: res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major res = res.squeeze(3) res = res.squeeze(2) return res