visionnet.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import numpy as np
  2. from typing import Dict, List
  3. import gym
  4. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  5. from ray.rllib.models.torch.misc import normc_initializer, same_padding, \
  6. SlimConv2d, SlimFC
  7. from ray.rllib.models.utils import get_activation_fn, get_filter_config
  8. from ray.rllib.utils.annotations import override
  9. from ray.rllib.utils.framework import try_import_torch
  10. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  11. torch, nn = try_import_torch()
  12. class VisionNetwork(TorchModelV2, nn.Module):
  13. """Generic vision network."""
  14. def __init__(self, obs_space: gym.spaces.Space,
  15. action_space: gym.spaces.Space, num_outputs: int,
  16. model_config: ModelConfigDict, name: str):
  17. if not model_config.get("conv_filters"):
  18. model_config["conv_filters"] = get_filter_config(obs_space.shape)
  19. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  20. model_config, name)
  21. nn.Module.__init__(self)
  22. activation = self.model_config.get("conv_activation")
  23. filters = self.model_config["conv_filters"]
  24. assert len(filters) > 0,\
  25. "Must provide at least 1 entry in `conv_filters`!"
  26. # Post FC net config.
  27. post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", [])
  28. post_fcnet_activation = get_activation_fn(
  29. model_config.get("post_fcnet_activation"), framework="torch")
  30. no_final_linear = self.model_config.get("no_final_linear")
  31. vf_share_layers = self.model_config.get("vf_share_layers")
  32. # Whether the last layer is the output of a Flattened (rather than
  33. # a n x (1,1) Conv2D).
  34. self.last_layer_is_flattened = False
  35. self._logits = None
  36. layers = []
  37. (w, h, in_channels) = obs_space.shape
  38. in_size = [w, h]
  39. for out_channels, kernel, stride in filters[:-1]:
  40. padding, out_size = same_padding(in_size, kernel, stride)
  41. layers.append(
  42. SlimConv2d(
  43. in_channels,
  44. out_channels,
  45. kernel,
  46. stride,
  47. padding,
  48. activation_fn=activation))
  49. in_channels = out_channels
  50. in_size = out_size
  51. out_channels, kernel, stride = filters[-1]
  52. # No final linear: Last layer has activation function and exits with
  53. # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending
  54. # on `post_fcnet_...` settings).
  55. if no_final_linear and num_outputs:
  56. out_channels = out_channels if post_fcnet_hiddens else num_outputs
  57. layers.append(
  58. SlimConv2d(
  59. in_channels,
  60. out_channels,
  61. kernel,
  62. stride,
  63. None, # padding=valid
  64. activation_fn=activation))
  65. # Add (optional) post-fc-stack after last Conv2D layer.
  66. layer_sizes = post_fcnet_hiddens[:-1] + ([num_outputs]
  67. if post_fcnet_hiddens else
  68. [])
  69. for i, out_size in enumerate(layer_sizes):
  70. layers.append(
  71. SlimFC(
  72. in_size=out_channels,
  73. out_size=out_size,
  74. activation_fn=post_fcnet_activation,
  75. initializer=normc_initializer(1.0)))
  76. out_channels = out_size
  77. # Finish network normally (w/o overriding last layer size with
  78. # `num_outputs`), then add another linear one of size `num_outputs`.
  79. else:
  80. layers.append(
  81. SlimConv2d(
  82. in_channels,
  83. out_channels,
  84. kernel,
  85. stride,
  86. None, # padding=valid
  87. activation_fn=activation))
  88. # num_outputs defined. Use that to create an exact
  89. # `num_output`-sized (1,1)-Conv2D.
  90. if num_outputs:
  91. in_size = [
  92. np.ceil((in_size[0] - kernel[0]) / stride),
  93. np.ceil((in_size[1] - kernel[1]) / stride)
  94. ]
  95. padding, _ = same_padding(in_size, [1, 1], [1, 1])
  96. if post_fcnet_hiddens:
  97. layers.append(nn.Flatten())
  98. in_size = out_channels
  99. # Add (optional) post-fc-stack after last Conv2D layer.
  100. for i, out_size in enumerate(post_fcnet_hiddens +
  101. [num_outputs]):
  102. layers.append(
  103. SlimFC(
  104. in_size=in_size,
  105. out_size=out_size,
  106. activation_fn=post_fcnet_activation
  107. if i < len(post_fcnet_hiddens) - 1 else None,
  108. initializer=normc_initializer(1.0)))
  109. in_size = out_size
  110. # Last layer is logits layer.
  111. self._logits = layers.pop()
  112. else:
  113. self._logits = SlimConv2d(
  114. out_channels,
  115. num_outputs, [1, 1],
  116. 1,
  117. padding,
  118. activation_fn=None)
  119. # num_outputs not known -> Flatten, then set self.num_outputs
  120. # to the resulting number of nodes.
  121. else:
  122. self.last_layer_is_flattened = True
  123. layers.append(nn.Flatten())
  124. self._convs = nn.Sequential(*layers)
  125. # If our num_outputs still unknown, we need to do a test pass to
  126. # figure out the output dimensions. This could be the case, if we have
  127. # the Flatten layer at the end.
  128. if self.num_outputs is None:
  129. # Create a B=1 dummy sample and push it through out conv-net.
  130. dummy_in = torch.from_numpy(self.obs_space.sample()).permute(
  131. 2, 0, 1).unsqueeze(0).float()
  132. dummy_out = self._convs(dummy_in)
  133. self.num_outputs = dummy_out.shape[1]
  134. # Build the value layers
  135. self._value_branch_separate = self._value_branch = None
  136. if vf_share_layers:
  137. self._value_branch = SlimFC(
  138. out_channels,
  139. 1,
  140. initializer=normc_initializer(0.01),
  141. activation_fn=None)
  142. else:
  143. vf_layers = []
  144. (w, h, in_channels) = obs_space.shape
  145. in_size = [w, h]
  146. for out_channels, kernel, stride in filters[:-1]:
  147. padding, out_size = same_padding(in_size, kernel, stride)
  148. vf_layers.append(
  149. SlimConv2d(
  150. in_channels,
  151. out_channels,
  152. kernel,
  153. stride,
  154. padding,
  155. activation_fn=activation))
  156. in_channels = out_channels
  157. in_size = out_size
  158. out_channels, kernel, stride = filters[-1]
  159. vf_layers.append(
  160. SlimConv2d(
  161. in_channels,
  162. out_channels,
  163. kernel,
  164. stride,
  165. None,
  166. activation_fn=activation))
  167. vf_layers.append(
  168. SlimConv2d(
  169. in_channels=out_channels,
  170. out_channels=1,
  171. kernel=1,
  172. stride=1,
  173. padding=None,
  174. activation_fn=None))
  175. self._value_branch_separate = nn.Sequential(*vf_layers)
  176. # Holds the current "base" output (before logits layer).
  177. self._features = None
  178. @override(TorchModelV2)
  179. def forward(self, input_dict: Dict[str, TensorType],
  180. state: List[TensorType],
  181. seq_lens: TensorType) -> (TensorType, List[TensorType]):
  182. self._features = input_dict["obs"].float()
  183. # Permuate b/c data comes in as [B, dim, dim, channels]:
  184. self._features = self._features.permute(0, 3, 1, 2)
  185. conv_out = self._convs(self._features)
  186. # Store features to save forward pass when getting value_function out.
  187. if not self._value_branch_separate:
  188. self._features = conv_out
  189. if not self.last_layer_is_flattened:
  190. if self._logits:
  191. conv_out = self._logits(conv_out)
  192. if len(conv_out.shape) == 4:
  193. if conv_out.shape[2] != 1 or conv_out.shape[3] != 1:
  194. raise ValueError(
  195. "Given `conv_filters` ({}) do not result in a [B, {} "
  196. "(`num_outputs`), 1, 1] shape (but in {})! Please "
  197. "adjust your Conv2D stack such that the last 2 dims "
  198. "are both 1.".format(self.model_config["conv_filters"],
  199. self.num_outputs,
  200. list(conv_out.shape)))
  201. logits = conv_out.squeeze(3)
  202. logits = logits.squeeze(2)
  203. else:
  204. logits = conv_out
  205. return logits, state
  206. else:
  207. return conv_out, state
  208. @override(TorchModelV2)
  209. def value_function(self) -> TensorType:
  210. assert self._features is not None, "must call forward() first"
  211. if self._value_branch_separate:
  212. value = self._value_branch_separate(self._features)
  213. value = value.squeeze(3)
  214. value = value.squeeze(2)
  215. return value.squeeze(1)
  216. else:
  217. if not self.last_layer_is_flattened:
  218. features = self._features.squeeze(3)
  219. features = features.squeeze(2)
  220. else:
  221. features = self._features
  222. return self._value_branch(features).squeeze(1)
  223. def _hidden_layers(self, obs: TensorType) -> TensorType:
  224. res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
  225. res = res.squeeze(3)
  226. res = res.squeeze(2)
  227. return res