tf_modelv2.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import contextlib
  2. import gym
  3. import re
  4. from typing import Dict, List, Union
  5. from ray.util import log_once
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.utils.annotations import override, PublicAPI
  8. from ray.rllib.utils.deprecation import deprecation_warning
  9. from ray.rllib.utils.framework import try_import_tf
  10. from ray.rllib.utils.typing import ModelConfigDict, TensorType
  11. tf1, tf, tfv = try_import_tf()
  12. @PublicAPI
  13. class TFModelV2(ModelV2):
  14. """TF version of ModelV2, which is always also a keras Model.
  15. Note that this class by itself is not a valid model unless you
  16. implement forward() in a subclass."""
  17. def __init__(self, obs_space: gym.spaces.Space,
  18. action_space: gym.spaces.Space, num_outputs: int,
  19. model_config: ModelConfigDict, name: str):
  20. """Initialize a TFModelV2.
  21. Here is an example implementation for a subclass
  22. ``MyModelClass(TFModelV2)``::
  23. def __init__(self, *args, **kwargs):
  24. super(MyModelClass, self).__init__(*args, **kwargs)
  25. input_layer = tf.keras.layers.Input(...)
  26. hidden_layer = tf.keras.layers.Dense(...)(input_layer)
  27. output_layer = tf.keras.layers.Dense(...)(hidden_layer)
  28. value_layer = tf.keras.layers.Dense(...)(hidden_layer)
  29. self.base_model = tf.keras.Model(
  30. input_layer, [output_layer, value_layer])
  31. """
  32. super().__init__(
  33. obs_space,
  34. action_space,
  35. num_outputs,
  36. model_config,
  37. name,
  38. framework="tf")
  39. # Deprecated: TFModelV2 now automatically track their variables.
  40. self.var_list = []
  41. if tf1.executing_eagerly():
  42. self.graph = None
  43. else:
  44. self.graph = tf1.get_default_graph()
  45. def context(self) -> contextlib.AbstractContextManager:
  46. """Returns a contextmanager for the current TF graph."""
  47. if self.graph:
  48. return self.graph.as_default()
  49. else:
  50. return ModelV2.context(self)
  51. def update_ops(self) -> List[TensorType]:
  52. """Return the list of update ops for this model.
  53. For example, this should include any BatchNorm update ops."""
  54. return []
  55. def register_variables(self, variables: List[TensorType]) -> None:
  56. """Register the given list of variables with this model."""
  57. if log_once("deprecated_tfmodelv2_register_variables"):
  58. deprecation_warning(
  59. old="TFModelV2.register_variables", error=False)
  60. self.var_list.extend(variables)
  61. @override(ModelV2)
  62. def variables(self, as_dict: bool = False) -> \
  63. Union[List[TensorType], Dict[str, TensorType]]:
  64. if as_dict:
  65. # Old way using `register_variables`.
  66. if self.var_list:
  67. return {v.name: v for v in self.var_list}
  68. # New way: Automatically determine the var tree.
  69. else:
  70. return self._find_sub_modules("", self.__dict__)
  71. # Old way using `register_variables`.
  72. if self.var_list:
  73. return list(self.var_list)
  74. # New way: Automatically determine the var tree.
  75. else:
  76. return list(self.variables(as_dict=True).values())
  77. @override(ModelV2)
  78. def trainable_variables(self, as_dict: bool = False) -> \
  79. Union[List[TensorType], Dict[str, TensorType]]:
  80. if as_dict:
  81. return {
  82. k: v
  83. for k, v in self.variables(as_dict=True).items() if v.trainable
  84. }
  85. return [v for v in self.variables() if v.trainable]
  86. @staticmethod
  87. def _find_sub_modules(current_key, struct):
  88. # Keras Model: key=k + "." + var-name (replace '/' by '.').
  89. if isinstance(struct, tf.keras.models.Model):
  90. ret = {}
  91. for var in struct.variables:
  92. name = re.sub("/", ".", var.name)
  93. key = current_key + "." + name
  94. ret[key] = var
  95. return ret
  96. # Other TFModelV2: Include its vars into ours.
  97. elif isinstance(struct, TFModelV2):
  98. return {
  99. current_key + "." + key: var
  100. for key, var in struct.variables(as_dict=True).items()
  101. }
  102. # tf.Variable
  103. elif isinstance(struct, tf.Variable):
  104. return {current_key: struct}
  105. # List/Tuple.
  106. elif isinstance(struct, (tuple, list)):
  107. ret = {}
  108. for i, value in enumerate(struct):
  109. sub_vars = TFModelV2._find_sub_modules(
  110. current_key + "_{}".format(i), value)
  111. ret.update(sub_vars)
  112. return ret
  113. # Dict.
  114. elif isinstance(struct, dict):
  115. if current_key:
  116. current_key += "_"
  117. ret = {}
  118. for key, value in struct.items():
  119. sub_vars = TFModelV2._find_sub_modules(current_key + str(key),
  120. value)
  121. ret.update(sub_vars)
  122. return ret
  123. return {}