test_model_imports.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #!/usr/bin/env python
  2. import h5py
  3. import numpy as np
  4. from pathlib import Path
  5. import unittest
  6. import ray
  7. from ray.rllib.agents.registry import get_trainer_class
  8. from ray.rllib.models.catalog import ModelCatalog
  9. from ray.rllib.models.tf.misc import normc_initializer
  10. from ray.rllib.models.tf.tf_modelv2 import TFModelV2
  11. from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  12. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  13. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  14. from ray.rllib.utils.test_utils import check, framework_iterator
  15. tf1, tf, tfv = try_import_tf()
  16. torch, nn = try_import_torch()
  17. class MyKerasModel(TFModelV2):
  18. """Custom model for policy gradient algorithms."""
  19. def __init__(self, obs_space, action_space, num_outputs, model_config,
  20. name):
  21. super(MyKerasModel, self).__init__(obs_space, action_space,
  22. num_outputs, model_config, name)
  23. self.inputs = tf.keras.layers.Input(
  24. shape=obs_space.shape, name="observations")
  25. layer_1 = tf.keras.layers.Dense(
  26. 16,
  27. name="layer1",
  28. activation=tf.nn.relu,
  29. kernel_initializer=normc_initializer(1.0))(self.inputs)
  30. layer_out = tf.keras.layers.Dense(
  31. num_outputs,
  32. name="out",
  33. activation=None,
  34. kernel_initializer=normc_initializer(0.01))(layer_1)
  35. if self.model_config["vf_share_layers"]:
  36. value_out = tf.keras.layers.Dense(
  37. 1,
  38. name="value",
  39. activation=None,
  40. kernel_initializer=normc_initializer(0.01))(layer_1)
  41. self.base_model = tf.keras.Model(self.inputs,
  42. [layer_out, value_out])
  43. else:
  44. self.base_model = tf.keras.Model(self.inputs, layer_out)
  45. def forward(self, input_dict, state, seq_lens):
  46. if self.model_config["vf_share_layers"]:
  47. model_out, self._value_out = self.base_model(input_dict["obs"])
  48. else:
  49. model_out = self.base_model(input_dict["obs"])
  50. self._value_out = tf.zeros(
  51. shape=(tf.shape(input_dict["obs"])[0], ))
  52. return model_out, state
  53. def value_function(self):
  54. return tf.reshape(self._value_out, [-1])
  55. def import_from_h5(self, import_file):
  56. # Override this to define custom weight loading behavior from h5 files.
  57. self.base_model.load_weights(import_file)
  58. class MyTorchModel(TorchModelV2, nn.Module):
  59. """Generic vision network."""
  60. def __init__(self, obs_space, action_space, num_outputs, model_config,
  61. name):
  62. TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
  63. model_config, name)
  64. nn.Module.__init__(self)
  65. self.device = torch.device("cuda"
  66. if torch.cuda.is_available() else "cpu")
  67. self.layer_1 = nn.Linear(obs_space.shape[0], 16).to(self.device)
  68. self.layer_out = nn.Linear(16, num_outputs).to(self.device)
  69. self.value_branch = nn.Linear(16, 1).to(self.device)
  70. self.cur_value = None
  71. def forward(self, input_dict, state, seq_lens):
  72. layer_1_out = self.layer_1(input_dict["obs"])
  73. logits = self.layer_out(layer_1_out)
  74. self.cur_value = self.value_branch(layer_1_out).squeeze(1)
  75. return logits, state
  76. def value_function(self):
  77. assert self.cur_value is not None, "Must call `forward()` first!"
  78. return self.cur_value
  79. def import_from_h5(self, import_file):
  80. # Override this to define custom weight loading behavior from h5 files.
  81. f = h5py.File(import_file)
  82. layer1 = f["layer1"][DEFAULT_POLICY_ID]["layer1"]
  83. out = f["out"][DEFAULT_POLICY_ID]["out"]
  84. value = f["value"][DEFAULT_POLICY_ID]["value"]
  85. try:
  86. self.layer_1.load_state_dict({
  87. "weight": torch.Tensor(np.transpose(layer1["kernel:0"])),
  88. "bias": torch.Tensor(np.transpose(layer1["bias:0"])),
  89. })
  90. self.layer_out.load_state_dict({
  91. "weight": torch.Tensor(np.transpose(out["kernel:0"])),
  92. "bias": torch.Tensor(np.transpose(out["bias:0"])),
  93. })
  94. self.value_branch.load_state_dict({
  95. "weight": torch.Tensor(np.transpose(value["kernel:0"])),
  96. "bias": torch.Tensor(np.transpose(value["bias:0"])),
  97. })
  98. except AttributeError:
  99. self.layer_1.load_state_dict({
  100. "weight": torch.Tensor(np.transpose(layer1["kernel:0"].value)),
  101. "bias": torch.Tensor(np.transpose(layer1["bias:0"].value)),
  102. })
  103. self.layer_out.load_state_dict({
  104. "weight": torch.Tensor(np.transpose(out["kernel:0"].value)),
  105. "bias": torch.Tensor(np.transpose(out["bias:0"].value)),
  106. })
  107. self.value_branch.load_state_dict({
  108. "weight": torch.Tensor(np.transpose(value["kernel:0"].value)),
  109. "bias": torch.Tensor(np.transpose(value["bias:0"].value)),
  110. })
  111. def model_import_test(algo, config, env):
  112. # Get the abs-path to use (bazel-friendly).
  113. rllib_dir = Path(__file__).parent.parent
  114. import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5"
  115. agent_cls = get_trainer_class(algo)
  116. for fw in framework_iterator(config, ["tf", "torch"]):
  117. config["model"]["custom_model"] = "keras_model" if fw != "torch" else \
  118. "torch_model"
  119. agent = agent_cls(config, env)
  120. def current_weight(agent):
  121. if fw == "tf":
  122. return agent.get_weights()[DEFAULT_POLICY_ID][
  123. "default_policy/value/kernel"][0]
  124. elif fw == "torch":
  125. return float(agent.get_weights()[DEFAULT_POLICY_ID][
  126. "value_branch.weight"][0][0])
  127. else:
  128. return agent.get_weights()[DEFAULT_POLICY_ID][4][0]
  129. # Import weights for our custom model from an h5 file.
  130. weight_before_import = current_weight(agent)
  131. agent.import_model(import_file=import_file)
  132. weight_after_import = current_weight(agent)
  133. check(weight_before_import, weight_after_import, false=True)
  134. # Train for a while.
  135. for _ in range(1):
  136. agent.train()
  137. weight_after_train = current_weight(agent)
  138. # Weights should have changed.
  139. check(weight_before_import, weight_after_train, false=True)
  140. check(weight_after_import, weight_after_train, false=True)
  141. # We can save the entire Agent and restore, weights should remain the
  142. # same.
  143. file = agent.save("after_train")
  144. check(weight_after_train, current_weight(agent))
  145. agent.restore(file)
  146. check(weight_after_train, current_weight(agent))
  147. # Import (untrained) weights again.
  148. agent.import_model(import_file=import_file)
  149. check(current_weight(agent), weight_after_import)
  150. class TestModelImport(unittest.TestCase):
  151. def setUp(self):
  152. ray.init()
  153. ModelCatalog.register_custom_model("keras_model", MyKerasModel)
  154. ModelCatalog.register_custom_model("torch_model", MyTorchModel)
  155. def tearDown(self):
  156. ray.shutdown()
  157. def test_ppo(self):
  158. model_import_test(
  159. "PPO",
  160. config={
  161. "num_workers": 0,
  162. "model": {
  163. "vf_share_layers": True,
  164. },
  165. },
  166. env="CartPole-v0")
  167. if __name__ == "__main__":
  168. import pytest
  169. import sys
  170. sys.exit(pytest.main(["-v", __file__]))