test_sac.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. from gym import Env
  2. from gym.spaces import Box, Dict, Discrete, Tuple
  3. import numpy as np
  4. import re
  5. import unittest
  6. import ray
  7. import ray.rllib.agents.sac as sac
  8. from ray.rllib.agents.sac.sac_tf_policy import sac_actor_critic_loss as tf_loss
  9. from ray.rllib.agents.sac.sac_torch_policy import actor_critic_loss as \
  10. loss_torch
  11. from ray.rllib.examples.env.random_env import RandomEnv
  12. from ray.rllib.examples.models.batch_norm_model import KerasBatchNormModel, \
  13. TorchBatchNormModel
  14. from ray.rllib.models.catalog import ModelCatalog
  15. from ray.rllib.models.tf.tf_action_dist import Dirichlet
  16. from ray.rllib.models.torch.torch_action_dist import TorchDirichlet
  17. from ray.rllib.execution.buffers.multi_agent_replay_buffer import \
  18. MultiAgentReplayBuffer
  19. from ray.rllib.policy.sample_batch import SampleBatch
  20. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  21. from ray.rllib.utils.numpy import fc, huber_loss, relu
  22. from ray.rllib.utils.spaces.simplex import Simplex
  23. from ray.rllib.utils.test_utils import check, check_compute_single_action, \
  24. check_train_results, framework_iterator
  25. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  26. from ray import tune
  27. tf1, tf, tfv = try_import_tf()
  28. torch, _ = try_import_torch()
  29. class SimpleEnv(Env):
  30. def __init__(self, config):
  31. if config.get("simplex_actions", False):
  32. self.action_space = Simplex((2, ))
  33. else:
  34. self.action_space = Box(0.0, 1.0, (1, ))
  35. self.observation_space = Box(0.0, 1.0, (1, ))
  36. self.max_steps = config.get("max_steps", 100)
  37. self.state = None
  38. self.steps = None
  39. def reset(self):
  40. self.state = self.observation_space.sample()
  41. self.steps = 0
  42. return self.state
  43. def step(self, action):
  44. self.steps += 1
  45. # Reward is 1.0 - (max(actions) - state).
  46. [r] = 1.0 - np.abs(np.max(action) - self.state)
  47. d = self.steps >= self.max_steps
  48. self.state = self.observation_space.sample()
  49. return self.state, r, d, {}
  50. class TestSAC(unittest.TestCase):
  51. @classmethod
  52. def setUpClass(cls) -> None:
  53. np.random.seed(42)
  54. torch.manual_seed(42)
  55. ray.init()
  56. @classmethod
  57. def tearDownClass(cls) -> None:
  58. ray.shutdown()
  59. def test_sac_compilation(self):
  60. """Tests whether an SACTrainer can be built with all frameworks."""
  61. config = sac.DEFAULT_CONFIG.copy()
  62. config["Q_model"] = sac.DEFAULT_CONFIG["Q_model"].copy()
  63. config["num_workers"] = 0 # Run locally.
  64. config["n_step"] = 3
  65. config["twin_q"] = True
  66. config["learning_starts"] = 0
  67. config["prioritized_replay"] = True
  68. config["rollout_fragment_length"] = 10
  69. config["train_batch_size"] = 10
  70. # If we use default buffer size (1e6), the buffer will take up
  71. # 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
  72. # available system memory (8.34816 GB).
  73. config["buffer_size"] = 40000
  74. # Test with saved replay buffer.
  75. config["store_buffer_in_checkpoints"] = True
  76. num_iterations = 1
  77. ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
  78. ModelCatalog.register_custom_model("batch_norm_torch",
  79. TorchBatchNormModel)
  80. image_space = Box(-1.0, 1.0, shape=(84, 84, 3))
  81. simple_space = Box(-1.0, 1.0, shape=(3, ))
  82. tune.register_env(
  83. "random_dict_env", lambda _: RandomEnv({
  84. "observation_space": Dict({
  85. "a": simple_space,
  86. "b": Discrete(2),
  87. "c": image_space, }),
  88. "action_space": Box(-1.0, 1.0, shape=(1, )), }))
  89. tune.register_env(
  90. "random_tuple_env", lambda _: RandomEnv({
  91. "observation_space": Tuple([
  92. simple_space, Discrete(2), image_space]),
  93. "action_space": Box(-1.0, 1.0, shape=(1, )), }))
  94. for fw in framework_iterator(config, with_eager_tracing=True):
  95. # Test for different env types (discrete w/ and w/o image, + cont).
  96. for env in [
  97. "random_dict_env",
  98. "random_tuple_env",
  99. # "MsPacmanNoFrameskip-v4",
  100. "CartPole-v0",
  101. ]:
  102. print("Env={}".format(env))
  103. # Test making the Q-model a custom one for CartPole, otherwise,
  104. # use the default model.
  105. config["Q_model"]["custom_model"] = "batch_norm{}".format(
  106. "_torch"
  107. if fw == "torch" else "") if env == "CartPole-v0" else None
  108. trainer = sac.SACTrainer(config=config, env=env)
  109. for i in range(num_iterations):
  110. results = trainer.train()
  111. check_train_results(results)
  112. print(results)
  113. check_compute_single_action(trainer)
  114. # Test, whether the replay buffer is saved along with
  115. # a checkpoint (no point in doing it for all frameworks since
  116. # this is framework agnostic).
  117. if fw == "tf" and env == "CartPole-v0":
  118. checkpoint = trainer.save()
  119. new_trainer = sac.SACTrainer(config, env=env)
  120. new_trainer.restore(checkpoint)
  121. # Get some data from the buffer and compare.
  122. data = trainer.local_replay_buffer.replay_buffers[
  123. "default_policy"]._storage[:42 + 42]
  124. new_data = new_trainer.local_replay_buffer.replay_buffers[
  125. "default_policy"]._storage[:42 + 42]
  126. check(data, new_data)
  127. new_trainer.stop()
  128. trainer.stop()
  129. def test_sac_loss_function(self):
  130. """Tests SAC loss function results across all frameworks."""
  131. config = sac.DEFAULT_CONFIG.copy()
  132. # Run locally.
  133. config["num_workers"] = 0
  134. config["learning_starts"] = 0
  135. config["twin_q"] = False
  136. config["gamma"] = 0.99
  137. # Switch on deterministic loss so we can compare the loss values.
  138. config["_deterministic_loss"] = True
  139. # Use very simple nets.
  140. config["Q_model"]["fcnet_hiddens"] = [10]
  141. config["policy_model"]["fcnet_hiddens"] = [10]
  142. # Make sure, timing differences do not affect trainer.train().
  143. config["min_time_s_per_reporting"] = 0
  144. # Test SAC with Simplex action space.
  145. config["env_config"] = {"simplex_actions": True}
  146. map_ = {
  147. # Action net.
  148. "default_policy/fc_1/kernel": "action_model._hidden_layers.0."
  149. "_model.0.weight",
  150. "default_policy/fc_1/bias": "action_model._hidden_layers.0."
  151. "_model.0.bias",
  152. "default_policy/fc_out/kernel": "action_model."
  153. "_logits._model.0.weight",
  154. "default_policy/fc_out/bias": "action_model._logits._model.0.bias",
  155. "default_policy/value_out/kernel": "action_model."
  156. "_value_branch._model.0.weight",
  157. "default_policy/value_out/bias": "action_model."
  158. "_value_branch._model.0.bias",
  159. # Q-net.
  160. "default_policy/fc_1_1/kernel": "q_net."
  161. "_hidden_layers.0._model.0.weight",
  162. "default_policy/fc_1_1/bias": "q_net."
  163. "_hidden_layers.0._model.0.bias",
  164. "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
  165. "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
  166. "default_policy/value_out_1/kernel": "q_net."
  167. "_value_branch._model.0.weight",
  168. "default_policy/value_out_1/bias": "q_net."
  169. "_value_branch._model.0.bias",
  170. "default_policy/log_alpha": "log_alpha",
  171. # Target action-net.
  172. "default_policy/fc_1_2/kernel": "action_model."
  173. "_hidden_layers.0._model.0.weight",
  174. "default_policy/fc_1_2/bias": "action_model."
  175. "_hidden_layers.0._model.0.bias",
  176. "default_policy/fc_out_2/kernel": "action_model."
  177. "_logits._model.0.weight",
  178. "default_policy/fc_out_2/bias": "action_model."
  179. "_logits._model.0.bias",
  180. "default_policy/value_out_2/kernel": "action_model."
  181. "_value_branch._model.0.weight",
  182. "default_policy/value_out_2/bias": "action_model."
  183. "_value_branch._model.0.bias",
  184. # Target Q-net
  185. "default_policy/fc_1_3/kernel": "q_net."
  186. "_hidden_layers.0._model.0.weight",
  187. "default_policy/fc_1_3/bias": "q_net."
  188. "_hidden_layers.0._model.0.bias",
  189. "default_policy/fc_out_3/kernel": "q_net."
  190. "_logits._model.0.weight",
  191. "default_policy/fc_out_3/bias": "q_net."
  192. "_logits._model.0.bias",
  193. "default_policy/value_out_3/kernel": "q_net."
  194. "_value_branch._model.0.weight",
  195. "default_policy/value_out_3/bias": "q_net."
  196. "_value_branch._model.0.bias",
  197. "default_policy/log_alpha_1": "log_alpha",
  198. }
  199. env = SimpleEnv
  200. batch_size = 100
  201. obs_size = (batch_size, 1)
  202. actions = np.random.random(size=(batch_size, 2))
  203. # Batch of size=n.
  204. input_ = self._get_batch_helper(obs_size, actions, batch_size)
  205. # Simply compare loss values AND grads of all frameworks with each
  206. # other.
  207. prev_fw_loss = weights_dict = None
  208. expect_c, expect_a, expect_e, expect_t = None, None, None, None
  209. # History of tf-updated NN-weights over n training steps.
  210. tf_updated_weights = []
  211. # History of input batches used.
  212. tf_inputs = []
  213. for fw, sess in framework_iterator(
  214. config, frameworks=("tf", "torch"), session=True):
  215. # Generate Trainer and get its default Policy object.
  216. trainer = sac.SACTrainer(config=config, env=env)
  217. policy = trainer.get_policy()
  218. p_sess = None
  219. if sess:
  220. p_sess = policy.get_session()
  221. # Set all weights (of all nets) to fixed values.
  222. if weights_dict is None:
  223. # Start with the tf vars-dict.
  224. assert fw in ["tf2", "tf", "tfe"]
  225. weights_dict = policy.get_weights()
  226. if fw == "tfe":
  227. log_alpha = weights_dict[10]
  228. weights_dict = self._translate_tfe_weights(
  229. weights_dict, map_)
  230. else:
  231. assert fw == "torch" # Then transfer that to torch Model.
  232. model_dict = self._translate_weights_to_torch(
  233. weights_dict, map_)
  234. # Have to add this here (not a parameter in tf, but must be
  235. # one in torch, so it gets properly copied to the GPU(s)).
  236. model_dict["target_entropy"] = policy.model.target_entropy
  237. policy.model.load_state_dict(model_dict)
  238. policy.target_model.load_state_dict(model_dict)
  239. if fw == "tf":
  240. log_alpha = weights_dict["default_policy/log_alpha"]
  241. elif fw == "torch":
  242. # Actually convert to torch tensors (by accessing everything).
  243. input_ = policy._lazy_tensor_dict(input_)
  244. input_ = {k: input_[k] for k in input_.keys()}
  245. log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]
  246. # Only run the expectation once, should be the same anyways
  247. # for all frameworks.
  248. if expect_c is None:
  249. expect_c, expect_a, expect_e, expect_t = \
  250. self._sac_loss_helper(input_, weights_dict,
  251. sorted(weights_dict.keys()),
  252. log_alpha, fw,
  253. gamma=config["gamma"], sess=sess)
  254. # Get actual outs and compare to expectation AND previous
  255. # framework. c=critic, a=actor, e=entropy, t=td-error.
  256. if fw == "tf":
  257. c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = \
  258. p_sess.run([
  259. policy.critic_loss,
  260. policy.actor_loss,
  261. policy.alpha_loss,
  262. policy.td_error,
  263. policy.optimizer().compute_gradients(
  264. policy.critic_loss[0],
  265. [v for v in policy.model.q_variables() if
  266. "value_" not in v.name]),
  267. policy.optimizer().compute_gradients(
  268. policy.actor_loss,
  269. [v for v in policy.model.policy_variables() if
  270. "value_" not in v.name]),
  271. policy.optimizer().compute_gradients(
  272. policy.alpha_loss, policy.model.log_alpha)],
  273. feed_dict=policy._get_loss_inputs_dict(
  274. input_, shuffle=False))
  275. tf_c_grads = [g for g, v in tf_c_grads]
  276. tf_a_grads = [g for g, v in tf_a_grads]
  277. tf_e_grads = [g for g, v in tf_e_grads]
  278. elif fw == "tfe":
  279. with tf.GradientTape() as tape:
  280. tf_loss(policy, policy.model, None, input_)
  281. c, a, e, t = policy.critic_loss, policy.actor_loss, \
  282. policy.alpha_loss, policy.td_error
  283. vars = tape.watched_variables()
  284. tf_c_grads = tape.gradient(c[0], vars[6:10])
  285. tf_a_grads = tape.gradient(a, vars[2:6])
  286. tf_e_grads = tape.gradient(e, vars[10])
  287. elif fw == "torch":
  288. loss_torch(policy, policy.model, None, input_)
  289. c, a, e, t = policy.get_tower_stats("critic_loss")[0], \
  290. policy.get_tower_stats("actor_loss")[0], \
  291. policy.get_tower_stats("alpha_loss")[0], \
  292. policy.get_tower_stats("td_error")[0]
  293. # Test actor gradients.
  294. policy.actor_optim.zero_grad()
  295. assert all(v.grad is None for v in policy.model.q_variables())
  296. assert all(
  297. v.grad is None for v in policy.model.policy_variables())
  298. assert policy.model.log_alpha.grad is None
  299. a.backward()
  300. # `actor_loss` depends on Q-net vars (but these grads must
  301. # be ignored and overridden in critic_loss.backward!).
  302. assert not all(
  303. torch.mean(v.grad) == 0
  304. for v in policy.model.policy_variables())
  305. assert not all(
  306. torch.min(v.grad) == 0
  307. for v in policy.model.policy_variables())
  308. assert policy.model.log_alpha.grad is None
  309. # Compare with tf ones.
  310. torch_a_grads = [
  311. v.grad for v in policy.model.policy_variables()
  312. if v.grad is not None
  313. ]
  314. check(tf_a_grads[2],
  315. np.transpose(torch_a_grads[0].detach().cpu()))
  316. # Test critic gradients.
  317. policy.critic_optims[0].zero_grad()
  318. assert all(
  319. torch.mean(v.grad) == 0.0
  320. for v in policy.model.q_variables() if v.grad is not None)
  321. assert all(
  322. torch.min(v.grad) == 0.0
  323. for v in policy.model.q_variables() if v.grad is not None)
  324. assert policy.model.log_alpha.grad is None
  325. c[0].backward()
  326. assert not all(
  327. torch.mean(v.grad) == 0
  328. for v in policy.model.q_variables() if v.grad is not None)
  329. assert not all(
  330. torch.min(v.grad) == 0 for v in policy.model.q_variables()
  331. if v.grad is not None)
  332. assert policy.model.log_alpha.grad is None
  333. # Compare with tf ones.
  334. torch_c_grads = [v.grad for v in policy.model.q_variables()]
  335. check(tf_c_grads[0],
  336. np.transpose(torch_c_grads[2].detach().cpu()))
  337. # Compare (unchanged(!) actor grads) with tf ones.
  338. torch_a_grads = [
  339. v.grad for v in policy.model.policy_variables()
  340. ]
  341. check(tf_a_grads[2],
  342. np.transpose(torch_a_grads[0].detach().cpu()))
  343. # Test alpha gradient.
  344. policy.alpha_optim.zero_grad()
  345. assert policy.model.log_alpha.grad is None
  346. e.backward()
  347. assert policy.model.log_alpha.grad is not None
  348. check(policy.model.log_alpha.grad, tf_e_grads)
  349. check(c, expect_c)
  350. check(a, expect_a)
  351. check(e, expect_e)
  352. check(t, expect_t)
  353. # Store this framework's losses in prev_fw_loss to compare with
  354. # next framework's outputs.
  355. if prev_fw_loss is not None:
  356. check(c, prev_fw_loss[0])
  357. check(a, prev_fw_loss[1])
  358. check(e, prev_fw_loss[2])
  359. check(t, prev_fw_loss[3])
  360. prev_fw_loss = (c, a, e, t)
  361. # Update weights from our batch (n times).
  362. for update_iteration in range(5):
  363. print("train iteration {}".format(update_iteration))
  364. if fw == "tf":
  365. in_ = self._get_batch_helper(obs_size, actions, batch_size)
  366. tf_inputs.append(in_)
  367. # Set a fake-batch to use
  368. # (instead of sampling from replay buffer).
  369. buf = MultiAgentReplayBuffer.get_instance_for_testing()
  370. buf._fake_batch = in_
  371. trainer.train()
  372. updated_weights = policy.get_weights()
  373. # Net must have changed.
  374. if tf_updated_weights:
  375. check(
  376. updated_weights["default_policy/fc_1/kernel"],
  377. tf_updated_weights[-1][
  378. "default_policy/fc_1/kernel"],
  379. false=True)
  380. tf_updated_weights.append(updated_weights)
  381. # Compare with updated tf-weights. Must all be the same.
  382. else:
  383. tf_weights = tf_updated_weights[update_iteration]
  384. in_ = tf_inputs[update_iteration]
  385. # Set a fake-batch to use
  386. # (instead of sampling from replay buffer).
  387. buf = MultiAgentReplayBuffer.get_instance_for_testing()
  388. buf._fake_batch = in_
  389. trainer.train()
  390. # Compare updated model.
  391. for tf_key in sorted(tf_weights.keys()):
  392. if re.search("_[23]|alpha", tf_key):
  393. continue
  394. tf_var = tf_weights[tf_key]
  395. torch_var = policy.model.state_dict()[map_[tf_key]]
  396. if tf_var.shape != torch_var.shape:
  397. check(
  398. tf_var,
  399. np.transpose(torch_var.detach().cpu()),
  400. atol=0.003)
  401. else:
  402. check(tf_var, torch_var, atol=0.003)
  403. # And alpha.
  404. check(policy.model.log_alpha,
  405. tf_weights["default_policy/log_alpha"])
  406. # Compare target nets.
  407. for tf_key in sorted(tf_weights.keys()):
  408. if not re.search("_[23]", tf_key):
  409. continue
  410. tf_var = tf_weights[tf_key]
  411. torch_var = policy.target_model.state_dict()[map_[
  412. tf_key]]
  413. if tf_var.shape != torch_var.shape:
  414. check(
  415. tf_var,
  416. np.transpose(torch_var.detach().cpu()),
  417. atol=0.003)
  418. else:
  419. check(tf_var, torch_var, atol=0.003)
  420. trainer.stop()
  421. def _get_batch_helper(self, obs_size, actions, batch_size):
  422. return SampleBatch({
  423. SampleBatch.CUR_OBS: np.random.random(size=obs_size),
  424. SampleBatch.ACTIONS: actions,
  425. SampleBatch.REWARDS: np.random.random(size=(batch_size, )),
  426. SampleBatch.DONES: np.random.choice(
  427. [True, False], size=(batch_size, )),
  428. SampleBatch.NEXT_OBS: np.random.random(size=obs_size),
  429. "weights": np.random.random(size=(batch_size, )),
  430. })
  431. def _sac_loss_helper(self, train_batch, weights, ks, log_alpha, fw, gamma,
  432. sess):
  433. """Emulates SAC loss functions for tf and torch."""
  434. # ks:
  435. # 0=log_alpha
  436. # 1=target log-alpha (not used)
  437. # 2=action hidden bias
  438. # 3=action hidden kernel
  439. # 4=action out bias
  440. # 5=action out kernel
  441. # 6=Q hidden bias
  442. # 7=Q hidden kernel
  443. # 8=Q out bias
  444. # 9=Q out kernel
  445. # 14=target Q hidden bias
  446. # 15=target Q hidden kernel
  447. # 16=target Q out bias
  448. # 17=target Q out kernel
  449. alpha = np.exp(log_alpha)
  450. # cls = TorchSquashedGaussian if fw == "torch" else SquashedGaussian
  451. cls = TorchDirichlet if fw == "torch" else Dirichlet
  452. model_out_t = train_batch[SampleBatch.CUR_OBS]
  453. model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
  454. target_model_out_tp1 = train_batch[SampleBatch.NEXT_OBS]
  455. # get_policy_output
  456. action_dist_t = cls(
  457. fc(
  458. relu(
  459. fc(model_out_t,
  460. weights[ks[1]],
  461. weights[ks[0]],
  462. framework=fw)), weights[ks[9]], weights[ks[8]]), None)
  463. policy_t = action_dist_t.deterministic_sample()
  464. log_pis_t = action_dist_t.logp(policy_t)
  465. if sess:
  466. log_pis_t = sess.run(log_pis_t)
  467. policy_t = sess.run(policy_t)
  468. log_pis_t = np.expand_dims(log_pis_t, -1)
  469. # Get policy output for t+1.
  470. action_dist_tp1 = cls(
  471. fc(
  472. relu(
  473. fc(model_out_tp1,
  474. weights[ks[1]],
  475. weights[ks[0]],
  476. framework=fw)), weights[ks[9]], weights[ks[8]]), None)
  477. policy_tp1 = action_dist_tp1.deterministic_sample()
  478. log_pis_tp1 = action_dist_tp1.logp(policy_tp1)
  479. if sess:
  480. log_pis_tp1 = sess.run(log_pis_tp1)
  481. policy_tp1 = sess.run(policy_tp1)
  482. log_pis_tp1 = np.expand_dims(log_pis_tp1, -1)
  483. # Q-values for the actually selected actions.
  484. # get_q_values
  485. q_t = fc(
  486. relu(
  487. fc(np.concatenate(
  488. [model_out_t, train_batch[SampleBatch.ACTIONS]], -1),
  489. weights[ks[3]],
  490. weights[ks[2]],
  491. framework=fw)),
  492. weights[ks[11]],
  493. weights[ks[10]],
  494. framework=fw)
  495. # Q-values for current policy in given current state.
  496. # get_q_values
  497. q_t_det_policy = fc(
  498. relu(
  499. fc(np.concatenate([model_out_t, policy_t], -1),
  500. weights[ks[3]],
  501. weights[ks[2]],
  502. framework=fw)),
  503. weights[ks[11]],
  504. weights[ks[10]],
  505. framework=fw)
  506. # Target q network evaluation.
  507. # target_model.get_q_values
  508. if fw == "tf":
  509. q_tp1 = fc(
  510. relu(
  511. fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
  512. weights[ks[7]],
  513. weights[ks[6]],
  514. framework=fw)),
  515. weights[ks[15]],
  516. weights[ks[14]],
  517. framework=fw)
  518. else:
  519. assert fw == "tfe"
  520. q_tp1 = fc(
  521. relu(
  522. fc(np.concatenate([target_model_out_tp1, policy_tp1], -1),
  523. weights[ks[7]],
  524. weights[ks[6]],
  525. framework=fw)),
  526. weights[ks[9]],
  527. weights[ks[8]],
  528. framework=fw)
  529. q_t_selected = np.squeeze(q_t, axis=-1)
  530. q_tp1 -= alpha * log_pis_tp1
  531. q_tp1_best = np.squeeze(q_tp1, axis=-1)
  532. dones = train_batch[SampleBatch.DONES]
  533. rewards = train_batch[SampleBatch.REWARDS]
  534. if fw == "torch":
  535. dones = dones.float().numpy()
  536. rewards = rewards.numpy()
  537. q_tp1_best_masked = (1.0 - dones) * q_tp1_best
  538. q_t_selected_target = rewards + gamma * q_tp1_best_masked
  539. base_td_error = np.abs(q_t_selected - q_t_selected_target)
  540. td_error = base_td_error
  541. critic_loss = [
  542. np.mean(train_batch["weights"] *
  543. huber_loss(q_t_selected_target - q_t_selected))
  544. ]
  545. target_entropy = -np.prod((1, ))
  546. alpha_loss = -np.mean(log_alpha * (log_pis_t + target_entropy))
  547. actor_loss = np.mean(alpha * log_pis_t - q_t_det_policy)
  548. return critic_loss, actor_loss, alpha_loss, td_error
  549. def _translate_weights_to_torch(self, weights_dict, map_):
  550. model_dict = {
  551. map_[k]: convert_to_torch_tensor(
  552. np.transpose(v) if re.search("kernel", k) else np.array([v])
  553. if re.search("log_alpha", k) else v)
  554. for i, (k, v) in enumerate(weights_dict.items()) if i < 13
  555. }
  556. return model_dict
  557. def _translate_tfe_weights(self, weights_dict, map_):
  558. model_dict = {
  559. "default_policy/log_alpha": None,
  560. "default_policy/log_alpha_target": None,
  561. "default_policy/sequential/action_1/kernel": weights_dict[2],
  562. "default_policy/sequential/action_1/bias": weights_dict[3],
  563. "default_policy/sequential/action_out/kernel": weights_dict[4],
  564. "default_policy/sequential/action_out/bias": weights_dict[5],
  565. "default_policy/sequential_1/q_hidden_0/kernel": weights_dict[6],
  566. "default_policy/sequential_1/q_hidden_0/bias": weights_dict[7],
  567. "default_policy/sequential_1/q_out/kernel": weights_dict[8],
  568. "default_policy/sequential_1/q_out/bias": weights_dict[9],
  569. "default_policy/value_out/kernel": weights_dict[0],
  570. "default_policy/value_out/bias": weights_dict[1],
  571. }
  572. return model_dict
  573. if __name__ == "__main__":
  574. import pytest
  575. import sys
  576. sys.exit(pytest.main(["-v", __file__]))