catalog.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. from functools import partial
  2. import gym
  3. from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
  4. import logging
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. from typing import List, Optional, Type, Union
  8. from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
  9. RLLIB_ACTION_DIST, _global_registry
  10. from ray.rllib.models.action_dist import ActionDistribution
  11. from ray.rllib.models.jax.jax_action_dist import JAXCategorical
  12. from ray.rllib.models.modelv2 import ModelV2
  13. from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
  14. from ray.rllib.models.tf.tf_action_dist import Categorical, \
  15. Deterministic, DiagGaussian, Dirichlet, \
  16. MultiActionDistribution, MultiCategorical
  17. from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
  18. TorchDeterministic, TorchDiagGaussian, \
  19. TorchMultiActionDistribution, TorchMultiCategorical
  20. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  21. from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \
  22. deprecation_warning
  23. from ray.rllib.utils.error import UnsupportedSpaceException
  24. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  25. from ray.rllib.utils.spaces.simplex import Simplex
  26. from ray.rllib.utils.spaces.space_utils import flatten_space
  27. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  28. tf1, tf, tfv = try_import_tf()
  29. torch, _ = try_import_torch()
  30. logger = logging.getLogger(__name__)
  31. # yapf: disable
  32. # __sphinx_doc_begin__
  33. MODEL_DEFAULTS: ModelConfigDict = {
  34. # Experimental flag.
  35. # If True, try to use a native (tf.keras.Model or torch.Module) default
  36. # model instead of our built-in ModelV2 defaults.
  37. # If False (default), use "classic" ModelV2 default models.
  38. # Note that this currently only works for:
  39. # 1) framework != torch AND
  40. # 2) fully connected and CNN default networks as well as
  41. # auto-wrapped LSTM- and attention nets.
  42. "_use_default_native_models": False,
  43. # Experimental flag.
  44. # If True, user specified no preprocessor to be created
  45. # (via config._disable_preprocessor_api=True). If True, observations
  46. # will arrive in model as they are returned by the env.
  47. "_disable_preprocessor_api": False,
  48. # Experimental flag.
  49. # If True, RLlib will no longer flatten the policy-computed actions into
  50. # a single tensor (for storage in SampleCollectors/output files/etc..),
  51. # but leave (possibly nested) actions as-is. Disabling flattening affects:
  52. # - SampleCollectors: Have to store possibly nested action structs.
  53. # - Models that have the previous action(s) as part of their input.
  54. # - Algorithms reading from offline files (incl. action information).
  55. "_disable_action_flattening": False,
  56. # === Built-in options ===
  57. # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
  58. # These are used if no custom model is specified and the input space is 1D.
  59. # Number of hidden layers to be used.
  60. "fcnet_hiddens": [256, 256],
  61. # Activation function descriptor.
  62. # Supported values are: "tanh", "relu", "swish" (or "silu"),
  63. # "linear" (or None).
  64. "fcnet_activation": "tanh",
  65. # VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
  66. # These are used if no custom model is specified and the input space is 2D.
  67. # Filter config: List of [out_channels, kernel, stride] for each filter.
  68. # Example:
  69. # Use None for making RLlib try to find a default filter setup given the
  70. # observation space.
  71. "conv_filters": None,
  72. # Activation function descriptor.
  73. # Supported values are: "tanh", "relu", "swish" (or "silu"),
  74. # "linear" (or None).
  75. "conv_activation": "relu",
  76. # Some default models support a final FC stack of n Dense layers with given
  77. # activation:
  78. # - Complex observation spaces: Image components are fed through
  79. # VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
  80. # everything is concated and pushed through this final FC stack.
  81. # - VisionNets (CNNs), e.g. after the CNN stack, there may be
  82. # additional Dense layers.
  83. # - FullyConnectedNetworks will have this additional FCStack as well
  84. # (that's why it's empty by default).
  85. "post_fcnet_hiddens": [],
  86. "post_fcnet_activation": "relu",
  87. # For DiagGaussian action distributions, make the second half of the model
  88. # outputs floating bias variables instead of state-dependent. This only
  89. # has an effect is using the default fully connected net.
  90. "free_log_std": False,
  91. # Whether to skip the final linear layer used to resize the hidden layer
  92. # outputs to size `num_outputs`. If True, then the last hidden layer
  93. # should already match num_outputs.
  94. "no_final_linear": False,
  95. # Whether layers should be shared for the value function.
  96. "vf_share_layers": True,
  97. # == LSTM ==
  98. # Whether to wrap the model with an LSTM.
  99. "use_lstm": False,
  100. # Max seq len for training the LSTM, defaults to 20.
  101. "max_seq_len": 20,
  102. # Size of the LSTM cell.
  103. "lstm_cell_size": 256,
  104. # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
  105. "lstm_use_prev_action": False,
  106. # Whether to feed r_{t-1} to LSTM.
  107. "lstm_use_prev_reward": False,
  108. # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
  109. "_time_major": False,
  110. # == Attention Nets (experimental: torch-version is untested) ==
  111. # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
  112. # wrapper Model around the default Model.
  113. "use_attention": False,
  114. # The number of transformer units within GTrXL.
  115. # A transformer unit in GTrXL consists of a) MultiHeadAttention module and
  116. # b) a position-wise MLP.
  117. "attention_num_transformer_units": 1,
  118. # The input and output size of each transformer unit.
  119. "attention_dim": 64,
  120. # The number of attention heads within the MultiHeadAttention units.
  121. "attention_num_heads": 1,
  122. # The dim of a single head (within the MultiHeadAttention units).
  123. "attention_head_dim": 32,
  124. # The memory sizes for inference and training.
  125. "attention_memory_inference": 50,
  126. "attention_memory_training": 50,
  127. # The output dim of the position-wise MLP.
  128. "attention_position_wise_mlp_dim": 32,
  129. # The initial bias values for the 2 GRU gates within a transformer unit.
  130. "attention_init_gru_gate_bias": 2.0,
  131. # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
  132. "attention_use_n_prev_actions": 0,
  133. # Whether to feed r_{t-n:t-1} to GTrXL.
  134. "attention_use_n_prev_rewards": 0,
  135. # == Atari ==
  136. # Set to True to enable 4x stacking behavior.
  137. "framestack": True,
  138. # Final resized frame dimension
  139. "dim": 84,
  140. # (deprecated) Converts ATARI frame to 1 Channel Grayscale image
  141. "grayscale": False,
  142. # (deprecated) Changes frame to range from [-1, 1] if true
  143. "zero_mean": True,
  144. # === Options for custom models ===
  145. # Name of a custom model to use
  146. "custom_model": None,
  147. # Extra options to pass to the custom classes. These will be available to
  148. # the Model's constructor in the model_config field. Also, they will be
  149. # attempted to be passed as **kwargs to ModelV2 models. For an example,
  150. # see rllib/models/[tf|torch]/attention_net.py.
  151. "custom_model_config": {},
  152. # Name of a custom action distribution to use.
  153. "custom_action_dist": None,
  154. # Custom preprocessors are deprecated. Please use a wrapper class around
  155. # your environment instead to preprocess observations.
  156. "custom_preprocessor": None,
  157. # Deprecated keys:
  158. # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
  159. "lstm_use_prev_action_reward": DEPRECATED_VALUE,
  160. }
  161. # __sphinx_doc_end__
  162. # yapf: enable
  163. @PublicAPI
  164. class ModelCatalog:
  165. """Registry of models, preprocessors, and action distributions for envs.
  166. Examples:
  167. >>> prep = ModelCatalog.get_preprocessor(env)
  168. >>> observation = prep.transform(raw_observation)
  169. >>> dist_class, dist_dim = ModelCatalog.get_action_dist(
  170. ... env.action_space, {})
  171. >>> model = ModelCatalog.get_model_v2(
  172. ... obs_space, action_space, num_outputs, options)
  173. >>> dist = dist_class(model.outputs, model)
  174. >>> action = dist.sample()
  175. """
  176. @staticmethod
  177. @DeveloperAPI
  178. def get_action_dist(
  179. action_space: gym.Space,
  180. config: ModelConfigDict,
  181. dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
  182. framework: str = "tf",
  183. **kwargs) -> (type, int):
  184. """Returns a distribution class and size for the given action space.
  185. Args:
  186. action_space (Space): Action space of the target gym env.
  187. config (Optional[dict]): Optional model config.
  188. dist_type (Optional[Union[str, Type[ActionDistribution]]]):
  189. Identifier of the action distribution (str) interpreted as a
  190. hint or the actual ActionDistribution class to use.
  191. framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
  192. kwargs (dict): Optional kwargs to pass on to the Distribution's
  193. constructor.
  194. Returns:
  195. Tuple:
  196. - dist_class (ActionDistribution): Python class of the
  197. distribution.
  198. - dist_dim (int): The size of the input vector to the
  199. distribution.
  200. """
  201. dist_cls = None
  202. config = config or MODEL_DEFAULTS
  203. # Custom distribution given.
  204. if config.get("custom_action_dist"):
  205. custom_action_config = config.copy()
  206. action_dist_name = custom_action_config.pop("custom_action_dist")
  207. logger.debug(
  208. "Using custom action distribution {}".format(action_dist_name))
  209. dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
  210. action_dist_name)
  211. return ModelCatalog._get_multi_action_distribution(
  212. dist_cls, action_space, custom_action_config, framework)
  213. # Dist_type is given directly as a class.
  214. elif type(dist_type) is type and \
  215. issubclass(dist_type, ActionDistribution) and \
  216. dist_type not in (
  217. MultiActionDistribution, TorchMultiActionDistribution):
  218. dist_cls = dist_type
  219. # Box space -> DiagGaussian OR Deterministic.
  220. elif isinstance(action_space, Box):
  221. if action_space.dtype.name.startswith("int"):
  222. low_ = np.min(action_space.low)
  223. high_ = np.max(action_space.high)
  224. dist_cls = TorchMultiCategorical if framework == "torch" \
  225. else MultiCategorical
  226. num_cats = int(np.product(action_space.shape))
  227. return partial(
  228. dist_cls,
  229. input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
  230. action_space=action_space), num_cats * (high_ - low_ + 1)
  231. else:
  232. if len(action_space.shape) > 1:
  233. raise UnsupportedSpaceException(
  234. "Action space has multiple dimensions "
  235. "{}. ".format(action_space.shape) +
  236. "Consider reshaping this into a single dimension, "
  237. "using a custom action distribution, "
  238. "using a Tuple action space, or the multi-agent API.")
  239. # TODO(sven): Check for bounds and return SquashedNormal, etc..
  240. if dist_type is None:
  241. return partial(
  242. TorchDiagGaussian if framework == "torch" else
  243. DiagGaussian, action_space=action_space), \
  244. DiagGaussian.required_model_output_shape(
  245. action_space, config)
  246. elif dist_type == "deterministic":
  247. dist_cls = TorchDeterministic if framework == "torch" \
  248. else Deterministic
  249. # Discrete Space -> Categorical.
  250. elif isinstance(action_space, Discrete):
  251. dist_cls = TorchCategorical if framework == "torch" else \
  252. JAXCategorical if framework == "jax" else Categorical
  253. # Tuple/Dict Spaces -> MultiAction.
  254. elif dist_type in (MultiActionDistribution,
  255. TorchMultiActionDistribution) or \
  256. isinstance(action_space, (Tuple, Dict)):
  257. return ModelCatalog._get_multi_action_distribution(
  258. (MultiActionDistribution
  259. if framework == "tf" else TorchMultiActionDistribution),
  260. action_space, config, framework)
  261. # Simplex -> Dirichlet.
  262. elif isinstance(action_space, Simplex):
  263. if framework == "torch":
  264. # TODO(sven): implement
  265. raise NotImplementedError(
  266. "Simplex action spaces not supported for torch.")
  267. dist_cls = Dirichlet
  268. # MultiDiscrete -> MultiCategorical.
  269. elif isinstance(action_space, MultiDiscrete):
  270. dist_cls = TorchMultiCategorical if framework == "torch" else \
  271. MultiCategorical
  272. return partial(dist_cls, input_lens=action_space.nvec), \
  273. int(sum(action_space.nvec))
  274. # Unknown type -> Error.
  275. else:
  276. raise NotImplementedError("Unsupported args: {} {}".format(
  277. action_space, dist_type))
  278. return dist_cls, dist_cls.required_model_output_shape(
  279. action_space, config)
  280. @staticmethod
  281. @DeveloperAPI
  282. def get_action_shape(action_space: gym.Space,
  283. framework: str = "tf") -> (np.dtype, List[int]):
  284. """Returns action tensor dtype and shape for the action space.
  285. Args:
  286. action_space (Space): Action space of the target gym env.
  287. framework (str): The framework identifier. One of "tf" or "torch".
  288. Returns:
  289. (dtype, shape): Dtype and shape of the actions tensor.
  290. """
  291. dl_lib = torch if framework == "torch" else tf
  292. if isinstance(action_space, Discrete):
  293. return action_space.dtype, (None, )
  294. elif isinstance(action_space, (Box, Simplex)):
  295. if np.issubdtype(action_space.dtype, np.floating):
  296. return dl_lib.float32, (None, ) + action_space.shape
  297. elif np.issubdtype(action_space.dtype, np.integer):
  298. return dl_lib.int32, (None, ) + action_space.shape
  299. else:
  300. raise ValueError(
  301. "RLlib doesn't support non int or float box spaces")
  302. elif isinstance(action_space, MultiDiscrete):
  303. return action_space.dtype, (None, ) + action_space.shape
  304. elif isinstance(action_space, (Tuple, Dict)):
  305. flat_action_space = flatten_space(action_space)
  306. size = 0
  307. all_discrete = True
  308. for i in range(len(flat_action_space)):
  309. if isinstance(flat_action_space[i], Discrete):
  310. size += 1
  311. else:
  312. all_discrete = False
  313. size += np.product(flat_action_space[i].shape)
  314. size = int(size)
  315. return dl_lib.int32 if all_discrete else dl_lib.float32, \
  316. (None, size)
  317. else:
  318. raise NotImplementedError(
  319. "Action space {} not supported".format(action_space))
  320. @staticmethod
  321. @DeveloperAPI
  322. def get_action_placeholder(action_space: gym.Space,
  323. name: str = "action") -> TensorType:
  324. """Returns an action placeholder consistent with the action space
  325. Args:
  326. action_space (Space): Action space of the target gym env.
  327. name (str): An optional string to name the placeholder by.
  328. Default: "action".
  329. Returns:
  330. action_placeholder (Tensor): A placeholder for the actions
  331. """
  332. dtype, shape = ModelCatalog.get_action_shape(
  333. action_space, framework="tf")
  334. return tf1.placeholder(dtype, shape=shape, name=name)
  335. @staticmethod
  336. @DeveloperAPI
  337. def get_model_v2(obs_space: gym.Space,
  338. action_space: gym.Space,
  339. num_outputs: int,
  340. model_config: ModelConfigDict,
  341. framework: str = "tf",
  342. name: str = "default_model",
  343. model_interface: type = None,
  344. default_model: type = None,
  345. **model_kwargs) -> ModelV2:
  346. """Returns a suitable model compatible with given spaces and output.
  347. Args:
  348. obs_space (Space): Observation space of the target gym env. This
  349. may have an `original_space` attribute that specifies how to
  350. unflatten the tensor into a ragged tensor.
  351. action_space (Space): Action space of the target gym env.
  352. num_outputs (int): The size of the output vector of the model.
  353. model_config (ModelConfigDict): The "model" sub-config dict
  354. within the Trainer's config dict.
  355. framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
  356. name (str): Name (scope) for the model.
  357. model_interface (cls): Interface required for the model
  358. default_model (cls): Override the default class for the model. This
  359. only has an effect when not using a custom model
  360. model_kwargs (dict): args to pass to the ModelV2 constructor
  361. Returns:
  362. model (ModelV2): Model to use for the policy.
  363. """
  364. # Validate the given config dict.
  365. ModelCatalog._validate_config(
  366. config=model_config,
  367. action_space=action_space,
  368. framework=framework)
  369. if model_config.get("custom_model"):
  370. # Allow model kwargs to be overridden / augmented by
  371. # custom_model_config.
  372. customized_model_kwargs = dict(
  373. model_kwargs, **model_config.get("custom_model_config", {}))
  374. if isinstance(model_config["custom_model"], type):
  375. model_cls = model_config["custom_model"]
  376. else:
  377. model_cls = _global_registry.get(RLLIB_MODEL,
  378. model_config["custom_model"])
  379. # Only allow ModelV2 or native keras Models.
  380. if not issubclass(model_cls, ModelV2):
  381. if framework not in ["tf", "tf2", "tfe"] or \
  382. not issubclass(model_cls, tf.keras.Model):
  383. raise ValueError(
  384. "`model_cls` must be a ModelV2 sub-class, but is"
  385. " {}!".format(model_cls))
  386. logger.info("Wrapping {} as {}".format(model_cls, model_interface))
  387. model_cls = ModelCatalog._wrap_if_needed(model_cls,
  388. model_interface)
  389. if framework in ["tf2", "tf", "tfe"]:
  390. # Try wrapping custom model with LSTM/attention, if required.
  391. if model_config.get("use_lstm") or \
  392. model_config.get("use_attention"):
  393. from ray.rllib.models.tf.attention_net import \
  394. AttentionWrapper, Keras_AttentionWrapper
  395. from ray.rllib.models.tf.recurrent_net import \
  396. LSTMWrapper, Keras_LSTMWrapper
  397. wrapped_cls = model_cls
  398. # Wrapped (custom) model is itself a keras Model ->
  399. # wrap with keras LSTM/GTrXL (attention) wrappers.
  400. if issubclass(wrapped_cls, tf.keras.Model):
  401. model_cls = Keras_LSTMWrapper if \
  402. model_config.get("use_lstm") else \
  403. Keras_AttentionWrapper
  404. model_config["wrapped_cls"] = wrapped_cls
  405. # Wrapped (custom) model is ModelV2 ->
  406. # wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
  407. else:
  408. forward = wrapped_cls.forward
  409. model_cls = ModelCatalog._wrap_if_needed(
  410. wrapped_cls, LSTMWrapper if
  411. model_config.get("use_lstm") else AttentionWrapper)
  412. model_cls._wrapped_forward = forward
  413. # Obsolete: Track and warn if vars were created but not
  414. # registered. Only still do this, if users do register their
  415. # variables. If not (which they shouldn't), don't check here.
  416. created = set()
  417. def track_var_creation(next_creator, **kw):
  418. v = next_creator(**kw)
  419. created.add(v)
  420. return v
  421. with tf.variable_creator_scope(track_var_creation):
  422. if issubclass(model_cls, tf.keras.Model):
  423. instance = model_cls(
  424. input_space=obs_space,
  425. action_space=action_space,
  426. num_outputs=num_outputs,
  427. name=name,
  428. **customized_model_kwargs,
  429. )
  430. else:
  431. # Try calling with kwargs first (custom ModelV2 should
  432. # accept these as kwargs, not get them from
  433. # config["custom_model_config"] anymore).
  434. try:
  435. instance = model_cls(
  436. obs_space,
  437. action_space,
  438. num_outputs,
  439. model_config,
  440. name,
  441. **customized_model_kwargs,
  442. )
  443. except TypeError as e:
  444. # Keyword error: Try old way w/o kwargs.
  445. if "__init__() got an unexpected " in e.args[0]:
  446. instance = model_cls(
  447. obs_space,
  448. action_space,
  449. num_outputs,
  450. model_config,
  451. name,
  452. **model_kwargs,
  453. )
  454. logger.warning(
  455. "Custom ModelV2 should accept all custom "
  456. "options as **kwargs, instead of expecting"
  457. " them in config['custom_model_config']!")
  458. # Other error -> re-raise.
  459. else:
  460. raise e
  461. # User still registered TFModelV2's variables: Check, whether
  462. # ok.
  463. registered = []
  464. if not isinstance(instance, tf.keras.Model):
  465. registered = set(instance.var_list)
  466. if len(registered) > 0:
  467. not_registered = set()
  468. for var in created:
  469. if var not in registered:
  470. not_registered.add(var)
  471. if not_registered:
  472. raise ValueError(
  473. "It looks like you are still using "
  474. "`{}.register_variables()` to register your "
  475. "model's weights. This is no longer required, but "
  476. "if you are still calling this method at least "
  477. "once, you must make sure to register all created "
  478. "variables properly. The missing variables are {},"
  479. " and you only registered {}. "
  480. "Did you forget to call `register_variables()` on "
  481. "some of the variables in question?".format(
  482. instance, not_registered, registered))
  483. elif framework == "torch":
  484. # Try wrapping custom model with LSTM/attention, if required.
  485. if model_config.get("use_lstm") or \
  486. model_config.get("use_attention"):
  487. from ray.rllib.models.torch.attention_net import \
  488. AttentionWrapper
  489. from ray.rllib.models.torch.recurrent_net import \
  490. LSTMWrapper
  491. wrapped_cls = model_cls
  492. forward = wrapped_cls.forward
  493. model_cls = ModelCatalog._wrap_if_needed(
  494. wrapped_cls, LSTMWrapper
  495. if model_config.get("use_lstm") else AttentionWrapper)
  496. model_cls._wrapped_forward = forward
  497. # PyTorch automatically tracks nn.Modules inside the parent
  498. # nn.Module's constructor.
  499. # Try calling with kwargs first (custom ModelV2 should
  500. # accept these as kwargs, not get them from
  501. # config["custom_model_config"] anymore).
  502. try:
  503. instance = model_cls(obs_space, action_space, num_outputs,
  504. model_config, name,
  505. **customized_model_kwargs)
  506. except TypeError as e:
  507. # Keyword error: Try old way w/o kwargs.
  508. if "__init__() got an unexpected " in e.args[0]:
  509. instance = model_cls(obs_space, action_space,
  510. num_outputs, model_config, name,
  511. **model_kwargs)
  512. logger.warning(
  513. "Custom ModelV2 should accept all custom "
  514. "options as **kwargs, instead of expecting"
  515. " them in config['custom_model_config']!")
  516. # Other error -> re-raise.
  517. else:
  518. raise e
  519. else:
  520. raise NotImplementedError(
  521. "`framework` must be 'tf2|tf|tfe|torch', but is "
  522. "{}!".format(framework))
  523. return instance
  524. # Find a default TFModelV2 and wrap with model_interface.
  525. if framework in ["tf", "tfe", "tf2"]:
  526. v2_class = None
  527. # Try to get a default v2 model.
  528. if not model_config.get("custom_model"):
  529. v2_class = default_model or ModelCatalog._get_v2_model_class(
  530. obs_space, model_config, framework=framework)
  531. if not v2_class:
  532. raise ValueError("ModelV2 class could not be determined!")
  533. if model_config.get("use_lstm") or \
  534. model_config.get("use_attention"):
  535. from ray.rllib.models.tf.attention_net import \
  536. AttentionWrapper, Keras_AttentionWrapper
  537. from ray.rllib.models.tf.recurrent_net import LSTMWrapper, \
  538. Keras_LSTMWrapper
  539. wrapped_cls = v2_class
  540. if model_config.get("use_lstm"):
  541. if issubclass(wrapped_cls, tf.keras.Model):
  542. v2_class = Keras_LSTMWrapper
  543. model_config["wrapped_cls"] = wrapped_cls
  544. else:
  545. v2_class = ModelCatalog._wrap_if_needed(
  546. wrapped_cls, LSTMWrapper)
  547. v2_class._wrapped_forward = wrapped_cls.forward
  548. else:
  549. if issubclass(wrapped_cls, tf.keras.Model):
  550. v2_class = Keras_AttentionWrapper
  551. model_config["wrapped_cls"] = wrapped_cls
  552. else:
  553. v2_class = ModelCatalog._wrap_if_needed(
  554. wrapped_cls, AttentionWrapper)
  555. v2_class._wrapped_forward = wrapped_cls.forward
  556. # Wrap in the requested interface.
  557. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  558. if issubclass(wrapper, tf.keras.Model):
  559. model = wrapper(
  560. input_space=obs_space,
  561. action_space=action_space,
  562. num_outputs=num_outputs,
  563. name=name,
  564. **dict(model_kwargs, **model_config),
  565. )
  566. return model
  567. return wrapper(obs_space, action_space, num_outputs, model_config,
  568. name, **model_kwargs)
  569. # Find a default TorchModelV2 and wrap with model_interface.
  570. elif framework == "torch":
  571. # Try to get a default v2 model.
  572. if not model_config.get("custom_model"):
  573. v2_class = default_model or ModelCatalog._get_v2_model_class(
  574. obs_space, model_config, framework=framework)
  575. if not v2_class:
  576. raise ValueError("ModelV2 class could not be determined!")
  577. if model_config.get("use_lstm") or \
  578. model_config.get("use_attention"):
  579. from ray.rllib.models.torch.attention_net import \
  580. AttentionWrapper
  581. from ray.rllib.models.torch.recurrent_net import LSTMWrapper
  582. wrapped_cls = v2_class
  583. forward = wrapped_cls.forward
  584. if model_config.get("use_lstm"):
  585. v2_class = ModelCatalog._wrap_if_needed(
  586. wrapped_cls, LSTMWrapper)
  587. else:
  588. v2_class = ModelCatalog._wrap_if_needed(
  589. wrapped_cls, AttentionWrapper)
  590. v2_class._wrapped_forward = forward
  591. # Wrap in the requested interface.
  592. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  593. return wrapper(obs_space, action_space, num_outputs, model_config,
  594. name, **model_kwargs)
  595. # Find a default JAXModelV2 and wrap with model_interface.
  596. elif framework == "jax":
  597. v2_class = \
  598. default_model or ModelCatalog._get_v2_model_class(
  599. obs_space, model_config, framework=framework)
  600. # Wrap in the requested interface.
  601. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  602. return wrapper(obs_space, action_space, num_outputs, model_config,
  603. name, **model_kwargs)
  604. else:
  605. raise NotImplementedError(
  606. "`framework` must be 'tf2|tf|tfe|torch', but is "
  607. "{}!".format(framework))
  608. @staticmethod
  609. @DeveloperAPI
  610. def get_preprocessor(env: gym.Env,
  611. options: Optional[dict] = None) -> Preprocessor:
  612. """Returns a suitable preprocessor for the given env.
  613. This is a wrapper for get_preprocessor_for_space().
  614. """
  615. return ModelCatalog.get_preprocessor_for_space(env.observation_space,
  616. options)
  617. @staticmethod
  618. @DeveloperAPI
  619. def get_preprocessor_for_space(observation_space: gym.Space,
  620. options: dict = None) -> Preprocessor:
  621. """Returns a suitable preprocessor for the given observation space.
  622. Args:
  623. observation_space (Space): The input observation space.
  624. options (dict): Options to pass to the preprocessor.
  625. Returns:
  626. preprocessor (Preprocessor): Preprocessor for the observations.
  627. """
  628. options = options or MODEL_DEFAULTS
  629. for k in options.keys():
  630. if k not in MODEL_DEFAULTS:
  631. raise Exception("Unknown config key `{}`, all keys: {}".format(
  632. k, list(MODEL_DEFAULTS)))
  633. if options.get("custom_preprocessor"):
  634. preprocessor = options["custom_preprocessor"]
  635. logger.info("Using custom preprocessor {}".format(preprocessor))
  636. logger.warning(
  637. "DeprecationWarning: Custom preprocessors are deprecated, "
  638. "since they sometimes conflict with the built-in "
  639. "preprocessors for handling complex observation spaces. "
  640. "Please use wrapper classes around your environment "
  641. "instead of preprocessors.")
  642. prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
  643. observation_space, options)
  644. else:
  645. cls = get_preprocessor(observation_space)
  646. prep = cls(observation_space, options)
  647. if prep is not None:
  648. logger.debug("Created preprocessor {}: {} -> {}".format(
  649. prep, observation_space, prep.shape))
  650. return prep
  651. @staticmethod
  652. @Deprecated(error=False)
  653. def register_custom_preprocessor(preprocessor_name: str,
  654. preprocessor_class: type) -> None:
  655. """Register a custom preprocessor class by name.
  656. The preprocessor can be later used by specifying
  657. {"custom_preprocessor": preprocesor_name} in the model config.
  658. Args:
  659. preprocessor_name (str): Name to register the preprocessor under.
  660. preprocessor_class (type): Python class of the preprocessor.
  661. """
  662. _global_registry.register(RLLIB_PREPROCESSOR, preprocessor_name,
  663. preprocessor_class)
  664. @staticmethod
  665. @PublicAPI
  666. def register_custom_model(model_name: str, model_class: type) -> None:
  667. """Register a custom model class by name.
  668. The model can be later used by specifying {"custom_model": model_name}
  669. in the model config.
  670. Args:
  671. model_name (str): Name to register the model under.
  672. model_class (type): Python class of the model.
  673. """
  674. if tf is not None:
  675. if issubclass(model_class, tf.keras.Model):
  676. deprecation_warning(old="register_custom_model", error=False)
  677. _global_registry.register(RLLIB_MODEL, model_name, model_class)
  678. @staticmethod
  679. @PublicAPI
  680. def register_custom_action_dist(action_dist_name: str,
  681. action_dist_class: type) -> None:
  682. """Register a custom action distribution class by name.
  683. The model can be later used by specifying
  684. {"custom_action_dist": action_dist_name} in the model config.
  685. Args:
  686. model_name (str): Name to register the action distribution under.
  687. model_class (type): Python class of the action distribution.
  688. """
  689. _global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
  690. action_dist_class)
  691. @staticmethod
  692. def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
  693. if not model_interface or issubclass(model_cls, model_interface):
  694. return model_cls
  695. assert issubclass(model_cls, ModelV2), model_cls
  696. class wrapper(model_interface, model_cls):
  697. pass
  698. name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
  699. wrapper.__name__ = name
  700. wrapper.__qualname__ = name
  701. return wrapper
  702. @staticmethod
  703. def _get_v2_model_class(input_space: gym.Space,
  704. model_config: ModelConfigDict,
  705. framework: str = "tf") -> Type[ModelV2]:
  706. VisionNet = None
  707. ComplexNet = None
  708. Keras_FCNet = None
  709. Keras_VisionNet = None
  710. if framework in ["tf2", "tf", "tfe"]:
  711. from ray.rllib.models.tf.fcnet import \
  712. FullyConnectedNetwork as FCNet, \
  713. Keras_FullyConnectedNetwork as Keras_FCNet
  714. from ray.rllib.models.tf.visionnet import \
  715. VisionNetwork as VisionNet, \
  716. Keras_VisionNetwork as Keras_VisionNet
  717. from ray.rllib.models.tf.complex_input_net import \
  718. ComplexInputNetwork as ComplexNet
  719. elif framework == "torch":
  720. from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
  721. FCNet)
  722. from ray.rllib.models.torch.visionnet import (VisionNetwork as
  723. VisionNet)
  724. from ray.rllib.models.torch.complex_input_net import \
  725. ComplexInputNetwork as ComplexNet
  726. elif framework == "jax":
  727. from ray.rllib.models.jax.fcnet import (FullyConnectedNetwork as
  728. FCNet)
  729. else:
  730. raise ValueError(
  731. "framework={} not supported in `ModelCatalog._get_v2_model_"
  732. "class`!".format(framework))
  733. orig_space = input_space if not hasattr(
  734. input_space, "original_space") else input_space.original_space
  735. # `input_space` is 3D Box -> VisionNet.
  736. if isinstance(input_space, Box) and len(input_space.shape) == 3:
  737. if framework == "jax":
  738. raise NotImplementedError("No non-FC default net for JAX yet!")
  739. elif model_config.get("_use_default_native_models") and \
  740. Keras_VisionNet:
  741. return Keras_VisionNet
  742. return VisionNet
  743. # `input_space` is 1D Box -> FCNet.
  744. elif isinstance(input_space, Box) and len(input_space.shape) == 1 and \
  745. (not isinstance(orig_space, (Dict, Tuple)) or not any(
  746. isinstance(s, Box) and len(s.shape) >= 2
  747. for s in tree.flatten(orig_space.spaces))):
  748. # Keras native requested AND no auto-rnn-wrapping.
  749. if model_config.get("_use_default_native_models") and Keras_FCNet:
  750. return Keras_FCNet
  751. # Classic ModelV2 FCNet.
  752. else:
  753. return FCNet
  754. # Complex (Dict, Tuple, 2D Box (flatten), Discrete, MultiDiscrete).
  755. else:
  756. if framework == "jax":
  757. raise NotImplementedError("No non-FC default net for JAX yet!")
  758. return ComplexNet
  759. @staticmethod
  760. def _get_multi_action_distribution(dist_class, action_space, config,
  761. framework):
  762. # In case the custom distribution is a child of MultiActionDistr.
  763. # If users want to completely ignore the suggested child
  764. # distributions, they should simply do so in their custom class'
  765. # constructor.
  766. if issubclass(dist_class,
  767. (MultiActionDistribution, TorchMultiActionDistribution)):
  768. flat_action_space = flatten_space(action_space)
  769. child_dists_and_in_lens = tree.map_structure(
  770. lambda s: ModelCatalog.get_action_dist(
  771. s, config, framework=framework), flat_action_space)
  772. child_dists = [e[0] for e in child_dists_and_in_lens]
  773. input_lens = [int(e[1]) for e in child_dists_and_in_lens]
  774. return partial(
  775. dist_class,
  776. action_space=action_space,
  777. child_distributions=child_dists,
  778. input_lens=input_lens), int(sum(input_lens))
  779. return dist_class, dist_class.required_model_output_shape(
  780. action_space, config)
  781. @staticmethod
  782. def _validate_config(config: ModelConfigDict,
  783. action_space: gym.spaces.Space,
  784. framework: str) -> None:
  785. """Validates a given model config dict.
  786. Args:
  787. config: The "model" sub-config dict
  788. within the Trainer's config dict.
  789. action_space: The action space of the model, whose config are
  790. validated.
  791. framework: One of "jax", "tf2", "tf", "tfe", or "torch".
  792. Raises:
  793. ValueError: If something is wrong with the given config.
  794. """
  795. # Soft-deprecate custom preprocessors.
  796. if config.get("custom_preprocessor") is not None:
  797. deprecation_warning(
  798. old="model.custom_preprocessor",
  799. new="gym.ObservationWrapper around your env or handle complex "
  800. "inputs inside your Model",
  801. error=False,
  802. )
  803. if config.get("use_attention") and config.get("use_lstm"):
  804. raise ValueError("Only one of `use_lstm` or `use_attention` may "
  805. "be set to True!")
  806. # For complex action spaces, only allow prev action inputs to
  807. # LSTMs and attention nets iff `_disable_action_flattening=True`.
  808. # TODO: `_disable_action_flattening=True` will be the default in
  809. # the future.
  810. if (config.get("lstm_use_prev_action") or
  811. config.get("attention_use_n_prev_actions", 0) > 0) and not \
  812. config.get("_disable_action_flattening") and \
  813. isinstance(action_space, (Tuple, Dict)):
  814. raise ValueError(
  815. "For your complex action space (Tuple|Dict) and your model's "
  816. "`prev-actions` setup of your model, you must set "
  817. "`_disable_action_flattening=True` in your main config dict!")
  818. if framework == "jax":
  819. if config.get("use_attention"):
  820. raise ValueError("`use_attention` not available for "
  821. "framework=jax so far!")
  822. elif config.get("use_lstm"):
  823. raise ValueError("`use_lstm` not available for "
  824. "framework=jax so far!")