123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- #!/usr/bin/env python
- import h5py
- import numpy as np
- from pathlib import Path
- import unittest
- import ray
- from ray.rllib.agents.registry import get_trainer_class
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.tf.misc import normc_initializer
- from ray.rllib.models.tf.tf_modelv2 import TFModelV2
- from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.test_utils import check, framework_iterator
- tf1, tf, tfv = try_import_tf()
- torch, nn = try_import_torch()
- class MyKerasModel(TFModelV2):
- """Custom model for policy gradient algorithms."""
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- super(MyKerasModel, self).__init__(obs_space, action_space,
- num_outputs, model_config, name)
- self.inputs = tf.keras.layers.Input(
- shape=obs_space.shape, name="observations")
- layer_1 = tf.keras.layers.Dense(
- 16,
- name="layer1",
- activation=tf.nn.relu,
- kernel_initializer=normc_initializer(1.0))(self.inputs)
- layer_out = tf.keras.layers.Dense(
- num_outputs,
- name="out",
- activation=None,
- kernel_initializer=normc_initializer(0.01))(layer_1)
- if self.model_config["vf_share_layers"]:
- value_out = tf.keras.layers.Dense(
- 1,
- name="value",
- activation=None,
- kernel_initializer=normc_initializer(0.01))(layer_1)
- self.base_model = tf.keras.Model(self.inputs,
- [layer_out, value_out])
- else:
- self.base_model = tf.keras.Model(self.inputs, layer_out)
- def forward(self, input_dict, state, seq_lens):
- if self.model_config["vf_share_layers"]:
- model_out, self._value_out = self.base_model(input_dict["obs"])
- else:
- model_out = self.base_model(input_dict["obs"])
- self._value_out = tf.zeros(
- shape=(tf.shape(input_dict["obs"])[0], ))
- return model_out, state
- def value_function(self):
- return tf.reshape(self._value_out, [-1])
- def import_from_h5(self, import_file):
- # Override this to define custom weight loading behavior from h5 files.
- self.base_model.load_weights(import_file)
- class MyTorchModel(TorchModelV2, nn.Module):
- """Generic vision network."""
- def __init__(self, obs_space, action_space, num_outputs, model_config,
- name):
- TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
- model_config, name)
- nn.Module.__init__(self)
- self.device = torch.device("cuda"
- if torch.cuda.is_available() else "cpu")
- self.layer_1 = nn.Linear(obs_space.shape[0], 16).to(self.device)
- self.layer_out = nn.Linear(16, num_outputs).to(self.device)
- self.value_branch = nn.Linear(16, 1).to(self.device)
- self.cur_value = None
- def forward(self, input_dict, state, seq_lens):
- layer_1_out = self.layer_1(input_dict["obs"])
- logits = self.layer_out(layer_1_out)
- self.cur_value = self.value_branch(layer_1_out).squeeze(1)
- return logits, state
- def value_function(self):
- assert self.cur_value is not None, "Must call `forward()` first!"
- return self.cur_value
- def import_from_h5(self, import_file):
- # Override this to define custom weight loading behavior from h5 files.
- f = h5py.File(import_file)
- layer1 = f["layer1"][DEFAULT_POLICY_ID]["layer1"]
- out = f["out"][DEFAULT_POLICY_ID]["out"]
- value = f["value"][DEFAULT_POLICY_ID]["value"]
- try:
- self.layer_1.load_state_dict({
- "weight": torch.Tensor(np.transpose(layer1["kernel:0"])),
- "bias": torch.Tensor(np.transpose(layer1["bias:0"])),
- })
- self.layer_out.load_state_dict({
- "weight": torch.Tensor(np.transpose(out["kernel:0"])),
- "bias": torch.Tensor(np.transpose(out["bias:0"])),
- })
- self.value_branch.load_state_dict({
- "weight": torch.Tensor(np.transpose(value["kernel:0"])),
- "bias": torch.Tensor(np.transpose(value["bias:0"])),
- })
- except AttributeError:
- self.layer_1.load_state_dict({
- "weight": torch.Tensor(np.transpose(layer1["kernel:0"].value)),
- "bias": torch.Tensor(np.transpose(layer1["bias:0"].value)),
- })
- self.layer_out.load_state_dict({
- "weight": torch.Tensor(np.transpose(out["kernel:0"].value)),
- "bias": torch.Tensor(np.transpose(out["bias:0"].value)),
- })
- self.value_branch.load_state_dict({
- "weight": torch.Tensor(np.transpose(value["kernel:0"].value)),
- "bias": torch.Tensor(np.transpose(value["bias:0"].value)),
- })
- def model_import_test(algo, config, env):
- # Get the abs-path to use (bazel-friendly).
- rllib_dir = Path(__file__).parent.parent
- import_file = str(rllib_dir) + "/tests/data/model_weights/weights.h5"
- agent_cls = get_trainer_class(algo)
- for fw in framework_iterator(config, ["tf", "torch"]):
- config["model"]["custom_model"] = "keras_model" if fw != "torch" else \
- "torch_model"
- agent = agent_cls(config, env)
- def current_weight(agent):
- if fw == "tf":
- return agent.get_weights()[DEFAULT_POLICY_ID][
- "default_policy/value/kernel"][0]
- elif fw == "torch":
- return float(agent.get_weights()[DEFAULT_POLICY_ID][
- "value_branch.weight"][0][0])
- else:
- return agent.get_weights()[DEFAULT_POLICY_ID][4][0]
- # Import weights for our custom model from an h5 file.
- weight_before_import = current_weight(agent)
- agent.import_model(import_file=import_file)
- weight_after_import = current_weight(agent)
- check(weight_before_import, weight_after_import, false=True)
- # Train for a while.
- for _ in range(1):
- agent.train()
- weight_after_train = current_weight(agent)
- # Weights should have changed.
- check(weight_before_import, weight_after_train, false=True)
- check(weight_after_import, weight_after_train, false=True)
- # We can save the entire Agent and restore, weights should remain the
- # same.
- file = agent.save("after_train")
- check(weight_after_train, current_weight(agent))
- agent.restore(file)
- check(weight_after_train, current_weight(agent))
- # Import (untrained) weights again.
- agent.import_model(import_file=import_file)
- check(current_weight(agent), weight_after_import)
- class TestModelImport(unittest.TestCase):
- def setUp(self):
- ray.init()
- ModelCatalog.register_custom_model("keras_model", MyKerasModel)
- ModelCatalog.register_custom_model("torch_model", MyTorchModel)
- def tearDown(self):
- ray.shutdown()
- def test_ppo(self):
- model_import_test(
- "PPO",
- config={
- "num_workers": 0,
- "model": {
- "vf_share_layers": True,
- },
- },
- env="CartPole-v0")
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|