catalog.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892
  1. from functools import partial
  2. import gym
  3. from gym.spaces import Box, Dict, Discrete, MultiDiscrete, Tuple
  4. import logging
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. from typing import List, Optional, Type, Union
  8. from ray.tune.registry import RLLIB_MODEL, RLLIB_PREPROCESSOR, \
  9. RLLIB_ACTION_DIST, _global_registry
  10. from ray.rllib.models.action_dist import ActionDistribution
  11. from ray.rllib.models.jax.jax_action_dist import JAXCategorical
  12. from ray.rllib.models.modelv2 import ModelV2
  13. from ray.rllib.models.preprocessors import get_preprocessor, Preprocessor
  14. from ray.rllib.models.tf.tf_action_dist import Categorical, \
  15. Deterministic, DiagGaussian, Dirichlet, \
  16. MultiActionDistribution, MultiCategorical
  17. from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
  18. TorchDeterministic, TorchDiagGaussian, \
  19. TorchMultiActionDistribution, TorchMultiCategorical
  20. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
  21. from ray.rllib.utils.deprecation import Deprecated, DEPRECATED_VALUE, \
  22. deprecation_warning
  23. from ray.rllib.utils.error import UnsupportedSpaceException
  24. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  25. from ray.rllib.utils.spaces.simplex import Simplex
  26. from ray.rllib.utils.spaces.space_utils import flatten_space
  27. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  28. tf1, tf, tfv = try_import_tf()
  29. torch, _ = try_import_torch()
  30. logger = logging.getLogger(__name__)
  31. # yapf: disable
  32. # __sphinx_doc_begin__
  33. MODEL_DEFAULTS: ModelConfigDict = {
  34. # Experimental flag.
  35. # If True, try to use a native (tf.keras.Model or torch.Module) default
  36. # model instead of our built-in ModelV2 defaults.
  37. # If False (default), use "classic" ModelV2 default models.
  38. # Note that this currently only works for:
  39. # 1) framework != torch AND
  40. # 2) fully connected and CNN default networks as well as
  41. # auto-wrapped LSTM- and attention nets.
  42. "_use_default_native_models": False,
  43. # Experimental flag.
  44. # If True, user specified no preprocessor to be created
  45. # (via config._disable_preprocessor_api=True). If True, observations
  46. # will arrive in model as they are returned by the env.
  47. "_disable_preprocessor_api": False,
  48. # === Built-in options ===
  49. # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
  50. # These are used if no custom model is specified and the input space is 1D.
  51. # Number of hidden layers to be used.
  52. "fcnet_hiddens": [256, 256],
  53. # Activation function descriptor.
  54. # Supported values are: "tanh", "relu", "swish" (or "silu"),
  55. # "linear" (or None).
  56. "fcnet_activation": "tanh",
  57. # VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
  58. # These are used if no custom model is specified and the input space is 2D.
  59. # Filter config: List of [out_channels, kernel, stride] for each filter.
  60. # Example:
  61. # Use None for making RLlib try to find a default filter setup given the
  62. # observation space.
  63. "conv_filters": None,
  64. # Activation function descriptor.
  65. # Supported values are: "tanh", "relu", "swish" (or "silu"),
  66. # "linear" (or None).
  67. "conv_activation": "relu",
  68. # Some default models support a final FC stack of n Dense layers with given
  69. # activation:
  70. # - Complex observation spaces: Image components are fed through
  71. # VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
  72. # everything is concated and pushed through this final FC stack.
  73. # - VisionNets (CNNs), e.g. after the CNN stack, there may be
  74. # additional Dense layers.
  75. # - FullyConnectedNetworks will have this additional FCStack as well
  76. # (that's why it's empty by default).
  77. "post_fcnet_hiddens": [],
  78. "post_fcnet_activation": "relu",
  79. # For DiagGaussian action distributions, make the second half of the model
  80. # outputs floating bias variables instead of state-dependent. This only
  81. # has an effect is using the default fully connected net.
  82. "free_log_std": False,
  83. # Whether to skip the final linear layer used to resize the hidden layer
  84. # outputs to size `num_outputs`. If True, then the last hidden layer
  85. # should already match num_outputs.
  86. "no_final_linear": False,
  87. # Whether layers should be shared for the value function.
  88. "vf_share_layers": True,
  89. # == LSTM ==
  90. # Whether to wrap the model with an LSTM.
  91. "use_lstm": False,
  92. # Max seq len for training the LSTM, defaults to 20.
  93. "max_seq_len": 20,
  94. # Size of the LSTM cell.
  95. "lstm_cell_size": 256,
  96. # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
  97. "lstm_use_prev_action": False,
  98. # Whether to feed r_{t-1} to LSTM.
  99. "lstm_use_prev_reward": False,
  100. # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
  101. "_time_major": False,
  102. # == Attention Nets (experimental: torch-version is untested) ==
  103. # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
  104. # wrapper Model around the default Model.
  105. "use_attention": False,
  106. # The number of transformer units within GTrXL.
  107. # A transformer unit in GTrXL consists of a) MultiHeadAttention module and
  108. # b) a position-wise MLP.
  109. "attention_num_transformer_units": 1,
  110. # The input and output size of each transformer unit.
  111. "attention_dim": 64,
  112. # The number of attention heads within the MultiHeadAttention units.
  113. "attention_num_heads": 1,
  114. # The dim of a single head (within the MultiHeadAttention units).
  115. "attention_head_dim": 32,
  116. # The memory sizes for inference and training.
  117. "attention_memory_inference": 50,
  118. "attention_memory_training": 50,
  119. # The output dim of the position-wise MLP.
  120. "attention_position_wise_mlp_dim": 32,
  121. # The initial bias values for the 2 GRU gates within a transformer unit.
  122. "attention_init_gru_gate_bias": 2.0,
  123. # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
  124. "attention_use_n_prev_actions": 0,
  125. # Whether to feed r_{t-n:t-1} to GTrXL.
  126. "attention_use_n_prev_rewards": 0,
  127. # == Atari ==
  128. # Set to True to enable 4x stacking behavior.
  129. "framestack": True,
  130. # Final resized frame dimension
  131. "dim": 84,
  132. # (deprecated) Converts ATARI frame to 1 Channel Grayscale image
  133. "grayscale": False,
  134. # (deprecated) Changes frame to range from [-1, 1] if true
  135. "zero_mean": True,
  136. # === Options for custom models ===
  137. # Name of a custom model to use
  138. "custom_model": None,
  139. # Extra options to pass to the custom classes. These will be available to
  140. # the Model's constructor in the model_config field. Also, they will be
  141. # attempted to be passed as **kwargs to ModelV2 models. For an example,
  142. # see rllib/models/[tf|torch]/attention_net.py.
  143. "custom_model_config": {},
  144. # Name of a custom action distribution to use.
  145. "custom_action_dist": None,
  146. # Custom preprocessors are deprecated. Please use a wrapper class around
  147. # your environment instead to preprocess observations.
  148. "custom_preprocessor": None,
  149. # Deprecated keys:
  150. # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
  151. "lstm_use_prev_action_reward": DEPRECATED_VALUE,
  152. }
  153. # __sphinx_doc_end__
  154. # yapf: enable
  155. @PublicAPI
  156. class ModelCatalog:
  157. """Registry of models, preprocessors, and action distributions for envs.
  158. Examples:
  159. >>> prep = ModelCatalog.get_preprocessor(env)
  160. >>> observation = prep.transform(raw_observation)
  161. >>> dist_class, dist_dim = ModelCatalog.get_action_dist(
  162. ... env.action_space, {})
  163. >>> model = ModelCatalog.get_model_v2(
  164. ... obs_space, action_space, num_outputs, options)
  165. >>> dist = dist_class(model.outputs, model)
  166. >>> action = dist.sample()
  167. """
  168. @staticmethod
  169. @DeveloperAPI
  170. def get_action_dist(
  171. action_space: gym.Space,
  172. config: ModelConfigDict,
  173. dist_type: Optional[Union[str, Type[ActionDistribution]]] = None,
  174. framework: str = "tf",
  175. **kwargs) -> (type, int):
  176. """Returns a distribution class and size for the given action space.
  177. Args:
  178. action_space (Space): Action space of the target gym env.
  179. config (Optional[dict]): Optional model config.
  180. dist_type (Optional[Union[str, Type[ActionDistribution]]]):
  181. Identifier of the action distribution (str) interpreted as a
  182. hint or the actual ActionDistribution class to use.
  183. framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
  184. kwargs (dict): Optional kwargs to pass on to the Distribution's
  185. constructor.
  186. Returns:
  187. Tuple:
  188. - dist_class (ActionDistribution): Python class of the
  189. distribution.
  190. - dist_dim (int): The size of the input vector to the
  191. distribution.
  192. """
  193. dist_cls = None
  194. config = config or MODEL_DEFAULTS
  195. # Custom distribution given.
  196. if config.get("custom_action_dist"):
  197. custom_action_config = config.copy()
  198. action_dist_name = custom_action_config.pop("custom_action_dist")
  199. logger.debug(
  200. "Using custom action distribution {}".format(action_dist_name))
  201. dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
  202. action_dist_name)
  203. return ModelCatalog._get_multi_action_distribution(
  204. dist_cls, action_space, custom_action_config, framework)
  205. # Dist_type is given directly as a class.
  206. elif type(dist_type) is type and \
  207. issubclass(dist_type, ActionDistribution) and \
  208. dist_type not in (
  209. MultiActionDistribution, TorchMultiActionDistribution):
  210. dist_cls = dist_type
  211. # Box space -> DiagGaussian OR Deterministic.
  212. elif isinstance(action_space, Box):
  213. if action_space.dtype.name.startswith("int"):
  214. low_ = np.min(action_space.low)
  215. high_ = np.max(action_space.high)
  216. dist_cls = TorchMultiCategorical if framework == "torch" \
  217. else MultiCategorical
  218. num_cats = int(np.product(action_space.shape))
  219. return partial(
  220. dist_cls,
  221. input_lens=[high_ - low_ + 1 for _ in range(num_cats)],
  222. action_space=action_space), num_cats * (high_ - low_ + 1)
  223. else:
  224. if len(action_space.shape) > 1:
  225. raise UnsupportedSpaceException(
  226. "Action space has multiple dimensions "
  227. "{}. ".format(action_space.shape) +
  228. "Consider reshaping this into a single dimension, "
  229. "using a custom action distribution, "
  230. "using a Tuple action space, or the multi-agent API.")
  231. # TODO(sven): Check for bounds and return SquashedNormal, etc..
  232. if dist_type is None:
  233. dist_cls = TorchDiagGaussian if framework == "torch" \
  234. else DiagGaussian
  235. elif dist_type == "deterministic":
  236. dist_cls = TorchDeterministic if framework == "torch" \
  237. else Deterministic
  238. # Discrete Space -> Categorical.
  239. elif isinstance(action_space, Discrete):
  240. dist_cls = TorchCategorical if framework == "torch" else \
  241. JAXCategorical if framework == "jax" else Categorical
  242. # Tuple/Dict Spaces -> MultiAction.
  243. elif dist_type in (MultiActionDistribution,
  244. TorchMultiActionDistribution) or \
  245. isinstance(action_space, (Tuple, Dict)):
  246. return ModelCatalog._get_multi_action_distribution(
  247. (MultiActionDistribution
  248. if framework == "tf" else TorchMultiActionDistribution),
  249. action_space, config, framework)
  250. # Simplex -> Dirichlet.
  251. elif isinstance(action_space, Simplex):
  252. if framework == "torch":
  253. # TODO(sven): implement
  254. raise NotImplementedError(
  255. "Simplex action spaces not supported for torch.")
  256. dist_cls = Dirichlet
  257. # MultiDiscrete -> MultiCategorical.
  258. elif isinstance(action_space, MultiDiscrete):
  259. dist_cls = TorchMultiCategorical if framework == "torch" else \
  260. MultiCategorical
  261. return partial(dist_cls, input_lens=action_space.nvec), \
  262. int(sum(action_space.nvec))
  263. # Unknown type -> Error.
  264. else:
  265. raise NotImplementedError("Unsupported args: {} {}".format(
  266. action_space, dist_type))
  267. return dist_cls, dist_cls.required_model_output_shape(
  268. action_space, config)
  269. @staticmethod
  270. @DeveloperAPI
  271. def get_action_shape(action_space: gym.Space,
  272. framework: str = "tf") -> (np.dtype, List[int]):
  273. """Returns action tensor dtype and shape for the action space.
  274. Args:
  275. action_space (Space): Action space of the target gym env.
  276. framework (str): The framework identifier. One of "tf" or "torch".
  277. Returns:
  278. (dtype, shape): Dtype and shape of the actions tensor.
  279. """
  280. dl_lib = torch if framework == "torch" else tf
  281. if isinstance(action_space, Discrete):
  282. return action_space.dtype, (None, )
  283. elif isinstance(action_space, (Box, Simplex)):
  284. if np.issubdtype(action_space.dtype, np.floating):
  285. return dl_lib.float32, (None, ) + action_space.shape
  286. elif np.issubdtype(action_space.dtype, np.integer):
  287. return dl_lib.int32, (None, ) + action_space.shape
  288. else:
  289. raise ValueError(
  290. "RLlib doesn't support non int or float box spaces")
  291. elif isinstance(action_space, MultiDiscrete):
  292. return action_space.dtype, (None, ) + action_space.shape
  293. elif isinstance(action_space, (Tuple, Dict)):
  294. flat_action_space = flatten_space(action_space)
  295. size = 0
  296. all_discrete = True
  297. for i in range(len(flat_action_space)):
  298. if isinstance(flat_action_space[i], Discrete):
  299. size += 1
  300. else:
  301. all_discrete = False
  302. size += np.product(flat_action_space[i].shape)
  303. size = int(size)
  304. return dl_lib.int32 if all_discrete else dl_lib.float32, \
  305. (None, size)
  306. else:
  307. raise NotImplementedError(
  308. "Action space {} not supported".format(action_space))
  309. @staticmethod
  310. @DeveloperAPI
  311. def get_action_placeholder(action_space: gym.Space,
  312. name: str = "action") -> TensorType:
  313. """Returns an action placeholder consistent with the action space
  314. Args:
  315. action_space (Space): Action space of the target gym env.
  316. name (str): An optional string to name the placeholder by.
  317. Default: "action".
  318. Returns:
  319. action_placeholder (Tensor): A placeholder for the actions
  320. """
  321. dtype, shape = ModelCatalog.get_action_shape(
  322. action_space, framework="tf")
  323. return tf1.placeholder(dtype, shape=shape, name=name)
  324. @staticmethod
  325. @DeveloperAPI
  326. def get_model_v2(obs_space: gym.Space,
  327. action_space: gym.Space,
  328. num_outputs: int,
  329. model_config: ModelConfigDict,
  330. framework: str = "tf",
  331. name: str = "default_model",
  332. model_interface: type = None,
  333. default_model: type = None,
  334. **model_kwargs) -> ModelV2:
  335. """Returns a suitable model compatible with given spaces and output.
  336. Args:
  337. obs_space (Space): Observation space of the target gym env. This
  338. may have an `original_space` attribute that specifies how to
  339. unflatten the tensor into a ragged tensor.
  340. action_space (Space): Action space of the target gym env.
  341. num_outputs (int): The size of the output vector of the model.
  342. model_config (ModelConfigDict): The "model" sub-config dict
  343. within the Trainer's config dict.
  344. framework (str): One of "tf2", "tf", "tfe", "torch", or "jax".
  345. name (str): Name (scope) for the model.
  346. model_interface (cls): Interface required for the model
  347. default_model (cls): Override the default class for the model. This
  348. only has an effect when not using a custom model
  349. model_kwargs (dict): args to pass to the ModelV2 constructor
  350. Returns:
  351. model (ModelV2): Model to use for the policy.
  352. """
  353. # Validate the given config dict.
  354. ModelCatalog._validate_config(config=model_config, framework=framework)
  355. if model_config.get("custom_model"):
  356. # Allow model kwargs to be overridden / augmented by
  357. # custom_model_config.
  358. customized_model_kwargs = dict(
  359. model_kwargs, **model_config.get("custom_model_config", {}))
  360. if isinstance(model_config["custom_model"], type):
  361. model_cls = model_config["custom_model"]
  362. else:
  363. model_cls = _global_registry.get(RLLIB_MODEL,
  364. model_config["custom_model"])
  365. # Only allow ModelV2 or native keras Models.
  366. if not issubclass(model_cls, ModelV2):
  367. if framework not in ["tf", "tf2", "tfe"] or \
  368. not issubclass(model_cls, tf.keras.Model):
  369. raise ValueError(
  370. "`model_cls` must be a ModelV2 sub-class, but is"
  371. " {}!".format(model_cls))
  372. logger.info("Wrapping {} as {}".format(model_cls, model_interface))
  373. model_cls = ModelCatalog._wrap_if_needed(model_cls,
  374. model_interface)
  375. if framework in ["tf2", "tf", "tfe"]:
  376. # Try wrapping custom model with LSTM/attention, if required.
  377. if model_config.get("use_lstm") or \
  378. model_config.get("use_attention"):
  379. from ray.rllib.models.tf.attention_net import \
  380. AttentionWrapper, Keras_AttentionWrapper
  381. from ray.rllib.models.tf.recurrent_net import \
  382. LSTMWrapper, Keras_LSTMWrapper
  383. wrapped_cls = model_cls
  384. # Wrapped (custom) model is itself a keras Model ->
  385. # wrap with keras LSTM/GTrXL (attention) wrappers.
  386. if issubclass(wrapped_cls, tf.keras.Model):
  387. model_cls = Keras_LSTMWrapper if \
  388. model_config.get("use_lstm") else \
  389. Keras_AttentionWrapper
  390. model_config["wrapped_cls"] = wrapped_cls
  391. # Wrapped (custom) model is ModelV2 ->
  392. # wrap with ModelV2 LSTM/GTrXL (attention) wrappers.
  393. else:
  394. forward = wrapped_cls.forward
  395. model_cls = ModelCatalog._wrap_if_needed(
  396. wrapped_cls, LSTMWrapper if
  397. model_config.get("use_lstm") else AttentionWrapper)
  398. model_cls._wrapped_forward = forward
  399. # Obsolete: Track and warn if vars were created but not
  400. # registered. Only still do this, if users do register their
  401. # variables. If not (which they shouldn't), don't check here.
  402. created = set()
  403. def track_var_creation(next_creator, **kw):
  404. v = next_creator(**kw)
  405. created.add(v)
  406. return v
  407. with tf.variable_creator_scope(track_var_creation):
  408. if issubclass(model_cls, tf.keras.Model):
  409. instance = model_cls(
  410. input_space=obs_space,
  411. action_space=action_space,
  412. num_outputs=num_outputs,
  413. name=name,
  414. **customized_model_kwargs,
  415. )
  416. else:
  417. # Try calling with kwargs first (custom ModelV2 should
  418. # accept these as kwargs, not get them from
  419. # config["custom_model_config"] anymore).
  420. try:
  421. instance = model_cls(
  422. obs_space,
  423. action_space,
  424. num_outputs,
  425. model_config,
  426. name,
  427. **customized_model_kwargs,
  428. )
  429. except TypeError as e:
  430. # Keyword error: Try old way w/o kwargs.
  431. if "__init__() got an unexpected " in e.args[0]:
  432. instance = model_cls(
  433. obs_space,
  434. action_space,
  435. num_outputs,
  436. model_config,
  437. name,
  438. **model_kwargs,
  439. )
  440. logger.warning(
  441. "Custom ModelV2 should accept all custom "
  442. "options as **kwargs, instead of expecting"
  443. " them in config['custom_model_config']!")
  444. # Other error -> re-raise.
  445. else:
  446. raise e
  447. # User still registered TFModelV2's variables: Check, whether
  448. # ok.
  449. registered = []
  450. if not isinstance(instance, tf.keras.Model):
  451. registered = set(instance.var_list)
  452. if len(registered) > 0:
  453. not_registered = set()
  454. for var in created:
  455. if var not in registered:
  456. not_registered.add(var)
  457. if not_registered:
  458. raise ValueError(
  459. "It looks like you are still using "
  460. "`{}.register_variables()` to register your "
  461. "model's weights. This is no longer required, but "
  462. "if you are still calling this method at least "
  463. "once, you must make sure to register all created "
  464. "variables properly. The missing variables are {},"
  465. " and you only registered {}. "
  466. "Did you forget to call `register_variables()` on "
  467. "some of the variables in question?".format(
  468. instance, not_registered, registered))
  469. elif framework == "torch":
  470. # Try wrapping custom model with LSTM/attention, if required.
  471. if model_config.get("use_lstm") or \
  472. model_config.get("use_attention"):
  473. from ray.rllib.models.torch.attention_net import \
  474. AttentionWrapper
  475. from ray.rllib.models.torch.recurrent_net import \
  476. LSTMWrapper
  477. wrapped_cls = model_cls
  478. forward = wrapped_cls.forward
  479. model_cls = ModelCatalog._wrap_if_needed(
  480. wrapped_cls, LSTMWrapper
  481. if model_config.get("use_lstm") else AttentionWrapper)
  482. model_cls._wrapped_forward = forward
  483. # PyTorch automatically tracks nn.Modules inside the parent
  484. # nn.Module's constructor.
  485. # Try calling with kwargs first (custom ModelV2 should
  486. # accept these as kwargs, not get them from
  487. # config["custom_model_config"] anymore).
  488. try:
  489. instance = model_cls(obs_space, action_space, num_outputs,
  490. model_config, name,
  491. **customized_model_kwargs)
  492. except TypeError as e:
  493. # Keyword error: Try old way w/o kwargs.
  494. if "__init__() got an unexpected " in e.args[0]:
  495. instance = model_cls(obs_space, action_space,
  496. num_outputs, model_config, name,
  497. **model_kwargs)
  498. logger.warning(
  499. "Custom ModelV2 should accept all custom "
  500. "options as **kwargs, instead of expecting"
  501. " them in config['custom_model_config']!")
  502. # Other error -> re-raise.
  503. else:
  504. raise e
  505. else:
  506. raise NotImplementedError(
  507. "`framework` must be 'tf2|tf|tfe|torch', but is "
  508. "{}!".format(framework))
  509. return instance
  510. # Find a default TFModelV2 and wrap with model_interface.
  511. if framework in ["tf", "tfe", "tf2"]:
  512. v2_class = None
  513. # Try to get a default v2 model.
  514. if not model_config.get("custom_model"):
  515. v2_class = default_model or ModelCatalog._get_v2_model_class(
  516. obs_space, model_config, framework=framework)
  517. if not v2_class:
  518. raise ValueError("ModelV2 class could not be determined!")
  519. if model_config.get("use_lstm") or \
  520. model_config.get("use_attention"):
  521. from ray.rllib.models.tf.attention_net import \
  522. AttentionWrapper, Keras_AttentionWrapper
  523. from ray.rllib.models.tf.recurrent_net import LSTMWrapper, \
  524. Keras_LSTMWrapper
  525. wrapped_cls = v2_class
  526. if model_config.get("use_lstm"):
  527. if issubclass(wrapped_cls, tf.keras.Model):
  528. v2_class = Keras_LSTMWrapper
  529. model_config["wrapped_cls"] = wrapped_cls
  530. else:
  531. v2_class = ModelCatalog._wrap_if_needed(
  532. wrapped_cls, LSTMWrapper)
  533. v2_class._wrapped_forward = wrapped_cls.forward
  534. else:
  535. if issubclass(wrapped_cls, tf.keras.Model):
  536. v2_class = Keras_AttentionWrapper
  537. model_config["wrapped_cls"] = wrapped_cls
  538. else:
  539. v2_class = ModelCatalog._wrap_if_needed(
  540. wrapped_cls, AttentionWrapper)
  541. v2_class._wrapped_forward = wrapped_cls.forward
  542. # Wrap in the requested interface.
  543. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  544. if issubclass(wrapper, tf.keras.Model):
  545. model = wrapper(
  546. input_space=obs_space,
  547. action_space=action_space,
  548. num_outputs=num_outputs,
  549. name=name,
  550. **dict(model_kwargs, **model_config),
  551. )
  552. return model
  553. return wrapper(obs_space, action_space, num_outputs, model_config,
  554. name, **model_kwargs)
  555. # Find a default TorchModelV2 and wrap with model_interface.
  556. elif framework == "torch":
  557. # Try to get a default v2 model.
  558. if not model_config.get("custom_model"):
  559. v2_class = default_model or ModelCatalog._get_v2_model_class(
  560. obs_space, model_config, framework=framework)
  561. if not v2_class:
  562. raise ValueError("ModelV2 class could not be determined!")
  563. if model_config.get("use_lstm") or \
  564. model_config.get("use_attention"):
  565. from ray.rllib.models.torch.attention_net import \
  566. AttentionWrapper
  567. from ray.rllib.models.torch.recurrent_net import LSTMWrapper
  568. wrapped_cls = v2_class
  569. forward = wrapped_cls.forward
  570. if model_config.get("use_lstm"):
  571. v2_class = ModelCatalog._wrap_if_needed(
  572. wrapped_cls, LSTMWrapper)
  573. else:
  574. v2_class = ModelCatalog._wrap_if_needed(
  575. wrapped_cls, AttentionWrapper)
  576. v2_class._wrapped_forward = forward
  577. # Wrap in the requested interface.
  578. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  579. return wrapper(obs_space, action_space, num_outputs, model_config,
  580. name, **model_kwargs)
  581. # Find a default JAXModelV2 and wrap with model_interface.
  582. elif framework == "jax":
  583. v2_class = \
  584. default_model or ModelCatalog._get_v2_model_class(
  585. obs_space, model_config, framework=framework)
  586. # Wrap in the requested interface.
  587. wrapper = ModelCatalog._wrap_if_needed(v2_class, model_interface)
  588. return wrapper(obs_space, action_space, num_outputs, model_config,
  589. name, **model_kwargs)
  590. else:
  591. raise NotImplementedError(
  592. "`framework` must be 'tf2|tf|tfe|torch', but is "
  593. "{}!".format(framework))
  594. @staticmethod
  595. @DeveloperAPI
  596. def get_preprocessor(env: gym.Env,
  597. options: Optional[dict] = None) -> Preprocessor:
  598. """Returns a suitable preprocessor for the given env.
  599. This is a wrapper for get_preprocessor_for_space().
  600. """
  601. return ModelCatalog.get_preprocessor_for_space(env.observation_space,
  602. options)
  603. @staticmethod
  604. @DeveloperAPI
  605. def get_preprocessor_for_space(observation_space: gym.Space,
  606. options: dict = None) -> Preprocessor:
  607. """Returns a suitable preprocessor for the given observation space.
  608. Args:
  609. observation_space (Space): The input observation space.
  610. options (dict): Options to pass to the preprocessor.
  611. Returns:
  612. preprocessor (Preprocessor): Preprocessor for the observations.
  613. """
  614. options = options or MODEL_DEFAULTS
  615. for k in options.keys():
  616. if k not in MODEL_DEFAULTS:
  617. raise Exception("Unknown config key `{}`, all keys: {}".format(
  618. k, list(MODEL_DEFAULTS)))
  619. if options.get("custom_preprocessor"):
  620. preprocessor = options["custom_preprocessor"]
  621. logger.info("Using custom preprocessor {}".format(preprocessor))
  622. logger.warning(
  623. "DeprecationWarning: Custom preprocessors are deprecated, "
  624. "since they sometimes conflict with the built-in "
  625. "preprocessors for handling complex observation spaces. "
  626. "Please use wrapper classes around your environment "
  627. "instead of preprocessors.")
  628. prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
  629. observation_space, options)
  630. else:
  631. cls = get_preprocessor(observation_space)
  632. prep = cls(observation_space, options)
  633. if prep is not None:
  634. logger.debug("Created preprocessor {}: {} -> {}".format(
  635. prep, observation_space, prep.shape))
  636. return prep
  637. @staticmethod
  638. @Deprecated(error=False)
  639. def register_custom_preprocessor(preprocessor_name: str,
  640. preprocessor_class: type) -> None:
  641. """Register a custom preprocessor class by name.
  642. The preprocessor can be later used by specifying
  643. {"custom_preprocessor": preprocesor_name} in the model config.
  644. Args:
  645. preprocessor_name (str): Name to register the preprocessor under.
  646. preprocessor_class (type): Python class of the preprocessor.
  647. """
  648. _global_registry.register(RLLIB_PREPROCESSOR, preprocessor_name,
  649. preprocessor_class)
  650. @staticmethod
  651. @PublicAPI
  652. def register_custom_model(model_name: str, model_class: type) -> None:
  653. """Register a custom model class by name.
  654. The model can be later used by specifying {"custom_model": model_name}
  655. in the model config.
  656. Args:
  657. model_name (str): Name to register the model under.
  658. model_class (type): Python class of the model.
  659. """
  660. if tf is not None:
  661. if issubclass(model_class, tf.keras.Model):
  662. deprecation_warning(old="register_custom_model", error=False)
  663. _global_registry.register(RLLIB_MODEL, model_name, model_class)
  664. @staticmethod
  665. @PublicAPI
  666. def register_custom_action_dist(action_dist_name: str,
  667. action_dist_class: type) -> None:
  668. """Register a custom action distribution class by name.
  669. The model can be later used by specifying
  670. {"custom_action_dist": action_dist_name} in the model config.
  671. Args:
  672. model_name (str): Name to register the action distribution under.
  673. model_class (type): Python class of the action distribution.
  674. """
  675. _global_registry.register(RLLIB_ACTION_DIST, action_dist_name,
  676. action_dist_class)
  677. @staticmethod
  678. def _wrap_if_needed(model_cls: type, model_interface: type) -> type:
  679. if not model_interface or issubclass(model_cls, model_interface):
  680. return model_cls
  681. assert issubclass(model_cls, ModelV2), model_cls
  682. class wrapper(model_interface, model_cls):
  683. pass
  684. name = "{}_as_{}".format(model_cls.__name__, model_interface.__name__)
  685. wrapper.__name__ = name
  686. wrapper.__qualname__ = name
  687. return wrapper
  688. @staticmethod
  689. def _get_v2_model_class(input_space: gym.Space,
  690. model_config: ModelConfigDict,
  691. framework: str = "tf") -> Type[ModelV2]:
  692. VisionNet = None
  693. ComplexNet = None
  694. Keras_FCNet = None
  695. Keras_VisionNet = None
  696. if framework in ["tf2", "tf", "tfe"]:
  697. from ray.rllib.models.tf.fcnet import \
  698. FullyConnectedNetwork as FCNet, \
  699. Keras_FullyConnectedNetwork as Keras_FCNet
  700. from ray.rllib.models.tf.visionnet import \
  701. VisionNetwork as VisionNet, \
  702. Keras_VisionNetwork as Keras_VisionNet
  703. from ray.rllib.models.tf.complex_input_net import \
  704. ComplexInputNetwork as ComplexNet
  705. elif framework == "torch":
  706. from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
  707. FCNet)
  708. from ray.rllib.models.torch.visionnet import (VisionNetwork as
  709. VisionNet)
  710. from ray.rllib.models.torch.complex_input_net import \
  711. ComplexInputNetwork as ComplexNet
  712. elif framework == "jax":
  713. from ray.rllib.models.jax.fcnet import (FullyConnectedNetwork as
  714. FCNet)
  715. else:
  716. raise ValueError(
  717. "framework={} not supported in `ModelCatalog._get_v2_model_"
  718. "class`!".format(framework))
  719. # Complex space, where at least one sub-space is image.
  720. # -> Complex input model (which auto-flattens everything, but correctly
  721. # processes image components with default CNN stacks).
  722. space_to_check = input_space if not hasattr(
  723. input_space, "original_space") else input_space.original_space
  724. if isinstance(input_space, (Dict, Tuple)) or (isinstance(
  725. space_to_check, (Dict, Tuple)) and any(
  726. isinstance(s, Box) and len(s.shape) >= 2
  727. for s in tree.flatten(space_to_check.spaces))):
  728. return ComplexNet
  729. # Single, flattenable/one-hot-able space -> Simple FCNet.
  730. if isinstance(input_space, (Discrete, MultiDiscrete)) or \
  731. len(input_space.shape) == 1 or (
  732. len(input_space.shape) == 2):
  733. # Keras native requested AND no auto-rnn-wrapping.
  734. if model_config.get("_use_default_native_models") and Keras_FCNet:
  735. return Keras_FCNet
  736. # Classic ModelV2 FCNet.
  737. else:
  738. return FCNet
  739. elif framework == "jax":
  740. raise NotImplementedError("No non-FC default net for JAX yet!")
  741. # Last resort: Conv2D stack for single image spaces.
  742. if model_config.get("_use_default_native_models") and Keras_VisionNet:
  743. return Keras_VisionNet
  744. return VisionNet
  745. @staticmethod
  746. def _get_multi_action_distribution(dist_class, action_space, config,
  747. framework):
  748. # In case the custom distribution is a child of MultiActionDistr.
  749. # If users want to completely ignore the suggested child
  750. # distributions, they should simply do so in their custom class'
  751. # constructor.
  752. if issubclass(dist_class,
  753. (MultiActionDistribution, TorchMultiActionDistribution)):
  754. flat_action_space = flatten_space(action_space)
  755. child_dists_and_in_lens = tree.map_structure(
  756. lambda s: ModelCatalog.get_action_dist(
  757. s, config, framework=framework), flat_action_space)
  758. child_dists = [e[0] for e in child_dists_and_in_lens]
  759. input_lens = [int(e[1]) for e in child_dists_and_in_lens]
  760. return partial(
  761. dist_class,
  762. action_space=action_space,
  763. child_distributions=child_dists,
  764. input_lens=input_lens), int(sum(input_lens))
  765. return dist_class, dist_class.required_model_output_shape(
  766. action_space, config)
  767. @staticmethod
  768. def _validate_config(config: ModelConfigDict, framework: str) -> None:
  769. """Validates a given model config dict.
  770. Args:
  771. config (ModelConfigDict): The "model" sub-config dict
  772. within the Trainer's config dict.
  773. framework (str): One of "jax", "tf2", "tf", "tfe", or "torch".
  774. Raises:
  775. ValueError: If something is wrong with the given config.
  776. """
  777. # Soft-deprecate custom preprocessors.
  778. if config.get("custom_preprocessor") is not None:
  779. deprecation_warning(
  780. old="model.custom_preprocessor",
  781. new="gym.ObservationWrapper around your env or handle complex "
  782. "inputs inside your Model",
  783. error=False,
  784. )
  785. if config.get("use_attention") and config.get("use_lstm"):
  786. raise ValueError("Only one of `use_lstm` or `use_attention` may "
  787. "be set to True!")
  788. if framework == "jax":
  789. if config.get("use_attention"):
  790. raise ValueError("`use_attention` not available for "
  791. "framework=jax so far!")
  792. elif config.get("use_lstm"):
  793. raise ValueError("`use_lstm` not available for "
  794. "framework=jax so far!")