tf_policy.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238
  1. import errno
  2. import gym
  3. import logging
  4. import math
  5. import numpy as np
  6. import os
  7. import tree # pip install dm_tree
  8. from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
  9. import ray
  10. import ray.experimental.tf_utils
  11. from ray.util.debug import log_once
  12. from ray.rllib.policy.policy import Policy
  13. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  14. from ray.rllib.policy.sample_batch import SampleBatch
  15. from ray.rllib.models.modelv2 import ModelV2
  16. from ray.rllib.utils import force_list
  17. from ray.rllib.utils.annotations import DeveloperAPI, override
  18. from ray.rllib.utils.debug import summarize
  19. from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
  20. from ray.rllib.utils.framework import try_import_tf, get_variable
  21. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  22. from ray.rllib.utils.schedules import PiecewiseSchedule
  23. from ray.rllib.utils.spaces.space_utils import normalize_action
  24. from ray.rllib.utils.tf_utils import get_gpu_devices
  25. from ray.rllib.utils.tf_run_builder import TFRunBuilder
  26. from ray.rllib.utils.typing import LocalOptimizer, ModelGradients, \
  27. TensorType, TrainerConfigDict
  28. if TYPE_CHECKING:
  29. from ray.rllib.evaluation import Episode
  30. tf1, tf, tfv = try_import_tf()
  31. logger = logging.getLogger(__name__)
  32. @DeveloperAPI
  33. class TFPolicy(Policy):
  34. """An agent policy and loss implemented in TensorFlow.
  35. Do not sub-class this class directly (neither should you sub-class
  36. DynamicTFPolicy), but rather use
  37. rllib.policy.tf_policy_template.build_tf_policy
  38. to generate your custom tf (graph-mode or eager) Policy classes.
  39. Extending this class enables RLlib to perform TensorFlow specific
  40. optimizations on the policy, e.g., parallelization across gpus or
  41. fusing multiple graphs together in the multi-agent setting.
  42. Input tensors are typically shaped like [BATCH_SIZE, ...].
  43. Examples:
  44. >>> policy = TFPolicySubclass(
  45. sess, obs_input, sampled_action, loss, loss_inputs)
  46. >>> print(policy.compute_actions([1, 0, 2]))
  47. (array([0, 1, 1]), [], {})
  48. >>> print(policy.postprocess_trajectory(SampleBatch({...})))
  49. SampleBatch({"action": ..., "advantages": ..., ...})
  50. """
  51. @DeveloperAPI
  52. def __init__(self,
  53. observation_space: gym.spaces.Space,
  54. action_space: gym.spaces.Space,
  55. config: TrainerConfigDict,
  56. sess: "tf1.Session",
  57. obs_input: TensorType,
  58. sampled_action: TensorType,
  59. loss: Union[TensorType, List[TensorType]],
  60. loss_inputs: List[Tuple[str, TensorType]],
  61. model: Optional[ModelV2] = None,
  62. sampled_action_logp: Optional[TensorType] = None,
  63. action_input: Optional[TensorType] = None,
  64. log_likelihood: Optional[TensorType] = None,
  65. dist_inputs: Optional[TensorType] = None,
  66. dist_class: Optional[type] = None,
  67. state_inputs: Optional[List[TensorType]] = None,
  68. state_outputs: Optional[List[TensorType]] = None,
  69. prev_action_input: Optional[TensorType] = None,
  70. prev_reward_input: Optional[TensorType] = None,
  71. seq_lens: Optional[TensorType] = None,
  72. max_seq_len: int = 20,
  73. batch_divisibility_req: int = 1,
  74. update_ops: List[TensorType] = None,
  75. explore: Optional[TensorType] = None,
  76. timestep: Optional[TensorType] = None):
  77. """Initializes a Policy object.
  78. Args:
  79. observation_space: Observation space of the policy.
  80. action_space: Action space of the policy.
  81. config: Policy-specific configuration data.
  82. sess: The TensorFlow session to use.
  83. obs_input: Input placeholder for observations, of shape
  84. [BATCH_SIZE, obs...].
  85. sampled_action: Tensor for sampling an action, of shape
  86. [BATCH_SIZE, action...]
  87. loss: Scalar policy loss output tensor or a list thereof
  88. (in case there is more than one loss).
  89. loss_inputs: A (name, placeholder) tuple for each loss input
  90. argument. Each placeholder name must
  91. correspond to a SampleBatch column key returned by
  92. postprocess_trajectory(), and has shape [BATCH_SIZE, data...].
  93. These keys will be read from postprocessed sample batches and
  94. fed into the specified placeholders during loss computation.
  95. model: The optional ModelV2 to use for calculating actions and
  96. losses. If not None, TFPolicy will provide functionality for
  97. getting variables, calling the model's custom loss (if
  98. provided), and importing weights into the model.
  99. sampled_action_logp: log probability of the sampled action.
  100. action_input: Input placeholder for actions for
  101. logp/log-likelihood calculations.
  102. log_likelihood: Tensor to calculate the log_likelihood (given
  103. action_input and obs_input).
  104. dist_class: An optional ActionDistribution class to use for
  105. generating a dist object from distribution inputs.
  106. dist_inputs: Tensor to calculate the distribution
  107. inputs/parameters.
  108. state_inputs: List of RNN state input Tensors.
  109. state_outputs: List of RNN state output Tensors.
  110. prev_action_input: placeholder for previous actions.
  111. prev_reward_input: placeholder for previous rewards.
  112. seq_lens: Placeholder for RNN sequence lengths, of shape
  113. [NUM_SEQUENCES].
  114. Note that NUM_SEQUENCES << BATCH_SIZE. See
  115. policy/rnn_sequencing.py for more information.
  116. max_seq_len: Max sequence length for LSTM training.
  117. batch_divisibility_req: pad all agent experiences batches to
  118. multiples of this value. This only has an effect if not using
  119. a LSTM model.
  120. update_ops: override the batchnorm update ops
  121. to run when applying gradients. Otherwise we run all update
  122. ops found in the current variable scope.
  123. explore: Placeholder for `explore` parameter into call to
  124. Exploration.get_exploration_action. Explicitly set this to
  125. False for not creating any Exploration component.
  126. timestep: Placeholder for the global sampling timestep.
  127. """
  128. self.framework = "tf"
  129. super().__init__(observation_space, action_space, config)
  130. # Get devices to build the graph on.
  131. worker_idx = self.config.get("worker_index", 0)
  132. if not config["_fake_gpus"] and \
  133. ray.worker._mode() == ray.worker.LOCAL_MODE:
  134. num_gpus = 0
  135. elif worker_idx == 0:
  136. num_gpus = config["num_gpus"]
  137. else:
  138. num_gpus = config["num_gpus_per_worker"]
  139. gpu_ids = get_gpu_devices()
  140. # Place on one or more CPU(s) when either:
  141. # - Fake GPU mode.
  142. # - num_gpus=0 (either set by user or we are in local_mode=True).
  143. # - no GPUs available.
  144. if config["_fake_gpus"] or num_gpus == 0 or not gpu_ids:
  145. logger.info("TFPolicy (worker={}) running on {}.".format(
  146. worker_idx
  147. if worker_idx > 0 else "local", f"{num_gpus} fake-GPUs"
  148. if config["_fake_gpus"] else "CPU"))
  149. self.devices = [
  150. "/cpu:0" for _ in range(int(math.ceil(num_gpus)) or 1)
  151. ]
  152. # Place on one or more actual GPU(s), when:
  153. # - num_gpus > 0 (set by user) AND
  154. # - local_mode=False AND
  155. # - actual GPUs available AND
  156. # - non-fake GPU mode.
  157. else:
  158. logger.info("TFPolicy (worker={}) running on {} GPU(s).".format(
  159. worker_idx if worker_idx > 0 else "local", num_gpus))
  160. # We are a remote worker (WORKER_MODE=1):
  161. # GPUs should be assigned to us by ray.
  162. if ray.worker._mode() == ray.worker.WORKER_MODE:
  163. gpu_ids = ray.get_gpu_ids()
  164. if len(gpu_ids) < num_gpus:
  165. raise ValueError(
  166. "TFPolicy was not able to find enough GPU IDs! Found "
  167. f"{gpu_ids}, but num_gpus={num_gpus}.")
  168. self.devices = [
  169. f"/gpu:{i}" for i, _ in enumerate(gpu_ids) if i < num_gpus
  170. ]
  171. # Disable env-info placeholder.
  172. if SampleBatch.INFOS in self.view_requirements:
  173. self.view_requirements[SampleBatch.INFOS].used_for_training = False
  174. self.view_requirements[
  175. SampleBatch.INFOS].used_for_compute_actions = False
  176. assert model is None or isinstance(model, (ModelV2, tf.keras.Model)), \
  177. "Model classes for TFPolicy other than `ModelV2|tf.keras.Model` " \
  178. "not allowed! You passed in {}.".format(model)
  179. self.model = model
  180. # Auto-update model's inference view requirements, if recurrent.
  181. if self.model is not None:
  182. self._update_model_view_requirements_from_init_state()
  183. # If `explore` is explicitly set to False, don't create an exploration
  184. # component.
  185. self.exploration = self._create_exploration() if explore is not False \
  186. else None
  187. self._sess = sess
  188. self._obs_input = obs_input
  189. self._prev_action_input = prev_action_input
  190. self._prev_reward_input = prev_reward_input
  191. self._sampled_action = sampled_action
  192. self._is_training = self._get_is_training_placeholder()
  193. self._is_exploring = explore if explore is not None else \
  194. tf1.placeholder_with_default(True, (), name="is_exploring")
  195. self._sampled_action_logp = sampled_action_logp
  196. self._sampled_action_prob = (tf.math.exp(self._sampled_action_logp)
  197. if self._sampled_action_logp is not None
  198. else None)
  199. self._action_input = action_input # For logp calculations.
  200. self._dist_inputs = dist_inputs
  201. self.dist_class = dist_class
  202. self._state_inputs = state_inputs or []
  203. self._state_outputs = state_outputs or []
  204. self._seq_lens = seq_lens
  205. self._max_seq_len = max_seq_len
  206. if self._state_inputs and self._seq_lens is None:
  207. raise ValueError(
  208. "seq_lens tensor must be given if state inputs are defined")
  209. self._batch_divisibility_req = batch_divisibility_req
  210. self._update_ops = update_ops
  211. self._apply_op = None
  212. self._stats_fetches = {}
  213. self._timestep = timestep if timestep is not None else \
  214. tf1.placeholder_with_default(
  215. tf.zeros((), dtype=tf.int64), (), name="timestep")
  216. self._optimizers: List[LocalOptimizer] = []
  217. # Backward compatibility and for some code shared with tf-eager Policy.
  218. self._optimizer = None
  219. self._grads_and_vars: Union[ModelGradients, List[ModelGradients]] = []
  220. self._grads: Union[ModelGradients, List[ModelGradients]] = []
  221. # Policy tf-variables (weights), whose values to get/set via
  222. # get_weights/set_weights.
  223. self._variables = None
  224. # Local optimizer(s)' tf-variables (e.g. state vars for Adam).
  225. # Will be stored alongside `self._variables` when checkpointing.
  226. self._optimizer_variables: \
  227. Optional[ray.experimental.tf_utils.TensorFlowVariables] = None
  228. # The loss tf-op(s). Number of losses must match number of optimizers.
  229. self._losses = []
  230. # Backward compatibility (in case custom child TFPolicies access this
  231. # property).
  232. self._loss = None
  233. # A batch dict passed into loss function as input.
  234. self._loss_input_dict = {}
  235. losses = force_list(loss)
  236. if len(losses) > 0:
  237. self._initialize_loss(losses, loss_inputs)
  238. # The log-likelihood calculator op.
  239. self._log_likelihood = log_likelihood
  240. if self._log_likelihood is None and self._dist_inputs is not None and \
  241. self.dist_class is not None:
  242. self._log_likelihood = self.dist_class(
  243. self._dist_inputs, self.model).logp(self._action_input)
  244. @override(Policy)
  245. def compute_actions_from_input_dict(
  246. self,
  247. input_dict: Union[SampleBatch, Dict[str, TensorType]],
  248. explore: bool = None,
  249. timestep: Optional[int] = None,
  250. episodes: Optional[List["Episode"]] = None,
  251. **kwargs) -> \
  252. Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  253. explore = explore if explore is not None else self.config["explore"]
  254. timestep = timestep if timestep is not None else self.global_timestep
  255. # Switch off is_training flag in our batch.
  256. input_dict["is_training"] = False
  257. builder = TFRunBuilder(self.get_session(),
  258. "compute_actions_from_input_dict")
  259. obs_batch = input_dict[SampleBatch.OBS]
  260. to_fetch = self._build_compute_actions(
  261. builder, input_dict=input_dict, explore=explore, timestep=timestep)
  262. # Execute session run to get action (and other fetches).
  263. fetched = builder.get(to_fetch)
  264. # Update our global timestep by the batch size.
  265. self.global_timestep += len(obs_batch) if isinstance(obs_batch, list) \
  266. else len(input_dict) if isinstance(input_dict, SampleBatch) \
  267. else obs_batch.shape[0]
  268. return fetched
  269. @override(Policy)
  270. def compute_actions(
  271. self,
  272. obs_batch: Union[List[TensorType], TensorType],
  273. state_batches: Optional[List[TensorType]] = None,
  274. prev_action_batch: Union[List[TensorType], TensorType] = None,
  275. prev_reward_batch: Union[List[TensorType], TensorType] = None,
  276. info_batch: Optional[Dict[str, list]] = None,
  277. episodes: Optional[List["Episode"]] = None,
  278. explore: Optional[bool] = None,
  279. timestep: Optional[int] = None,
  280. **kwargs):
  281. explore = explore if explore is not None else self.config["explore"]
  282. timestep = timestep if timestep is not None else self.global_timestep
  283. builder = TFRunBuilder(self.get_session(), "compute_actions")
  284. input_dict = {SampleBatch.OBS: obs_batch, "is_training": False}
  285. if state_batches:
  286. for i, s in enumerate(state_batches):
  287. input_dict[f"state_in_{i}"] = s
  288. if prev_action_batch is not None:
  289. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  290. if prev_reward_batch is not None:
  291. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  292. to_fetch = self._build_compute_actions(
  293. builder, input_dict=input_dict, explore=explore, timestep=timestep)
  294. # Execute session run to get action (and other fetches).
  295. fetched = builder.get(to_fetch)
  296. # Update our global timestep by the batch size.
  297. self.global_timestep += \
  298. len(obs_batch) if isinstance(obs_batch, list) \
  299. else tree.flatten(obs_batch)[0].shape[0]
  300. return fetched
  301. @override(Policy)
  302. def compute_log_likelihoods(
  303. self,
  304. actions: Union[List[TensorType], TensorType],
  305. obs_batch: Union[List[TensorType], TensorType],
  306. state_batches: Optional[List[TensorType]] = None,
  307. prev_action_batch: Optional[Union[List[TensorType],
  308. TensorType]] = None,
  309. prev_reward_batch: Optional[Union[List[TensorType],
  310. TensorType]] = None,
  311. actions_normalized: bool = True,
  312. ) -> TensorType:
  313. if self._log_likelihood is None:
  314. raise ValueError("Cannot compute log-prob/likelihood w/o a "
  315. "self._log_likelihood op!")
  316. # Exploration hook before each forward pass.
  317. self.exploration.before_compute_actions(
  318. explore=False, tf_sess=self.get_session())
  319. builder = TFRunBuilder(self.get_session(), "compute_log_likelihoods")
  320. # Normalize actions if necessary.
  321. if actions_normalized is False and self.config["normalize_actions"]:
  322. actions = normalize_action(actions, self.action_space_struct)
  323. # Feed actions (for which we want logp values) into graph.
  324. builder.add_feed_dict({self._action_input: actions})
  325. # Feed observations.
  326. builder.add_feed_dict({self._obs_input: obs_batch})
  327. # Internal states.
  328. state_batches = state_batches or []
  329. if len(self._state_inputs) != len(state_batches):
  330. raise ValueError(
  331. "Must pass in RNN state batches for placeholders {}, got {}".
  332. format(self._state_inputs, state_batches))
  333. builder.add_feed_dict(
  334. {k: v
  335. for k, v in zip(self._state_inputs, state_batches)})
  336. if state_batches:
  337. builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
  338. # Prev-a and r.
  339. if self._prev_action_input is not None and \
  340. prev_action_batch is not None:
  341. builder.add_feed_dict({self._prev_action_input: prev_action_batch})
  342. if self._prev_reward_input is not None and \
  343. prev_reward_batch is not None:
  344. builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
  345. # Fetch the log_likelihoods output and return.
  346. fetches = builder.add_fetches([self._log_likelihood])
  347. return builder.get(fetches)[0]
  348. @override(Policy)
  349. @DeveloperAPI
  350. def learn_on_batch(
  351. self, postprocessed_batch: SampleBatch) -> Dict[str, TensorType]:
  352. assert self.loss_initialized()
  353. # Switch on is_training flag in our batch.
  354. postprocessed_batch.set_training(True)
  355. builder = TFRunBuilder(self.get_session(), "learn_on_batch")
  356. # Callback handling.
  357. learn_stats = {}
  358. self.callbacks.on_learn_on_batch(
  359. policy=self, train_batch=postprocessed_batch, result=learn_stats)
  360. fetches = self._build_learn_on_batch(builder, postprocessed_batch)
  361. stats = builder.get(fetches)
  362. stats.update({"custom_metrics": learn_stats})
  363. return stats
  364. @override(Policy)
  365. @DeveloperAPI
  366. def compute_gradients(
  367. self,
  368. postprocessed_batch: SampleBatch) -> \
  369. Tuple[ModelGradients, Dict[str, TensorType]]:
  370. assert self.loss_initialized()
  371. # Switch on is_training flag in our batch.
  372. postprocessed_batch.set_training(True)
  373. builder = TFRunBuilder(self.get_session(), "compute_gradients")
  374. fetches = self._build_compute_gradients(builder, postprocessed_batch)
  375. return builder.get(fetches)
  376. @override(Policy)
  377. @DeveloperAPI
  378. def apply_gradients(self, gradients: ModelGradients) -> None:
  379. assert self.loss_initialized()
  380. builder = TFRunBuilder(self.get_session(), "apply_gradients")
  381. fetches = self._build_apply_gradients(builder, gradients)
  382. builder.get(fetches)
  383. @override(Policy)
  384. @DeveloperAPI
  385. def get_weights(self) -> Union[Dict[str, TensorType], List[TensorType]]:
  386. return self._variables.get_weights()
  387. @override(Policy)
  388. @DeveloperAPI
  389. def set_weights(self, weights) -> None:
  390. return self._variables.set_weights(weights)
  391. @override(Policy)
  392. @DeveloperAPI
  393. def get_exploration_state(self) -> Dict[str, TensorType]:
  394. return self.exploration.get_state(sess=self.get_session())
  395. @Deprecated(new="get_exploration_state", error=False)
  396. def get_exploration_info(self) -> Dict[str, TensorType]:
  397. return self.get_exploration_state()
  398. @override(Policy)
  399. @DeveloperAPI
  400. def is_recurrent(self) -> bool:
  401. return len(self._state_inputs) > 0
  402. @override(Policy)
  403. @DeveloperAPI
  404. def num_state_tensors(self) -> int:
  405. return len(self._state_inputs)
  406. @override(Policy)
  407. @DeveloperAPI
  408. def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
  409. # For tf Policies, return Policy weights and optimizer var values.
  410. state = super().get_state()
  411. if len(self._optimizer_variables.variables) > 0:
  412. state["_optimizer_variables"] = \
  413. self.get_session().run(self._optimizer_variables.variables)
  414. # Add exploration state.
  415. state["_exploration_state"] = \
  416. self.exploration.get_state(self.get_session())
  417. return state
  418. @override(Policy)
  419. @DeveloperAPI
  420. def set_state(self, state: dict) -> None:
  421. # Set optimizer vars first.
  422. optimizer_vars = state.get("_optimizer_variables", None)
  423. if optimizer_vars is not None:
  424. self._optimizer_variables.set_weights(optimizer_vars)
  425. # Set exploration's state.
  426. if hasattr(self, "exploration") and "_exploration_state" in state:
  427. self.exploration.set_state(
  428. state=state["_exploration_state"], sess=self.get_session())
  429. # Set the Policy's (NN) weights.
  430. super().set_state(state)
  431. @override(Policy)
  432. @DeveloperAPI
  433. def export_checkpoint(self,
  434. export_dir: str,
  435. filename_prefix: str = "model") -> None:
  436. """Export tensorflow checkpoint to export_dir."""
  437. try:
  438. os.makedirs(export_dir)
  439. except OSError as e:
  440. # ignore error if export dir already exists
  441. if e.errno != errno.EEXIST:
  442. raise
  443. save_path = os.path.join(export_dir, filename_prefix)
  444. with self.get_session().graph.as_default():
  445. saver = tf1.train.Saver()
  446. saver.save(self.get_session(), save_path)
  447. @override(Policy)
  448. @DeveloperAPI
  449. def export_model(self, export_dir: str,
  450. onnx: Optional[int] = None) -> None:
  451. """Export tensorflow graph to export_dir for serving."""
  452. if onnx:
  453. try:
  454. import tf2onnx
  455. except ImportError as e:
  456. raise RuntimeError(
  457. "Converting a TensorFlow model to ONNX requires "
  458. "`tf2onnx` to be installed. Install with "
  459. "`pip install tf2onnx`.") from e
  460. with self.get_session().graph.as_default():
  461. signature_def_map = self._build_signature_def()
  462. sd = signature_def_map[tf1.saved_model.signature_constants.
  463. DEFAULT_SERVING_SIGNATURE_DEF_KEY]
  464. inputs = [v.name for k, v in sd.inputs.items()]
  465. outputs = [v.name for k, v in sd.outputs.items()]
  466. from tf2onnx import tf_loader
  467. frozen_graph_def = tf_loader.freeze_session(
  468. self._sess, input_names=inputs, output_names=outputs)
  469. with tf1.Session(graph=tf.Graph()) as session:
  470. tf.import_graph_def(frozen_graph_def, name="")
  471. g = tf2onnx.tfonnx.process_tf_graph(
  472. session.graph,
  473. input_names=inputs,
  474. output_names=outputs,
  475. inputs_as_nchw=inputs)
  476. model_proto = g.make_model("onnx_model")
  477. tf2onnx.utils.save_onnx_model(
  478. export_dir,
  479. "saved_model",
  480. feed_dict={},
  481. model_proto=model_proto)
  482. else:
  483. with self.get_session().graph.as_default():
  484. signature_def_map = self._build_signature_def()
  485. builder = tf1.saved_model.builder.SavedModelBuilder(export_dir)
  486. builder.add_meta_graph_and_variables(
  487. self.get_session(),
  488. [tf1.saved_model.tag_constants.SERVING],
  489. signature_def_map=signature_def_map,
  490. saver=tf1.summary.FileWriter(export_dir).add_graph(
  491. graph=self.get_session().graph))
  492. builder.save()
  493. @override(Policy)
  494. @DeveloperAPI
  495. def import_model_from_h5(self, import_file: str) -> None:
  496. """Imports weights into tf model."""
  497. if self.model is None:
  498. raise NotImplementedError("No `self.model` to import into!")
  499. # Make sure the session is the right one (see issue #7046).
  500. with self.get_session().graph.as_default():
  501. with self.get_session().as_default():
  502. return self.model.import_from_h5(import_file)
  503. @override(Policy)
  504. def get_session(self) -> Optional["tf1.Session"]:
  505. """Returns a reference to the TF session for this policy."""
  506. return self._sess
  507. def variables(self):
  508. """Return the list of all savable variables for this policy."""
  509. if self.model is None:
  510. raise NotImplementedError("No `self.model` to get variables for!")
  511. elif isinstance(self.model, tf.keras.Model):
  512. return self.model.variables
  513. else:
  514. return self.model.variables()
  515. def get_placeholder(self, name) -> "tf1.placeholder":
  516. """Returns the given action or loss input placeholder by name.
  517. If the loss has not been initialized and a loss input placeholder is
  518. requested, an error is raised.
  519. Args:
  520. name (str): The name of the placeholder to return. One of
  521. SampleBatch.CUR_OBS|PREV_ACTION/REWARD or a valid key from
  522. `self._loss_input_dict`.
  523. Returns:
  524. tf1.placeholder: The placeholder under the given str key.
  525. """
  526. if name == SampleBatch.CUR_OBS:
  527. return self._obs_input
  528. elif name == SampleBatch.PREV_ACTIONS:
  529. return self._prev_action_input
  530. elif name == SampleBatch.PREV_REWARDS:
  531. return self._prev_reward_input
  532. assert self._loss_input_dict, \
  533. "You need to populate `self._loss_input_dict` before " \
  534. "`get_placeholder()` can be called"
  535. return self._loss_input_dict[name]
  536. def loss_initialized(self) -> bool:
  537. """Returns whether the loss term(s) have been initialized."""
  538. return len(self._losses) > 0
  539. def _initialize_loss(self, losses: List[TensorType],
  540. loss_inputs: List[Tuple[str, TensorType]]) -> None:
  541. """Initializes the loss op from given loss tensor and placeholders.
  542. Args:
  543. loss (List[TensorType]): The list of loss ops returned by some
  544. loss function.
  545. loss_inputs (List[Tuple[str, TensorType]]): The list of Tuples:
  546. (name, tf1.placeholders) needed for calculating the loss.
  547. """
  548. self._loss_input_dict = dict(loss_inputs)
  549. self._loss_input_dict_no_rnn = {
  550. k: v
  551. for k, v in self._loss_input_dict.items()
  552. if (v not in self._state_inputs and v != self._seq_lens)
  553. }
  554. for i, ph in enumerate(self._state_inputs):
  555. self._loss_input_dict["state_in_{}".format(i)] = ph
  556. if self.model and not isinstance(self.model, tf.keras.Model):
  557. self._losses = force_list(
  558. self.model.custom_loss(losses, self._loss_input_dict))
  559. self._stats_fetches.update({"model": self.model.metrics()})
  560. else:
  561. self._losses = losses
  562. # Backward compatibility.
  563. self._loss = self._losses[0] if self._losses is not None else None
  564. if not self._optimizers:
  565. self._optimizers = force_list(self.optimizer())
  566. # Backward compatibility.
  567. self._optimizer = self._optimizers[0] if self._optimizers else None
  568. # Supporting more than one loss/optimizer.
  569. if self.config["_tf_policy_handles_more_than_one_loss"]:
  570. self._grads_and_vars = []
  571. self._grads = []
  572. for group in self.gradients(self._optimizers, self._losses):
  573. g_and_v = [(g, v) for (g, v) in group if g is not None]
  574. self._grads_and_vars.append(g_and_v)
  575. self._grads.append([g for (g, _) in g_and_v])
  576. # Only one optimizer and and loss term.
  577. else:
  578. self._grads_and_vars = [
  579. (g, v)
  580. for (g, v) in self.gradients(self._optimizer, self._loss)
  581. if g is not None
  582. ]
  583. self._grads = [g for (g, _) in self._grads_and_vars]
  584. if self.model:
  585. self._variables = ray.experimental.tf_utils.TensorFlowVariables(
  586. [], self.get_session(), self.variables())
  587. # Gather update ops for any batch norm layers.
  588. if len(self.devices) <= 1:
  589. if not self._update_ops:
  590. self._update_ops = tf1.get_collection(
  591. tf1.GraphKeys.UPDATE_OPS,
  592. scope=tf1.get_variable_scope().name)
  593. if self._update_ops:
  594. logger.info("Update ops to run on apply gradient: {}".format(
  595. self._update_ops))
  596. with tf1.control_dependencies(self._update_ops):
  597. self._apply_op = self.build_apply_op(
  598. optimizer=self._optimizers
  599. if self.config["_tf_policy_handles_more_than_one_loss"]
  600. else self._optimizer,
  601. grads_and_vars=self._grads_and_vars)
  602. if log_once("loss_used"):
  603. logger.debug("These tensors were used in the loss functions:"
  604. f"\n{summarize(self._loss_input_dict)}\n")
  605. self.get_session().run(tf1.global_variables_initializer())
  606. # TensorFlowVariables holing a flat list of all our optimizers'
  607. # variables.
  608. self._optimizer_variables = \
  609. ray.experimental.tf_utils.TensorFlowVariables(
  610. [v for o in self._optimizers for v in o.variables()],
  611. self.get_session())
  612. @DeveloperAPI
  613. def copy(self,
  614. existing_inputs: List[Tuple[str, "tf1.placeholder"]]) -> \
  615. "TFPolicy":
  616. """Creates a copy of self using existing input placeholders.
  617. Optional: Only required to work with the multi-GPU optimizer.
  618. Args:
  619. existing_inputs (List[Tuple[str, tf1.placeholder]]): Dict mapping
  620. names (str) to tf1.placeholders to re-use (share) with the
  621. returned copy of self.
  622. Returns:
  623. TFPolicy: A copy of self.
  624. """
  625. raise NotImplementedError
  626. @DeveloperAPI
  627. def extra_compute_action_feed_dict(self) -> Dict[TensorType, TensorType]:
  628. """Extra dict to pass to the compute actions session run.
  629. Returns:
  630. Dict[TensorType, TensorType]: A feed dict to be added to the
  631. feed_dict passed to the compute_actions session.run() call.
  632. """
  633. return {}
  634. @DeveloperAPI
  635. def extra_compute_action_fetches(self) -> Dict[str, TensorType]:
  636. """Extra values to fetch and return from compute_actions().
  637. By default we return action probability/log-likelihood info
  638. and action distribution inputs (if present).
  639. Returns:
  640. Dict[str, TensorType]: An extra fetch-dict to be passed to and
  641. returned from the compute_actions() call.
  642. """
  643. extra_fetches = {}
  644. # Action-logp and action-prob.
  645. if self._sampled_action_logp is not None:
  646. extra_fetches[SampleBatch.ACTION_PROB] = self._sampled_action_prob
  647. extra_fetches[SampleBatch.ACTION_LOGP] = self._sampled_action_logp
  648. # Action-dist inputs.
  649. if self._dist_inputs is not None:
  650. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = self._dist_inputs
  651. return extra_fetches
  652. @DeveloperAPI
  653. def extra_compute_grad_feed_dict(self) -> Dict[TensorType, TensorType]:
  654. """Extra dict to pass to the compute gradients session run.
  655. Returns:
  656. Dict[TensorType, TensorType]: Extra feed_dict to be passed to the
  657. compute_gradients Session.run() call.
  658. """
  659. return {} # e.g, kl_coeff
  660. @DeveloperAPI
  661. def extra_compute_grad_fetches(self) -> Dict[str, any]:
  662. """Extra values to fetch and return from compute_gradients().
  663. Returns:
  664. Dict[str, any]: Extra fetch dict to be added to the fetch dict
  665. of the compute_gradients Session.run() call.
  666. """
  667. return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
  668. @DeveloperAPI
  669. def optimizer(self) -> "tf.keras.optimizers.Optimizer":
  670. """TF optimizer to use for policy optimization.
  671. Returns:
  672. tf.keras.optimizers.Optimizer: The local optimizer to use for this
  673. Policy's Model.
  674. """
  675. if hasattr(self, "config") and "lr" in self.config:
  676. return tf1.train.AdamOptimizer(learning_rate=self.config["lr"])
  677. else:
  678. return tf1.train.AdamOptimizer()
  679. @DeveloperAPI
  680. def gradients(
  681. self,
  682. optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
  683. loss: Union[TensorType, List[TensorType]],
  684. ) -> Union[List[ModelGradients], List[List[ModelGradients]]]:
  685. """Override this for a custom gradient computation behavior.
  686. Args:
  687. optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): A single
  688. LocalOptimizer of a list thereof to use for gradient
  689. calculations. If more than one optimizer given, the number of
  690. optimizers must match the number of losses provided.
  691. loss (Union[TensorType, List[TensorType]]): A single loss term
  692. or a list thereof to use for gradient calculations.
  693. If more than one loss given, the number of loss terms must
  694. match the number of optimizers provided.
  695. Returns:
  696. Union[List[ModelGradients], List[List[ModelGradients]]]: List of
  697. ModelGradients (grads and vars OR just grads) OR List of List
  698. of ModelGradients in case we have more than one
  699. optimizer/loss.
  700. """
  701. optimizers = force_list(optimizer)
  702. losses = force_list(loss)
  703. # We have more than one optimizers and loss terms.
  704. if self.config["_tf_policy_handles_more_than_one_loss"]:
  705. grads = []
  706. for optim, loss_ in zip(optimizers, losses):
  707. grads.append(optim.compute_gradients(loss_))
  708. # We have only one optimizer and one loss term.
  709. else:
  710. return optimizers[0].compute_gradients(losses[0])
  711. @DeveloperAPI
  712. def build_apply_op(
  713. self,
  714. optimizer: Union[LocalOptimizer, List[LocalOptimizer]],
  715. grads_and_vars: Union[ModelGradients, List[ModelGradients]],
  716. ) -> "tf.Operation":
  717. """Override this for a custom gradient apply computation behavior.
  718. Args:
  719. optimizer (Union[LocalOptimizer, List[LocalOptimizer]]): The local
  720. tf optimizer to use for applying the grads and vars.
  721. grads_and_vars (Union[ModelGradients, List[ModelGradients]]): List
  722. of tuples with grad values and the grad-value's corresponding
  723. tf.variable in it.
  724. Returns:
  725. tf.Operation: The tf op that applies all computed gradients
  726. (`grads_and_vars`) to the model(s) via the given optimizer(s).
  727. """
  728. optimizers = force_list(optimizer)
  729. # We have more than one optimizers and loss terms.
  730. if self.config["_tf_policy_handles_more_than_one_loss"]:
  731. ops = []
  732. for i, optim in enumerate(optimizers):
  733. # Specify global_step (e.g. for TD3 which needs to count the
  734. # num updates that have happened).
  735. ops.append(
  736. optim.apply_gradients(
  737. grads_and_vars[i],
  738. global_step=tf1.train.get_or_create_global_step()))
  739. return tf.group(ops)
  740. # We have only one optimizer and one loss term.
  741. else:
  742. return optimizers[0].apply_gradients(
  743. grads_and_vars,
  744. global_step=tf1.train.get_or_create_global_step())
  745. def _get_is_training_placeholder(self):
  746. """Get the placeholder for _is_training, i.e., for batch norm layers.
  747. This can be called safely before __init__ has run.
  748. """
  749. if not hasattr(self, "_is_training"):
  750. self._is_training = tf1.placeholder_with_default(
  751. False, (), name="is_training")
  752. return self._is_training
  753. def _debug_vars(self):
  754. if log_once("grad_vars"):
  755. if self.config["_tf_policy_handles_more_than_one_loss"]:
  756. for group in self._grads_and_vars:
  757. for _, v in group:
  758. logger.info("Optimizing variable {}".format(v))
  759. else:
  760. for _, v in self._grads_and_vars:
  761. logger.info("Optimizing variable {}".format(v))
  762. def _extra_input_signature_def(self):
  763. """Extra input signatures to add when exporting tf model.
  764. Inferred from extra_compute_action_feed_dict()
  765. """
  766. feed_dict = self.extra_compute_action_feed_dict()
  767. return {
  768. k.name: tf1.saved_model.utils.build_tensor_info(k)
  769. for k in feed_dict.keys()
  770. }
  771. def _extra_output_signature_def(self):
  772. """Extra output signatures to add when exporting tf model.
  773. Inferred from extra_compute_action_fetches()
  774. """
  775. fetches = self.extra_compute_action_fetches()
  776. return {
  777. k: tf1.saved_model.utils.build_tensor_info(fetches[k])
  778. for k in fetches.keys()
  779. }
  780. def _build_signature_def(self):
  781. """Build signature def map for tensorflow SavedModelBuilder.
  782. """
  783. # build input signatures
  784. input_signature = self._extra_input_signature_def()
  785. input_signature["observations"] = \
  786. tf1.saved_model.utils.build_tensor_info(self._obs_input)
  787. if self._seq_lens is not None:
  788. input_signature[SampleBatch.SEQ_LENS] = \
  789. tf1.saved_model.utils.build_tensor_info(self._seq_lens)
  790. if self._prev_action_input is not None:
  791. input_signature["prev_action"] = \
  792. tf1.saved_model.utils.build_tensor_info(
  793. self._prev_action_input)
  794. if self._prev_reward_input is not None:
  795. input_signature["prev_reward"] = \
  796. tf1.saved_model.utils.build_tensor_info(
  797. self._prev_reward_input)
  798. input_signature["is_training"] = \
  799. tf1.saved_model.utils.build_tensor_info(self._is_training)
  800. if self._timestep is not None:
  801. input_signature["timestep"] = \
  802. tf1.saved_model.utils.build_tensor_info(self._timestep)
  803. for state_input in self._state_inputs:
  804. input_signature[state_input.name] = \
  805. tf1.saved_model.utils.build_tensor_info(state_input)
  806. # build output signatures
  807. output_signature = self._extra_output_signature_def()
  808. for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
  809. output_signature["actions_{}".format(i)] = \
  810. tf1.saved_model.utils.build_tensor_info(a)
  811. for state_output in self._state_outputs:
  812. output_signature[state_output.name] = \
  813. tf1.saved_model.utils.build_tensor_info(state_output)
  814. signature_def = (
  815. tf1.saved_model.signature_def_utils.build_signature_def(
  816. input_signature, output_signature,
  817. tf1.saved_model.signature_constants.PREDICT_METHOD_NAME))
  818. signature_def_key = (tf1.saved_model.signature_constants.
  819. DEFAULT_SERVING_SIGNATURE_DEF_KEY)
  820. signature_def_map = {signature_def_key: signature_def}
  821. return signature_def_map
  822. def _build_compute_actions(self,
  823. builder,
  824. *,
  825. input_dict=None,
  826. obs_batch=None,
  827. state_batches=None,
  828. prev_action_batch=None,
  829. prev_reward_batch=None,
  830. episodes=None,
  831. explore=None,
  832. timestep=None):
  833. explore = explore if explore is not None else self.config["explore"]
  834. timestep = timestep if timestep is not None else self.global_timestep
  835. # Call the exploration before_compute_actions hook.
  836. self.exploration.before_compute_actions(
  837. timestep=timestep, explore=explore, tf_sess=self.get_session())
  838. builder.add_feed_dict(self.extra_compute_action_feed_dict())
  839. # `input_dict` given: Simply build what's in that dict.
  840. if input_dict is not None:
  841. if hasattr(self, "_input_dict"):
  842. for key, value in input_dict.items():
  843. if key in self._input_dict:
  844. # Handle complex/nested spaces as well.
  845. tree.map_structure(
  846. lambda k, v: builder.add_feed_dict({k: v}),
  847. self._input_dict[key], value,
  848. )
  849. # For policies that inherit directly from TFPolicy.
  850. else:
  851. builder.add_feed_dict({
  852. self._obs_input: input_dict[SampleBatch.OBS]
  853. })
  854. if SampleBatch.PREV_ACTIONS in input_dict:
  855. builder.add_feed_dict({
  856. self._prev_action_input: input_dict[
  857. SampleBatch.PREV_ACTIONS]
  858. })
  859. if SampleBatch.PREV_REWARDS in input_dict:
  860. builder.add_feed_dict({
  861. self._prev_reward_input: input_dict[
  862. SampleBatch.PREV_REWARDS]
  863. })
  864. state_batches = []
  865. i = 0
  866. while "state_in_{}".format(i) in input_dict:
  867. state_batches.append(input_dict["state_in_{}".format(i)])
  868. i += 1
  869. builder.add_feed_dict(
  870. dict(zip(self._state_inputs, state_batches)))
  871. if "state_in_0" in input_dict:
  872. builder.add_feed_dict({
  873. self._seq_lens: np.ones(len(input_dict["state_in_0"]))
  874. })
  875. # Hardcoded old way: Build fixed fields, if provided.
  876. # TODO: (sven) This can be deprecated after trajectory view API flag is
  877. # removed and always True.
  878. else:
  879. if log_once("_build_compute_actions_input_dict"):
  880. deprecation_warning(
  881. old="_build_compute_actions(.., obs_batch=.., ..)",
  882. new="_build_compute_actions(.., input_dict=..)",
  883. error=False,
  884. )
  885. state_batches = state_batches or []
  886. if len(self._state_inputs) != len(state_batches):
  887. raise ValueError(
  888. "Must pass in RNN state batches for placeholders {}, "
  889. "got {}".format(self._state_inputs, state_batches))
  890. tree.map_structure(
  891. lambda k, v: builder.add_feed_dict({k: v}),
  892. self._obs_input, obs_batch,
  893. )
  894. if state_batches:
  895. builder.add_feed_dict({
  896. self._seq_lens: np.ones(len(obs_batch))
  897. })
  898. if self._prev_action_input is not None and \
  899. prev_action_batch is not None:
  900. builder.add_feed_dict({
  901. self._prev_action_input: prev_action_batch
  902. })
  903. if self._prev_reward_input is not None and \
  904. prev_reward_batch is not None:
  905. builder.add_feed_dict({
  906. self._prev_reward_input: prev_reward_batch
  907. })
  908. builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
  909. builder.add_feed_dict({self._is_training: False})
  910. builder.add_feed_dict({self._is_exploring: explore})
  911. if timestep is not None:
  912. builder.add_feed_dict({self._timestep: timestep})
  913. # Determine, what exactly to fetch from the graph.
  914. to_fetch = [self._sampled_action] + self._state_outputs + \
  915. [self.extra_compute_action_fetches()]
  916. # Perform the session call.
  917. fetches = builder.add_fetches(to_fetch)
  918. return fetches[0], fetches[1:-1], fetches[-1]
  919. def _build_compute_gradients(self, builder, postprocessed_batch):
  920. self._debug_vars()
  921. builder.add_feed_dict(self.extra_compute_grad_feed_dict())
  922. builder.add_feed_dict(
  923. self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
  924. fetches = builder.add_fetches(
  925. [self._grads, self._get_grad_and_stats_fetches()])
  926. return fetches[0], fetches[1]
  927. def _build_apply_gradients(self, builder, gradients):
  928. if len(gradients) != len(self._grads):
  929. raise ValueError(
  930. "Unexpected number of gradients to apply, got {} for {}".
  931. format(gradients, self._grads))
  932. builder.add_feed_dict({self._is_training: True})
  933. builder.add_feed_dict(dict(zip(self._grads, gradients)))
  934. fetches = builder.add_fetches([self._apply_op])
  935. return fetches[0]
  936. def _build_learn_on_batch(self, builder, postprocessed_batch):
  937. self._debug_vars()
  938. builder.add_feed_dict(self.extra_compute_grad_feed_dict())
  939. builder.add_feed_dict(
  940. self._get_loss_inputs_dict(postprocessed_batch, shuffle=False))
  941. fetches = builder.add_fetches([
  942. self._apply_op,
  943. self._get_grad_and_stats_fetches(),
  944. ])
  945. return fetches[1]
  946. def _get_grad_and_stats_fetches(self):
  947. fetches = self.extra_compute_grad_fetches()
  948. if LEARNER_STATS_KEY not in fetches:
  949. raise ValueError(
  950. "Grad fetches should contain 'stats': {...} entry")
  951. if self._stats_fetches:
  952. fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
  953. **fetches[LEARNER_STATS_KEY])
  954. return fetches
  955. def _get_loss_inputs_dict(self, train_batch: SampleBatch, shuffle: bool):
  956. """Return a feed dict from a batch.
  957. Args:
  958. train_batch (SampleBatch): batch of data to derive inputs from.
  959. shuffle (bool): whether to shuffle batch sequences. Shuffle may
  960. be done in-place. This only makes sense if you're further
  961. applying minibatch SGD after getting the outputs.
  962. Returns:
  963. Feed dict of data.
  964. """
  965. # Get batch ready for RNNs, if applicable.
  966. if not isinstance(train_batch,
  967. SampleBatch) or not train_batch.zero_padded:
  968. pad_batch_to_sequences_of_same_size(
  969. train_batch,
  970. max_seq_len=self._max_seq_len,
  971. shuffle=shuffle,
  972. batch_divisibility_req=self._batch_divisibility_req,
  973. feature_keys=list(self._loss_input_dict_no_rnn.keys()),
  974. view_requirements=self.view_requirements,
  975. )
  976. # Mark the batch as "is_training" so the Model can use this
  977. # information.
  978. train_batch.set_training(True)
  979. # Build the feed dict from the batch.
  980. feed_dict = {}
  981. for key, placeholders in self._loss_input_dict.items():
  982. tree.map_structure(
  983. lambda ph, v: feed_dict.__setitem__(ph, v),
  984. placeholders,
  985. train_batch[key],
  986. )
  987. state_keys = [
  988. "state_in_{}".format(i) for i in range(len(self._state_inputs))
  989. ]
  990. for key in state_keys:
  991. feed_dict[self._loss_input_dict[key]] = train_batch[key]
  992. if state_keys:
  993. feed_dict[self._seq_lens] = train_batch[SampleBatch.SEQ_LENS]
  994. return feed_dict
  995. @DeveloperAPI
  996. class LearningRateSchedule:
  997. """Mixin for TFPolicy that adds a learning rate schedule."""
  998. @DeveloperAPI
  999. def __init__(self, lr, lr_schedule):
  1000. self._lr_schedule = None
  1001. if lr_schedule is None:
  1002. self.cur_lr = tf1.get_variable(
  1003. "lr", initializer=lr, trainable=False)
  1004. else:
  1005. self._lr_schedule = PiecewiseSchedule(
  1006. lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
  1007. self.cur_lr = tf1.get_variable(
  1008. "lr", initializer=self._lr_schedule.value(0), trainable=False)
  1009. if self.framework == "tf":
  1010. self._lr_placeholder = tf1.placeholder(
  1011. dtype=tf.float32, name="lr")
  1012. self._lr_update = self.cur_lr.assign(
  1013. self._lr_placeholder, read_value=False)
  1014. @override(Policy)
  1015. def on_global_var_update(self, global_vars):
  1016. super(LearningRateSchedule, self).on_global_var_update(global_vars)
  1017. if self._lr_schedule is not None:
  1018. new_val = self._lr_schedule.value(global_vars["timestep"])
  1019. if self.framework == "tf":
  1020. self.get_session().run(
  1021. self._lr_update, feed_dict={self._lr_placeholder: new_val})
  1022. else:
  1023. self.cur_lr.assign(new_val, read_value=False)
  1024. # This property (self._optimizer) is (still) accessible for
  1025. # both TFPolicy and any TFPolicy_eager.
  1026. self._optimizer.learning_rate.assign(self.cur_lr)
  1027. @override(TFPolicy)
  1028. def optimizer(self):
  1029. return tf1.train.AdamOptimizer(learning_rate=self.cur_lr)
  1030. @DeveloperAPI
  1031. class EntropyCoeffSchedule:
  1032. """Mixin for TFPolicy that adds entropy coeff decay."""
  1033. @DeveloperAPI
  1034. def __init__(self, entropy_coeff, entropy_coeff_schedule):
  1035. self._entropy_coeff_schedule = None
  1036. if entropy_coeff_schedule is None:
  1037. self.entropy_coeff = get_variable(
  1038. entropy_coeff,
  1039. framework="tf",
  1040. tf_name="entropy_coeff",
  1041. trainable=False)
  1042. else:
  1043. # Allows for custom schedule similar to lr_schedule format
  1044. if isinstance(entropy_coeff_schedule, list):
  1045. self._entropy_coeff_schedule = PiecewiseSchedule(
  1046. entropy_coeff_schedule,
  1047. outside_value=entropy_coeff_schedule[-1][-1],
  1048. framework=None)
  1049. else:
  1050. # Implements previous version but enforces outside_value
  1051. self._entropy_coeff_schedule = PiecewiseSchedule(
  1052. [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
  1053. outside_value=0.0,
  1054. framework=None)
  1055. self.entropy_coeff = get_variable(
  1056. self._entropy_coeff_schedule.value(0),
  1057. framework="tf",
  1058. tf_name="entropy_coeff",
  1059. trainable=False)
  1060. if self.framework == "tf":
  1061. self._entropy_coeff_placeholder = tf1.placeholder(
  1062. dtype=tf.float32, name="entropy_coeff")
  1063. self._entropy_coeff_update = self.entropy_coeff.assign(
  1064. self._entropy_coeff_placeholder, read_value=False)
  1065. @override(Policy)
  1066. def on_global_var_update(self, global_vars):
  1067. super(EntropyCoeffSchedule, self).on_global_var_update(global_vars)
  1068. if self._entropy_coeff_schedule is not None:
  1069. new_val = self._entropy_coeff_schedule.value(
  1070. global_vars["timestep"])
  1071. if self.framework == "tf":
  1072. self.get_session().run(
  1073. self._entropy_coeff_update,
  1074. feed_dict={self._entropy_coeff_placeholder: new_val})
  1075. else:
  1076. self.entropy_coeff.assign(new_val, read_value=False)