appo_tf_policy.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. """
  2. TensorFlow policy class used for APPO.
  3. Adapted from VTraceTFPolicy to use the PPO surrogate loss.
  4. Keep in sync with changes to VTraceTFPolicy.
  5. """
  6. import numpy as np
  7. import logging
  8. import gym
  9. from typing import Dict, List, Optional, Type, Union
  10. from ray.rllib.agents.impala import vtrace_tf as vtrace
  11. from ray.rllib.agents.impala.vtrace_tf_policy import _make_time_major, \
  12. clip_gradients, choose_optimizer
  13. from ray.rllib.evaluation.episode import Episode
  14. from ray.rllib.evaluation.postprocessing import compute_gae_for_sample_batch, \
  15. Postprocessing
  16. from ray.rllib.models.tf.tf_action_dist import Categorical
  17. from ray.rllib.policy.policy import Policy
  18. from ray.rllib.policy.sample_batch import SampleBatch
  19. from ray.rllib.policy.tf_policy_template import build_tf_policy
  20. from ray.rllib.policy.tf_policy import EntropyCoeffSchedule, \
  21. LearningRateSchedule, TFPolicy
  22. from ray.rllib.agents.ppo.ppo_tf_policy import KLCoeffMixin, ValueNetworkMixin
  23. from ray.rllib.models.catalog import ModelCatalog
  24. from ray.rllib.models.modelv2 import ModelV2
  25. from ray.rllib.models.tf.tf_action_dist import TFActionDistribution
  26. from ray.rllib.utils.annotations import override
  27. from ray.rllib.utils.framework import try_import_tf
  28. from ray.rllib.utils.tf_utils import explained_variance, make_tf_callable
  29. from ray.rllib.utils.typing import AgentID, TensorType, TrainerConfigDict
  30. tf1, tf, tfv = try_import_tf()
  31. POLICY_SCOPE = "func"
  32. TARGET_POLICY_SCOPE = "target_func"
  33. logger = logging.getLogger(__name__)
  34. def make_appo_model(policy: Policy, obs_space: gym.spaces.Space,
  35. action_space: gym.spaces.Space,
  36. config: TrainerConfigDict) -> ModelV2:
  37. """Builds model and target model for APPO.
  38. Args:
  39. policy (Policy): The Policy, which will use the model for optimization.
  40. obs_space (gym.spaces.Space): The policy's observation space.
  41. action_space (gym.spaces.Space): The policy's action space.
  42. config (TrainerConfigDict):
  43. Returns:
  44. ModelV2: The Model for the Policy to use.
  45. Note: The target model will not be returned, just assigned to
  46. `policy.target_model`.
  47. """
  48. # Get the num_outputs for the following model construction calls.
  49. _, logit_dim = ModelCatalog.get_action_dist(action_space, config["model"])
  50. # Construct the (main) model.
  51. policy.model = ModelCatalog.get_model_v2(
  52. obs_space,
  53. action_space,
  54. logit_dim,
  55. config["model"],
  56. name=POLICY_SCOPE,
  57. framework="torch" if config["framework"] == "torch" else "tf")
  58. policy.model_variables = policy.model.variables()
  59. # Construct the target model.
  60. policy.target_model = ModelCatalog.get_model_v2(
  61. obs_space,
  62. action_space,
  63. logit_dim,
  64. config["model"],
  65. name=TARGET_POLICY_SCOPE,
  66. framework="torch" if config["framework"] == "torch" else "tf")
  67. policy.target_model_variables = policy.target_model.variables()
  68. # Return only the model (not the target model).
  69. return policy.model
  70. def appo_surrogate_loss(
  71. policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution],
  72. train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
  73. """Constructs the loss for APPO.
  74. With IS modifications and V-trace for Advantage Estimation.
  75. Args:
  76. policy (Policy): The Policy to calculate the loss for.
  77. model (ModelV2): The Model to calculate the loss for.
  78. dist_class (Type[ActionDistribution]): The action distr. class.
  79. train_batch (SampleBatch): The training data.
  80. Returns:
  81. Union[TensorType, List[TensorType]]: A single loss tensor or a list
  82. of loss tensors.
  83. """
  84. model_out, _ = model(train_batch)
  85. action_dist = dist_class(model_out, model)
  86. if isinstance(policy.action_space, gym.spaces.Discrete):
  87. is_multidiscrete = False
  88. output_hidden_shape = [policy.action_space.n]
  89. elif isinstance(policy.action_space,
  90. gym.spaces.multi_discrete.MultiDiscrete):
  91. is_multidiscrete = True
  92. output_hidden_shape = policy.action_space.nvec.astype(np.int32)
  93. else:
  94. is_multidiscrete = False
  95. output_hidden_shape = 1
  96. # TODO: (sven) deprecate this when trajectory view API gets activated.
  97. def make_time_major(*args, **kw):
  98. return _make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS),
  99. *args, **kw)
  100. actions = train_batch[SampleBatch.ACTIONS]
  101. dones = train_batch[SampleBatch.DONES]
  102. rewards = train_batch[SampleBatch.REWARDS]
  103. behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
  104. target_model_out, _ = policy.target_model(train_batch)
  105. prev_action_dist = dist_class(behaviour_logits, policy.model)
  106. values = policy.model.value_function()
  107. values_time_major = make_time_major(values)
  108. policy.model_vars = policy.model.variables()
  109. policy.target_model_vars = policy.target_model.variables()
  110. if policy.is_recurrent():
  111. max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS])
  112. mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len)
  113. mask = tf.reshape(mask, [-1])
  114. mask = make_time_major(mask, drop_last=policy.config["vtrace"])
  115. def reduce_mean_valid(t):
  116. return tf.reduce_mean(tf.boolean_mask(t, mask))
  117. else:
  118. reduce_mean_valid = tf.reduce_mean
  119. if policy.config["vtrace"]:
  120. drop_last = policy.config["vtrace_drop_last_ts"]
  121. logger.debug("Using V-Trace surrogate loss (vtrace=True; "
  122. f"drop_last={drop_last})")
  123. # Prepare actions for loss.
  124. loss_actions = actions if is_multidiscrete else tf.expand_dims(
  125. actions, axis=1)
  126. old_policy_behaviour_logits = tf.stop_gradient(target_model_out)
  127. old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
  128. # Prepare KL for Loss
  129. mean_kl = make_time_major(
  130. old_policy_action_dist.multi_kl(action_dist), drop_last=drop_last)
  131. unpacked_behaviour_logits = tf.split(
  132. behaviour_logits, output_hidden_shape, axis=1)
  133. unpacked_old_policy_behaviour_logits = tf.split(
  134. old_policy_behaviour_logits, output_hidden_shape, axis=1)
  135. # Compute vtrace on the CPU for better perf.
  136. with tf.device("/cpu:0"):
  137. vtrace_returns = vtrace.multi_from_logits(
  138. behaviour_policy_logits=make_time_major(
  139. unpacked_behaviour_logits, drop_last=drop_last),
  140. target_policy_logits=make_time_major(
  141. unpacked_old_policy_behaviour_logits, drop_last=drop_last),
  142. actions=tf.unstack(
  143. make_time_major(loss_actions, drop_last=drop_last),
  144. axis=2),
  145. discounts=tf.cast(
  146. ~make_time_major(
  147. tf.cast(dones, tf.bool), drop_last=drop_last),
  148. tf.float32) * policy.config["gamma"],
  149. rewards=make_time_major(rewards, drop_last=drop_last),
  150. values=values_time_major[:-1]
  151. if drop_last else values_time_major,
  152. bootstrap_value=values_time_major[-1],
  153. dist_class=Categorical if is_multidiscrete else dist_class,
  154. model=model,
  155. clip_rho_threshold=tf.cast(
  156. policy.config["vtrace_clip_rho_threshold"], tf.float32),
  157. clip_pg_rho_threshold=tf.cast(
  158. policy.config["vtrace_clip_pg_rho_threshold"], tf.float32),
  159. )
  160. actions_logp = make_time_major(
  161. action_dist.logp(actions), drop_last=drop_last)
  162. prev_actions_logp = make_time_major(
  163. prev_action_dist.logp(actions), drop_last=drop_last)
  164. old_policy_actions_logp = make_time_major(
  165. old_policy_action_dist.logp(actions), drop_last=drop_last)
  166. is_ratio = tf.clip_by_value(
  167. tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0)
  168. logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp)
  169. policy._is_ratio = is_ratio
  170. advantages = vtrace_returns.pg_advantages
  171. surrogate_loss = tf.minimum(
  172. advantages * logp_ratio,
  173. advantages *
  174. tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
  175. 1 + policy.config["clip_param"]))
  176. action_kl = tf.reduce_mean(mean_kl, axis=0) \
  177. if is_multidiscrete else mean_kl
  178. mean_kl_loss = reduce_mean_valid(action_kl)
  179. mean_policy_loss = -reduce_mean_valid(surrogate_loss)
  180. # The value function loss.
  181. if drop_last:
  182. delta = values_time_major[:-1] - vtrace_returns.vs
  183. else:
  184. delta = values_time_major - vtrace_returns.vs
  185. value_targets = vtrace_returns.vs
  186. mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
  187. # The entropy loss.
  188. actions_entropy = make_time_major(
  189. action_dist.multi_entropy(), drop_last=True)
  190. mean_entropy = reduce_mean_valid(actions_entropy)
  191. else:
  192. logger.debug("Using PPO surrogate loss (vtrace=False)")
  193. # Prepare KL for Loss
  194. mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist))
  195. logp_ratio = tf.math.exp(
  196. make_time_major(action_dist.logp(actions)) -
  197. make_time_major(prev_action_dist.logp(actions)))
  198. advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES])
  199. surrogate_loss = tf.minimum(
  200. advantages * logp_ratio,
  201. advantages *
  202. tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"],
  203. 1 + policy.config["clip_param"]))
  204. action_kl = tf.reduce_mean(mean_kl, axis=0) \
  205. if is_multidiscrete else mean_kl
  206. mean_kl_loss = reduce_mean_valid(action_kl)
  207. mean_policy_loss = -reduce_mean_valid(surrogate_loss)
  208. # The value function loss.
  209. value_targets = make_time_major(
  210. train_batch[Postprocessing.VALUE_TARGETS])
  211. delta = values_time_major - value_targets
  212. mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta))
  213. # The entropy loss.
  214. mean_entropy = reduce_mean_valid(
  215. make_time_major(action_dist.multi_entropy()))
  216. # The summed weighted loss.
  217. total_loss = mean_policy_loss - \
  218. mean_entropy * policy.entropy_coeff
  219. # Optional KL loss.
  220. if policy.config["use_kl_loss"]:
  221. total_loss += policy.kl_coeff * mean_kl_loss
  222. # Optional vf loss (or in a separate term due to separate
  223. # optimizers/networks).
  224. loss_wo_vf = total_loss
  225. if not policy.config["_separate_vf_optimizer"]:
  226. total_loss += mean_vf_loss * policy.config["vf_loss_coeff"]
  227. # Store stats in policy for stats_fn.
  228. policy._total_loss = total_loss
  229. policy._loss_wo_vf = loss_wo_vf
  230. policy._mean_policy_loss = mean_policy_loss
  231. # Backward compatibility: Deprecate policy._mean_kl.
  232. policy._mean_kl_loss = policy._mean_kl = mean_kl_loss
  233. policy._mean_vf_loss = mean_vf_loss
  234. policy._mean_entropy = mean_entropy
  235. policy._value_targets = value_targets
  236. # Return one total loss or two losses: vf vs rest (policy + kl).
  237. if policy.config["_separate_vf_optimizer"]:
  238. return loss_wo_vf, mean_vf_loss
  239. else:
  240. return total_loss
  241. def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
  242. """Stats function for APPO. Returns a dict with important loss stats.
  243. Args:
  244. policy (Policy): The Policy to generate stats for.
  245. train_batch (SampleBatch): The SampleBatch (already) used for training.
  246. Returns:
  247. Dict[str, TensorType]: The stats dict.
  248. """
  249. values_batched = _make_time_major(
  250. policy,
  251. train_batch.get(SampleBatch.SEQ_LENS),
  252. policy.model.value_function(),
  253. drop_last=policy.config["vtrace"]
  254. and policy.config["vtrace_drop_last_ts"])
  255. stats_dict = {
  256. "cur_lr": tf.cast(policy.cur_lr, tf.float64),
  257. "total_loss": policy._total_loss,
  258. "policy_loss": policy._mean_policy_loss,
  259. "entropy": policy._mean_entropy,
  260. "var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()),
  261. "vf_loss": policy._mean_vf_loss,
  262. "vf_explained_var": explained_variance(
  263. tf.reshape(policy._value_targets, [-1]),
  264. tf.reshape(values_batched, [-1])),
  265. "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64),
  266. }
  267. if policy.config["vtrace"]:
  268. is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1])
  269. stats_dict["mean_IS"] = is_stat_mean
  270. stats_dict["var_IS"] = is_stat_var
  271. if policy.config["use_kl_loss"]:
  272. stats_dict["kl"] = policy._mean_kl_loss
  273. stats_dict["KL_Coeff"] = policy.kl_coeff
  274. return stats_dict
  275. def postprocess_trajectory(
  276. policy: Policy,
  277. sample_batch: SampleBatch,
  278. other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
  279. episode: Optional[Episode] = None) -> SampleBatch:
  280. """Postprocesses a trajectory and returns the processed trajectory.
  281. The trajectory contains only data from one episode and from one agent.
  282. - If `config.batch_mode=truncate_episodes` (default), sample_batch may
  283. contain a truncated (at-the-end) episode, in case the
  284. `config.rollout_fragment_length` was reached by the sampler.
  285. - If `config.batch_mode=complete_episodes`, sample_batch will contain
  286. exactly one episode (no matter how long).
  287. New columns can be added to sample_batch and existing ones may be altered.
  288. Args:
  289. policy (Policy): The Policy used to generate the trajectory
  290. (`sample_batch`)
  291. sample_batch (SampleBatch): The SampleBatch to postprocess.
  292. other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
  293. dict of AgentIDs mapping to other agents' trajectory data (from the
  294. same episode). NOTE: The other agents use the same policy.
  295. episode (Optional[Episode]): Optional multi-agent episode
  296. object in which the agents operated.
  297. Returns:
  298. SampleBatch: The postprocessed, modified SampleBatch (or a new one).
  299. """
  300. if not policy.config["vtrace"]:
  301. sample_batch = compute_gae_for_sample_batch(
  302. policy, sample_batch, other_agent_batches, episode)
  303. return sample_batch
  304. def add_values(policy):
  305. out = {}
  306. if not policy.config["vtrace"]:
  307. out[SampleBatch.VF_PREDS] = policy.model.value_function()
  308. return out
  309. class TargetNetworkMixin:
  310. """Target NN is updated by master learner via the `update_target` method.
  311. Updates happen every `trainer.update_target_frequency` steps. All worker
  312. batches are importance sampled wrt the target network to ensure a more
  313. stable pi_old in PPO.
  314. """
  315. def __init__(self, obs_space, action_space, config):
  316. @make_tf_callable(self.get_session())
  317. def do_update():
  318. assign_ops = []
  319. assert len(self.model_vars) == len(self.target_model_vars)
  320. for var, var_target in zip(self.model_vars,
  321. self.target_model_vars):
  322. assign_ops.append(var_target.assign(var))
  323. return tf.group(*assign_ops)
  324. self.update_target = do_update
  325. @override(TFPolicy)
  326. def variables(self):
  327. return self.model_vars + self.target_model_vars
  328. def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,
  329. action_space: gym.spaces.Space,
  330. config: TrainerConfigDict) -> None:
  331. """Call all mixin classes' constructors before APPOPolicy initialization.
  332. Args:
  333. policy (Policy): The Policy object.
  334. obs_space (gym.spaces.Space): The Policy's observation space.
  335. action_space (gym.spaces.Space): The Policy's action space.
  336. config (TrainerConfigDict): The Policy's config.
  337. """
  338. LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
  339. KLCoeffMixin.__init__(policy, config)
  340. ValueNetworkMixin.__init__(policy, obs_space, action_space, config)
  341. EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"],
  342. config["entropy_coeff_schedule"])
  343. def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
  344. action_space: gym.spaces.Space,
  345. config: TrainerConfigDict) -> None:
  346. """Call all mixin classes' constructors after APPOPolicy initialization.
  347. Args:
  348. policy (Policy): The Policy object.
  349. obs_space (gym.spaces.Space): The Policy's observation space.
  350. action_space (gym.spaces.Space): The Policy's action space.
  351. config (TrainerConfigDict): The Policy's config.
  352. """
  353. TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
  354. # Build a child class of `DynamicTFPolicy`, given the custom functions defined
  355. # above.
  356. AsyncPPOTFPolicy = build_tf_policy(
  357. name="AsyncPPOTFPolicy",
  358. make_model=make_appo_model,
  359. loss_fn=appo_surrogate_loss,
  360. stats_fn=stats,
  361. postprocess_fn=postprocess_trajectory,
  362. optimizer_fn=choose_optimizer,
  363. compute_gradients_fn=clip_gradients,
  364. extra_action_out_fn=add_values,
  365. before_loss_init=setup_mixins,
  366. after_init=setup_late_mixins,
  367. mixins=[
  368. LearningRateSchedule,
  369. KLCoeffMixin,
  370. TargetNetworkMixin,
  371. ValueNetworkMixin,
  372. EntropyCoeffSchedule,
  373. ],
  374. get_batch_divisibility_req=lambda p: p.config["rollout_fragment_length"])