123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628 |
- from gym import Env
- from gym.spaces import Box, Dict, Discrete, Tuple
- import numpy as np
- import re
- import unittest
- import ray
- import ray.rllib.agents.sac as sac
- from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
- from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
- loss_torch
- from ray.rllib.examples.env.random_env import RandomEnv
- from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \
- TorchBatchNormModel
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.tf.tf_action_dist import Dirichlet
- from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
- from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
- MultiAgentReplayBuffer
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.numpy import fc, huber_loss, relu
- from ray.rllib.utils.spaces.simplex import Simplex
- from ray.rllib.utils.test_utils import check, check_compute_single_action, \
- check_train_results, framework_iterator
- from ray.rllib.utils.torch_utils import convert_to_torch_tensor
- from ray import tune
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- class SimpleEnv(Env):
- def __init__(self, config):
- if config.get("simplex_actions", False):
- self.action_space = Simplex((2, ))
- else:
- self.action_space = Box(0.0, 1.0, (1, ))
- self.observation_space = Box(0.0, 1.0, (1, ))
- self.max_steps = config.get("max_steps", 100)
- self.state = None
- self.steps = None
- def reset(self):
- self.state = self.observation_space.sample()
- self.steps = 0
- return self.state
- def step(self, action):
- self.steps += 1
- # Reward is 1.0 - (max(actions) - state).
- [r] = 1.0 - np.abs(np.max(action) - self.state)
- d = self.steps >= self.max_steps
- self.state = self.observation_space.sample()
- return self.state, r, d, {}
- class TestSAC(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- np.random.seed(42)
- torch.manual_seed(42)
- ray.init()
- @classmethod
- def tearDownClass(cls) -> None:
- ray.shutdown()
- def test_sac_compilation(self):
- """Tests whether an SACTrainer can be built with all frameworks."""
- config = sac.DEFAULT_CONFIG.copy()
- config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
- config["num_workers"] = 0 # Run locally.
- config["n_step"] = 3
- config["twin_q"] = True
- config["learning_starts"] = 0
- config["prioritized_replay"] = True
- config["rollout_fragment_length"] = 10
- config["train_batch_size"] = 10
- # If we use default buffer size (1e6), the buffer will take up
- # 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
- # available system memory (8.34816 GB).
- config["buffer_size"] = 40000
- # Test with saved replay buffer.
- config["store_buffer_in_checkpoints"] = True
- num_iterations = 1
- ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
- ModelCatalog.register_custom_model("batch_norm_torch",
- TorchBatchNormModel)
- image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
- simple_space = Box(-1.0, 1.0, shape=(3, ))
- tune.register_env(
- "random_dict_env", lambda _: RandomEnv({
- "observation_space": Dict({
- "a": simple_space,
- "b": Discrete(2),
- "c": image_space, }),
- "action_space": Box(-1.0, 1.0, shape=(1, )), }))
- tune.register_env(
- "random_tuple_env", lambda _: RandomEnv({
- "observation_space": Tuple([
- simple_space, Discrete(2), image_space]),
- "action_space": Box(-1.0, 1.0, shape=(1, )), }))
- for fw in framework_iterator(config, with_eager_tracing=True):
- # Test for different env types (discrete w/ and w/o image, + cont).
- for env in [
- "random_dict_env",
- "random_tuple_env",
- # "MsPacmanNoFrameskip-v4",
- "CartPole-v0",
- ]:
- print("Env={}".format(env))
- # Test making the Q-model a custom one for CartPole, otherwise,
- # use the default model.
- config["Q_model"]["custom_model"] = "batch_norm{}".format(
- "_torch"
- if fw == "torch" else "") if env == "CartPole-v0" else None
- trainer = sac.SACTrainer(config=config, env=env)
- for i in range(num_iterations):
- results = trainer.train()
- check_train_results(results)
- print(results)
- check_compute_single_action(trainer)
- # Test, whether the replay buffer is saved along with
- # a checkpoint (no point in doing it for all frameworks since
- # this is framework agnostic).
- if fw == "tf" and env == "CartPole-v0":
- checkpoint = trainer.save()
- new_trainer = sac.SACTrainer(config, env=env)
- new_trainer.restore(checkpoint)
- # Get some data from the buffer and compare.
- data = trainer.local_replay_buffer.replay_buffers[
- "default_policy"]._storage[:42 + 42]
- new_data = new_trainer.local_replay_buffer.replay_buffers[
- "default_policy"]._storage[:42 + 42]
- check(data, new_data)
- new_trainer.stop()
- trainer.stop()
- def test_sac_loss_function(self):
- """Tests SAC loss function results across all frameworks."""
- config = sac.DEFAULT_CONFIG.copy()
- # Run locally.
- config["num_workers"] = 0
- config["learning_starts"] = 0
- config["twin_q"] = False
- config["gamma"] = 0.99
- # Switch on deterministic loss so we can compare the loss values.
- config["_deterministic_loss"] = True
- # Use very simple nets.
- config["Q_model"]["fcnet_hiddens"] = [10]
- config["policy_model"]["fcnet_hiddens"] = [10]
- # Make sure, timing differences do not affect trainer.train().
- config["min_time_s_per_reporting"] = 0
- # Test SAC with Simplex action space.
- config["env_config"] = {"simplex_actions": True}
- map_ = {
- # Action net.
- "default_policy/fc_1/kernel": "action_model._hidden_layers.0."
- "_model.0.weight",
- "default_policy/fc_1/bias": "action_model._hidden_layers.0."
- "_model.0.bias",
- "default_policy/fc_out/kernel": "action_model."
- "_logits._model.0.weight",
- "default_policy/fc_out/bias": "action_model._logits._model.0.bias",
- "default_policy/value_out/kernel": "action_model."
- "_value_branch._model.0.weight",
- "default_policy/value_out/bias": "action_model."
- "_value_branch._model.0.bias",
- # Q-net.
- "default_policy/fc_1_1/kernel": "q_net."
- "_hidden_layers.0._model.0.weight",
- "default_policy/fc_1_1/bias": "q_net."
- "_hidden_layers.0._model.0.bias",
- "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
- "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
- "default_policy/value_out_1/kernel": "q_net."
- "_value_branch._model.0.weight",
- "default_policy/value_out_1/bias": "q_net."
- "_value_branch._model.0.bias",
- "default_policy/log_alpha": "log_alpha",
- # Target action-net.
- "default_policy/fc_1_2/kernel": "action_model."
- "_hidden_layers.0._model.0.weight",
- "default_policy/fc_1_2/bias": "action_model."
- "_hidden_layers.0._model.0.bias",
- "default_policy/fc_out_2/kernel": "action_model."
- "_logits._model.0.weight",
- "default_policy/fc_out_2/bias": "action_model."
- "_logits._model.0.bias",
- "default_policy/value_out_2/kernel": "action_model."
- "_value_branch._model.0.weight",
- "default_policy/value_out_2/bias": "action_model."
- "_value_branch._model.0.bias",
- # Target Q-net
- "default_policy/fc_1_3/kernel": "q_net."
- "_hidden_layers.0._model.0.weight",
- "default_policy/fc_1_3/bias": "q_net."
- "_hidden_layers.0._model.0.bias",
- "default_policy/fc_out_3/kernel": "q_net."
- "_logits._model.0.weight",
- "default_policy/fc_out_3/bias": "q_net."
- "_logits._model.0.bias",
- "default_policy/value_out_3/kernel": "q_net."
- "_value_branch._model.0.weight",
- "default_policy/value_out_3/bias": "q_net."
- "_value_branch._model.0.bias",
- "default_policy/log_alpha_1": "log_alpha",
- }
- env = SimpleEnv
- batch_size = 100
- obs_size = (batch_size, 1)
- actions = np.random.random(size=(batch_size, 2))
- # Batch of size=n.
- input_ = self._get_batch_helper(obs_size, actions, batch_size)
- # Simply compare loss values AND grads of all frameworks with each
- # other.
- prev_fw_loss = weights_dict = None
- expect_c, expect_a, expect_e, expect_t = None, None, None, None
- # History of tf-updated NN-weights over n training steps.
- tf_updated_weights = []
- # History of input batches used.
- tf_inputs = []
- for fw, sess in framework_iterator(
- config, frameworks=("tf", "torch"), session=True):
- # Generate Trainer and get its default Policy object.
- trainer = sac.SACTrainer(config=config, env=env)
- policy = trainer.get_policy()
- p_sess = None
- if sess:
- p_sess = policy.get_session()
- # Set all weights (of all nets) to fixed values.
- if weights_dict is None:
- # Start with the tf vars-dict.
- assert fw in ["tf2", "tf", "tfe"]
- weights_dict = policy.get_weights()
- if fw == "tfe":
- log_alpha = weights_dict[10]
- weights_dict = self._translate_tfe_weights(
- weights_dict, map_)
- else:
- assert fw == "torch" # Then transfer that to torch Model.
- model_dict = self._translate_weights_to_torch(
- weights_dict, map_)
- # Have to add this here (not a parameter in tf, but must be
- # one in torch, so it gets properly copied to the GPU(s)).
- model_dict["target_entropy"] = policy.model.target_entropy
- policy.model.load_state_dict(model_dict)
- policy.target_model.load_state_dict(model_dict)
- if fw == "tf":
- log_alpha = weights_dict["default_policy/log_alpha"]
- elif fw == "torch":
- # Actually convert to torch tensors (by accessing everything).
- input_ = policy._lazy_tensor_dict(input_)
- input_ = {k: input_[k] for k in input_.keys()}
- log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]
- # Only run the expectation once, should be the same anyways
- # for all frameworks.
- if expect_c is None:
- expect_c, expect_a, expect_e, expect_t = \
- self._sac_loss_helper(input_, weights_dict,
- sorted(weights_dict.keys()),
- log_alpha, fw,
- gamma=config["gamma"], sess=sess)
- # Get actual outs and compare to expectation AND previous
- # framework. c=critic, a=actor, e=entropy, t=td-error.
- if fw == "tf":
- c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \
- p_sess.run([
- policy.critic_loss,
- policy.actor_loss,
- policy.alpha_loss,
- policy.td_error,
- policy.optimizer().compute_gradients(
- policy.critic_loss[0],
- [v for v in policy.model.q_variables() if
- "value_" not in v.name]),
- policy.optimizer().compute_gradients(
- policy.actor_loss,
- [v for v in policy.model.policy_variables() if
- "value_" not in v.name]),
- policy.optimizer().compute_gradients(
- policy.alpha_loss, policy.model.log_alpha)],
- feed_dict=policy._get_loss_inputs_dict(
- input_, shuffle=False))
- tf_c_grads = [g for g, v in tf_c_grads]
- tf_a_grads = [g for g, v in tf_a_grads]
- tf_e_grads = [g for g, v in tf_e_grads]
- elif fw == "tfe":
- with tf.GradientTape() as tape:
- tf_loss(policy, policy.model, None, input_)
- c, a, e, t = policy.critic_loss, policy.actor_loss, \
- policy.alpha_loss, policy.td_error
- vars = tape.watched_variables()
- tf_c_grads = tape.gradient(c[0], vars[6:10])
- tf_a_grads = tape.gradient(a, vars[2:6])
- tf_e_grads = tape.gradient(e, vars[10])
- elif fw == "torch":
- loss_torch(policy, policy.model, None, input_)
- c, a, e, t = policy.get_tower_stats("critic_loss")[0], \
- policy.get_tower_stats("actor_loss")[0], \
- policy.get_tower_stats("alpha_loss")[0], \
- policy.get_tower_stats("td_error")[0]
- # Test actor gradients.
- policy.actor_optim.zero_grad()
- assert all(v.grad is None for v in policy.model.q_variables())
- assert all(
- v.grad is None for v in policy.model.policy_variables())
- assert policy.model.log_alpha.grad is None
- a.backward()
- # `actor_loss` depends on Q-net vars (but these grads must
- # be ignored and overridden in critic_loss.backward!).
- assert not all(
- torch.mean(v.grad) == 0
- for v in policy.model.policy_variables())
- assert not all(
- torch.min(v.grad) == 0
- for v in policy.model.policy_variables())
- assert policy.model.log_alpha.grad is None
- # Compare with tf ones.
- torch_a_grads = [
- v.grad for v in policy.model.policy_variables()
- if v.grad is not None
- ]
- check(tf_a_grads[2],
- np.transpose(torch_a_grads[0].detach().cpu()))
- # Test critic gradients.
- policy.critic_optims[0].zero_grad()
- assert all(
- torch.mean(v.grad) == 0.0
- for v in policy.model.q_variables() if v.grad is not None)
- assert all(
- torch.min(v.grad) == 0.0
- for v in policy.model.q_variables() if v.grad is not None)
- assert policy.model.log_alpha.grad is None
- c[0].backward()
- assert not all(
- torch.mean(v.grad) == 0
- for v in policy.model.q_variables() if v.grad is not None)
- assert not all(
- torch.min(v.grad) == 0 for v in policy.model.q_variables()
- if v.grad is not None)
- assert policy.model.log_alpha.grad is None
- # Compare with tf ones.
- torch_c_grads = [v.grad for v in policy.model.q_variables()]
- check(tf_c_grads[0],
- np.transpose(torch_c_grads[2].detach().cpu()))
- # Compare (unchanged(!) actor grads) with tf ones.
- torch_a_grads = [
- v.grad for v in policy.model.policy_variables()
- ]
- check(tf_a_grads[2],
- np.transpose(torch_a_grads[0].detach().cpu()))
- # Test alpha gradient.
- policy.alpha_optim.zero_grad()
- assert policy.model.log_alpha.grad is None
- e.backward()
- assert policy.model.log_alpha.grad is not None
- check(policy.model.log_alpha.grad, tf_e_grads)
- check(c, expect_c)
- check(a, expect_a)
- check(e, expect_e)
- check(t, expect_t)
- # Store this framework's losses in prev_fw_loss to compare with
- # next framework's outputs.
- if prev_fw_loss is not None:
- check(c, prev_fw_loss[0])
- check(a, prev_fw_loss[1])
- check(e, prev_fw_loss[2])
- check(t, prev_fw_loss[3])
- prev_fw_loss = (c, a, e, t)
- # Update weights from our batch (n times).
- for update_iteration in range(5):
- print("train iteration {}".format(update_iteration))
- if fw == "tf":
- in_ = self._get_batch_helper(obs_size, actions, batch_size)
- tf_inputs.append(in_)
- # Set a fake-batch to use
- # (instead of sampling from replay buffer).
- buf = MultiAgentReplayBuffer.get_instance_for_testing()
- buf._fake_batch = in_
- trainer.train()
- updated_weights = policy.get_weights()
- # Net must have changed.
- if tf_updated_weights:
- check(
- updated_weights["default_policy/fc_1/kernel"],
- tf_updated_weights[-1][
- "default_policy/fc_1/kernel"],
- false=True)
- tf_updated_weights.append(updated_weights)
- # Compare with updated tf-weights. Must all be the same.
- else:
- tf_weights = tf_updated_weights[update_iteration]
- in_ = tf_inputs[update_iteration]
- # Set a fake-batch to use
- # (instead of sampling from replay buffer).
- buf = MultiAgentReplayBuffer.get_instance_for_testing()
- buf._fake_batch = in_
- trainer.train()
- # Compare updated model.
- for tf_key in sorted(tf_weights.keys()):
- if re.search("_[23]|alpha", tf_key):
- continue
- tf_var = tf_weights[tf_key]
- torch_var = policy.model.state_dict()[map_[tf_key]]
- if tf_var.shape != torch_var.shape:
- check(
- tf_var,
- np.transpose(torch_var.detach().cpu()),
- atol=0.003)
- else:
- check(tf_var, torch_var, atol=0.003)
- # And alpha.
- check(policy.model.log_alpha,
- tf_weights["default_policy/log_alpha"])
- # Compare target nets.
- for tf_key in sorted(tf_weights.keys()):
- if not re.search("_[23]", tf_key):
- continue
- tf_var = tf_weights[tf_key]
- torch_var = policy.target_model.state_dict()[map_[
- tf_key]]
- if tf_var.shape != torch_var.shape:
- check(
- tf_var,
- np.transpose(torch_var.detach().cpu()),
- atol=0.003)
- else:
- check(tf_var, torch_var, atol=0.003)
- trainer.stop()
- def _get_batch_helper(self, obs_size, actions, batch_size):
- return SampleBatch({
- SampleBatch.CUR_OBS: np.random.random(size=obs_size),
- SampleBatch.ACTIONS: actions,
- SampleBatch.REWARDS: np.random.random(size=(batch_size, )),
- SampleBatch.DONES: np.random.choice(
- [True, False], size=(batch_size, )),
- SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
- "weights": np.random.random(size=(batch_size, )),
- })
- def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma,
- sess):
- """Emulates SAC loss functions for tf and torch."""
- # ks:
- # 0=log_alpha
- # 1=target log-alpha (not used)
- # 2=action hidden bias
- # 3=action hidden kernel
- # 4=action out bias
- # 5=action out kernel
- # 6=Q hidden bias
- # 7=Q hidden kernel
- # 8=Q out bias
- # 9=Q out kernel
- # 14=target Q hidden bias
- # 15=target Q hidden kernel
- # 16=target Q out bias
- # 17=target Q out kernel
- alpha = np.exp(log_alpha)
- # cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian
- cls = TorchDirichlet if fw == "torch" else Dirichlet
- model_out_t = train_batch[SampleBatch.CUR_OBS]
- model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
- target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
- # get_policy_output
- action_dist_t = cls(
- fc(
- relu(
- fc(model_out_t,
- weights[ks[1]],
- weights[ks[0]],
- framework=fw)), weights[ks[9]], weights[ks[8]]), None)
- policy_t = action_dist_t.deterministic_sample()
- log_pis_t = action_dist_t.logp(policy_t)
- if sess:
- log_pis_t = sess.run(log_pis_t)
- policy_t = sess.run(policy_t)
- log_pis_t = np.expand_dims(log_pis_t, -1)
- # Get policy output for t+1.
- action_dist_tp1 = cls(
- fc(
- relu(
- fc(model_out_tp1,
- weights[ks[1]],
- weights[ks[0]],
- framework=fw)), weights[ks[9]], weights[ks[8]]), None)
- policy_tp1 = action_dist_tp1.deterministic_sample()
- log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
- if sess:
- log_pis_tp1 = sess.run(log_pis_tp1)
- policy_tp1 = sess.run(policy_tp1)
- log_pis_tp1 = np.expand_dims(log_pis_tp1, -1)
- # Q-values for the actually selected actions.
- # get_q_values
- q_t = fc(
- relu(
- fc(np.concatenate(
- [model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
- weights[ks[3]],
- weights[ks[2]],
- framework=fw)),
- weights[ks[11]],
- weights[ks[10]],
- framework=fw)
- # Q-values for current policy in given current state.
- # get_q_values
- q_t_det_policy = fc(
- relu(
- fc(np.concatenate([model_out_t, policy_t], -1),
- weights[ks[3]],
- weights[ks[2]],
- framework=fw)),
- weights[ks[11]],
- weights[ks[10]],
- framework=fw)
- # Target q network evaluation.
- # target_model.get_q_values
- if fw == "tf":
- q_tp1 = fc(
- relu(
- fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
- weights[ks[7]],
- weights[ks[6]],
- framework=fw)),
- weights[ks[15]],
- weights[ks[14]],
- framework=fw)
- else:
- assert fw == "tfe"
- q_tp1 = fc(
- relu(
- fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
- weights[ks[7]],
- weights[ks[6]],
- framework=fw)),
- weights[ks[9]],
- weights[ks[8]],
- framework=fw)
- q_t_selected = np.squeeze(q_t, axis=-1)
- q_tp1 -= alpha * log_pis_tp1
- q_tp1_best = np.squeeze(q_tp1, axis=-1)
- dones = train_batch[SampleBatch.DONES]
- rewards = train_batch[SampleBatch.REWARDS]
- if fw == "torch":
- dones = dones.float().numpy()
- rewards = rewards.numpy()
- q_tp1_best_masked = (1.0 - dones) * q_tp1_best
- q_t_selected_target = rewards + gamma * q_tp1_best_masked
- base_td_error = np.abs(q_t_selected - q_t_selected_target)
- td_error = base_td_error
- critic_loss = [
- np.mean(train_batch["weights"] *
- huber_loss(q_t_selected_target - q_t_selected))
- ]
- target_entropy = -np.prod((1, ))
- alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))
- actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy)
- return critic_loss, actor_loss, alpha_loss, td_error
- def _translate_weights_to_torch(self, weights_dict, map_):
- model_dict = {
- map_[k]: convert_to_torch_tensor(
- np.transpose(v) if re.search("kernel", k) else np.array([v])
- if re.search("log_alpha", k) else v)
- for i, (k, v) in enumerate(weights_dict.items()) if i < 13
- }
- return model_dict
- def _translate_tfe_weights(self, weights_dict, map_):
- model_dict = {
- "default_policy/log_alpha": None,
- "default_policy/log_alpha_target": None,
- "default_policy/sequential/action_1/kernel": weights_dict[2],
- "default_policy/sequential/action_1/bias": weights_dict[3],
- "default_policy/sequential/action_out/kernel": weights_dict[4],
- "default_policy/sequential/action_out/bias": weights_dict[5],
- "default_policy/sequential_1/q_hidden_0/kernel": weights_dict[6],
- "default_policy/sequential_1/q_hidden_0/bias": weights_dict[7],
- "default_policy/sequential_1/q_out/kernel": weights_dict[8],
- "default_policy/sequential_1/q_out/bias": weights_dict[9],
- "default_policy/value_out/kernel": weights_dict[0],
- "default_policy/value_out/bias": weights_dict[1],
- }
- return model_dict
- if __name__ == "__main__":
- import pytest
- import sys
- sys.exit(pytest.main(["-v", __file__]))
|