eager_tf_policy.py 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. """Eager mode TF policy built using build_tf_policy().
  2. It supports both traced and non-traced eager execution modes."""
  3. import functools
  4. import logging
  5. import os
  6. import threading
  7. from typing import Dict, List, Optional, Tuple, Union
  8. import tree # pip install dm_tree
  9. from ray.rllib.evaluation.episode import Episode
  10. from ray.rllib.models.catalog import ModelCatalog
  11. from ray.rllib.models.repeated_values import RepeatedValues
  12. from ray.rllib.policy.policy import Policy, PolicyState
  13. from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
  14. from ray.rllib.policy.sample_batch import SampleBatch
  15. from ray.rllib.utils import add_mixins, force_list
  16. from ray.rllib.utils.annotations import override, DeveloperAPI
  17. from ray.rllib.utils.deprecation import (
  18. DEPRECATED_VALUE,
  19. deprecation_warning,
  20. )
  21. from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
  22. from ray.rllib.utils.framework import try_import_tf
  23. from ray.rllib.utils.metrics import (
  24. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
  25. NUM_AGENT_STEPS_TRAINED,
  26. NUM_GRAD_UPDATES_LIFETIME,
  27. )
  28. from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
  29. from ray.rllib.utils.numpy import convert_to_numpy
  30. from ray.rllib.utils.spaces.space_utils import normalize_action
  31. from ray.rllib.utils.tf_utils import get_gpu_devices
  32. from ray.rllib.utils.threading import with_lock
  33. from ray.rllib.utils.typing import (
  34. LocalOptimizer,
  35. ModelGradients,
  36. TensorType,
  37. TensorStructType,
  38. )
  39. from ray.util.debug import log_once
  40. tf1, tf, tfv = try_import_tf()
  41. logger = logging.getLogger(__name__)
  42. def _convert_to_tf(x, dtype=None):
  43. if isinstance(x, SampleBatch):
  44. dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
  45. return tree.map_structure(_convert_to_tf, dict_)
  46. elif isinstance(x, Policy):
  47. return x
  48. # Special handling of "Repeated" values.
  49. elif isinstance(x, RepeatedValues):
  50. return RepeatedValues(
  51. tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len
  52. )
  53. if x is not None:
  54. d = dtype
  55. return tree.map_structure(
  56. lambda f: _convert_to_tf(f, d)
  57. if isinstance(f, RepeatedValues)
  58. else tf.convert_to_tensor(f, d)
  59. if f is not None and not tf.is_tensor(f)
  60. else f,
  61. x,
  62. )
  63. return x
  64. def _convert_to_numpy(x):
  65. def _map(x):
  66. if isinstance(x, tf.Tensor):
  67. return x.numpy()
  68. return x
  69. try:
  70. return tf.nest.map_structure(_map, x)
  71. except AttributeError:
  72. raise TypeError(
  73. ("Object of type {} has no method to convert to numpy.").format(type(x))
  74. )
  75. def _convert_eager_inputs(func):
  76. @functools.wraps(func)
  77. def _func(*args, **kwargs):
  78. if tf.executing_eagerly():
  79. eager_args = [_convert_to_tf(x) for x in args]
  80. # TODO: (sven) find a way to remove key-specific hacks.
  81. eager_kwargs = {
  82. k: _convert_to_tf(v, dtype=tf.int64 if k == "timestep" else None)
  83. for k, v in kwargs.items()
  84. if k not in {"info_batch", "episodes"}
  85. }
  86. return func(*eager_args, **eager_kwargs)
  87. else:
  88. return func(*args, **kwargs)
  89. return _func
  90. def _convert_eager_outputs(func):
  91. @functools.wraps(func)
  92. def _func(*args, **kwargs):
  93. out = func(*args, **kwargs)
  94. if tf.executing_eagerly():
  95. out = tf.nest.map_structure(_convert_to_numpy, out)
  96. return out
  97. return _func
  98. def _disallow_var_creation(next_creator, **kw):
  99. v = next_creator(**kw)
  100. raise ValueError(
  101. "Detected a variable being created during an eager "
  102. "forward pass. Variables should only be created during "
  103. "model initialization: {}".format(v.name)
  104. )
  105. def _check_too_many_retraces(obj):
  106. """Asserts that a given number of re-traces is not breached."""
  107. def _func(self_, *args, **kwargs):
  108. if (
  109. self_.config.get("eager_max_retraces") is not None
  110. and self_._re_trace_counter > self_.config["eager_max_retraces"]
  111. ):
  112. raise RuntimeError(
  113. "Too many tf-eager re-traces detected! This could lead to"
  114. " significant slow-downs (even slower than running in "
  115. "tf-eager mode w/ `eager_tracing=False`). To switch off "
  116. "these re-trace counting checks, set `eager_max_retraces`"
  117. " in your config to None."
  118. )
  119. return obj(self_, *args, **kwargs)
  120. return _func
  121. @DeveloperAPI
  122. class EagerTFPolicy(Policy):
  123. """Dummy class to recognize any eagerized TFPolicy by its inheritance."""
  124. pass
  125. def _traced_eager_policy(eager_policy_cls):
  126. """Wrapper class that enables tracing for all eager policy methods.
  127. This is enabled by the `--trace`/`eager_tracing=True` config when
  128. framework=tf2.
  129. """
  130. class TracedEagerPolicy(eager_policy_cls):
  131. def __init__(self, *args, **kwargs):
  132. self._traced_learn_on_batch_helper = False
  133. self._traced_compute_actions_helper = False
  134. self._traced_compute_gradients_helper = False
  135. self._traced_apply_gradients_helper = False
  136. super(TracedEagerPolicy, self).__init__(*args, **kwargs)
  137. @_check_too_many_retraces
  138. @override(Policy)
  139. def compute_actions_from_input_dict(
  140. self,
  141. input_dict: Dict[str, TensorType],
  142. explore: bool = None,
  143. timestep: Optional[int] = None,
  144. episodes: Optional[List[Episode]] = None,
  145. **kwargs,
  146. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  147. """Traced version of Policy.compute_actions_from_input_dict."""
  148. # Create a traced version of `self._compute_actions_helper`.
  149. if self._traced_compute_actions_helper is False and not self._no_tracing:
  150. if self.config.get("_enable_rl_module_api"):
  151. self._compute_actions_helper_rl_module_explore = (
  152. _convert_eager_inputs(
  153. tf.function(
  154. super(
  155. TracedEagerPolicy, self
  156. )._compute_actions_helper_rl_module_explore,
  157. autograph=True,
  158. reduce_retracing=True,
  159. )
  160. )
  161. )
  162. self._compute_actions_helper_rl_module_inference = (
  163. _convert_eager_inputs(
  164. tf.function(
  165. super(
  166. TracedEagerPolicy, self
  167. )._compute_actions_helper_rl_module_inference,
  168. autograph=True,
  169. reduce_retracing=True,
  170. )
  171. )
  172. )
  173. else:
  174. self._compute_actions_helper = _convert_eager_inputs(
  175. tf.function(
  176. super(TracedEagerPolicy, self)._compute_actions_helper,
  177. autograph=False,
  178. reduce_retracing=True,
  179. )
  180. )
  181. self._traced_compute_actions_helper = True
  182. # Now that the helper method is traced, call super's
  183. # `compute_actions_from_input_dict()` (which will call the traced helper).
  184. return super(TracedEagerPolicy, self).compute_actions_from_input_dict(
  185. input_dict=input_dict,
  186. explore=explore,
  187. timestep=timestep,
  188. episodes=episodes,
  189. **kwargs,
  190. )
  191. @_check_too_many_retraces
  192. @override(eager_policy_cls)
  193. def learn_on_batch(self, samples):
  194. """Traced version of Policy.learn_on_batch."""
  195. # Create a traced version of `self._learn_on_batch_helper`.
  196. if self._traced_learn_on_batch_helper is False and not self._no_tracing:
  197. self._learn_on_batch_helper = _convert_eager_inputs(
  198. tf.function(
  199. super(TracedEagerPolicy, self)._learn_on_batch_helper,
  200. autograph=False,
  201. reduce_retracing=True,
  202. )
  203. )
  204. self._traced_learn_on_batch_helper = True
  205. # Now that the helper method is traced, call super's
  206. # apply_gradients (which will call the traced helper).
  207. return super(TracedEagerPolicy, self).learn_on_batch(samples)
  208. @_check_too_many_retraces
  209. @override(eager_policy_cls)
  210. def compute_gradients(self, samples: SampleBatch) -> ModelGradients:
  211. """Traced version of Policy.compute_gradients."""
  212. # Create a traced version of `self._compute_gradients_helper`.
  213. if self._traced_compute_gradients_helper is False and not self._no_tracing:
  214. self._compute_gradients_helper = _convert_eager_inputs(
  215. tf.function(
  216. super(TracedEagerPolicy, self)._compute_gradients_helper,
  217. autograph=False,
  218. reduce_retracing=True,
  219. )
  220. )
  221. self._traced_compute_gradients_helper = True
  222. # Now that the helper method is traced, call super's
  223. # `compute_gradients()` (which will call the traced helper).
  224. return super(TracedEagerPolicy, self).compute_gradients(samples)
  225. @_check_too_many_retraces
  226. @override(Policy)
  227. def apply_gradients(self, grads: ModelGradients) -> None:
  228. """Traced version of Policy.apply_gradients."""
  229. # Create a traced version of `self._apply_gradients_helper`.
  230. if self._traced_apply_gradients_helper is False and not self._no_tracing:
  231. self._apply_gradients_helper = _convert_eager_inputs(
  232. tf.function(
  233. super(TracedEagerPolicy, self)._apply_gradients_helper,
  234. autograph=False,
  235. reduce_retracing=True,
  236. )
  237. )
  238. self._traced_apply_gradients_helper = True
  239. # Now that the helper method is traced, call super's
  240. # `apply_gradients()` (which will call the traced helper).
  241. return super(TracedEagerPolicy, self).apply_gradients(grads)
  242. @classmethod
  243. def with_tracing(cls):
  244. # Already traced -> Return same class.
  245. return cls
  246. TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
  247. TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
  248. return TracedEagerPolicy
  249. class _OptimizerWrapper:
  250. def __init__(self, tape):
  251. self.tape = tape
  252. def compute_gradients(self, loss, var_list):
  253. return list(zip(self.tape.gradient(loss, var_list), var_list))
  254. def _build_eager_tf_policy(
  255. name,
  256. loss_fn,
  257. get_default_config=None,
  258. postprocess_fn=None,
  259. stats_fn=None,
  260. optimizer_fn=None,
  261. compute_gradients_fn=None,
  262. apply_gradients_fn=None,
  263. grad_stats_fn=None,
  264. extra_learn_fetches_fn=None,
  265. extra_action_out_fn=None,
  266. validate_spaces=None,
  267. before_init=None,
  268. before_loss_init=None,
  269. after_init=None,
  270. make_model=None,
  271. action_sampler_fn=None,
  272. action_distribution_fn=None,
  273. mixins=None,
  274. get_batch_divisibility_req=None,
  275. # Deprecated args.
  276. obs_include_prev_action_reward=DEPRECATED_VALUE,
  277. extra_action_fetches_fn=None,
  278. gradients_fn=None,
  279. ):
  280. """Build an eager TF policy.
  281. An eager policy runs all operations in eager mode, which makes debugging
  282. much simpler, but has lower performance.
  283. You shouldn't need to call this directly. Rather, prefer to build a TF
  284. graph policy and use set `.framework("tf2", eager_tracing=False) in your
  285. AlgorithmConfig to have it automatically be converted to an eager policy.
  286. This has the same signature as build_tf_policy()."""
  287. base = add_mixins(EagerTFPolicy, mixins)
  288. if obs_include_prev_action_reward != DEPRECATED_VALUE:
  289. deprecation_warning(old="obs_include_prev_action_reward", error=True)
  290. if extra_action_fetches_fn is not None:
  291. deprecation_warning(
  292. old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
  293. )
  294. if gradients_fn is not None:
  295. deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
  296. class eager_policy_cls(base):
  297. def __init__(self, observation_space, action_space, config):
  298. # If this class runs as a @ray.remote actor, eager mode may not
  299. # have been activated yet.
  300. if not tf1.executing_eagerly():
  301. tf1.enable_eager_execution()
  302. self.framework = config.get("framework", "tf2")
  303. EagerTFPolicy.__init__(self, observation_space, action_space, config)
  304. # Global timestep should be a tensor.
  305. self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
  306. self.explore = tf.Variable(
  307. self.config["explore"], trainable=False, dtype=tf.bool
  308. )
  309. # Log device and worker index.
  310. num_gpus = self._get_num_gpus_for_policy()
  311. if num_gpus > 0:
  312. gpu_ids = get_gpu_devices()
  313. logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
  314. self._is_training = False
  315. # Only for `config.eager_tracing=True`: A counter to keep track of
  316. # how many times an eager-traced method (e.g.
  317. # `self._compute_actions_helper`) has been re-traced by tensorflow.
  318. # We will raise an error if more than n re-tracings have been
  319. # detected, since this would considerably slow down execution.
  320. # The variable below should only get incremented during the
  321. # tf.function trace operations, never when calling the already
  322. # traced function after that.
  323. self._re_trace_counter = 0
  324. self._loss_initialized = False
  325. # To ensure backward compatibility:
  326. # Old way: If `loss` provided here, use as-is (as a function).
  327. if loss_fn is not None:
  328. self._loss = loss_fn
  329. # New way: Convert the overridden `self.loss` into a plain
  330. # function, so it can be called the same way as `loss` would
  331. # be, ensuring backward compatibility.
  332. elif self.loss.__func__.__qualname__ != "Policy.loss":
  333. self._loss = self.loss.__func__
  334. # `loss` not provided nor overridden from Policy -> Set to None.
  335. else:
  336. self._loss = None
  337. self.batch_divisibility_req = (
  338. get_batch_divisibility_req(self)
  339. if callable(get_batch_divisibility_req)
  340. else (get_batch_divisibility_req or 1)
  341. )
  342. self._max_seq_len = config["model"]["max_seq_len"]
  343. if validate_spaces:
  344. validate_spaces(self, observation_space, action_space, config)
  345. if before_init:
  346. before_init(self, observation_space, action_space, config)
  347. self.config = config
  348. self.dist_class = None
  349. if action_sampler_fn or action_distribution_fn:
  350. if not make_model:
  351. raise ValueError(
  352. "`make_model` is required if `action_sampler_fn` OR "
  353. "`action_distribution_fn` is given"
  354. )
  355. else:
  356. self.dist_class, logit_dim = ModelCatalog.get_action_dist(
  357. action_space, self.config["model"]
  358. )
  359. if make_model:
  360. self.model = make_model(self, observation_space, action_space, config)
  361. else:
  362. self.model = ModelCatalog.get_model_v2(
  363. observation_space,
  364. action_space,
  365. logit_dim,
  366. config["model"],
  367. framework=self.framework,
  368. )
  369. # Lock used for locking some methods on the object-level.
  370. # This prevents possible race conditions when calling the model
  371. # first, then its value function (e.g. in a loss function), in
  372. # between of which another model call is made (e.g. to compute an
  373. # action).
  374. self._lock = threading.RLock()
  375. if self.config.get("_enable_rl_module_api", False):
  376. # Maybe update view_requirements, e.g. for recurrent case.
  377. self.view_requirements = self.model.update_default_view_requirements(
  378. self.view_requirements
  379. )
  380. else:
  381. # Auto-update model's inference view requirements, if recurrent.
  382. self._update_model_view_requirements_from_init_state()
  383. # Combine view_requirements for Model and Policy.
  384. self.view_requirements.update(self.model.view_requirements)
  385. self.exploration = self._create_exploration()
  386. self._state_inputs = self.model.get_initial_state()
  387. self._is_recurrent = len(self._state_inputs) > 0
  388. if before_loss_init:
  389. before_loss_init(self, observation_space, action_space, config)
  390. if optimizer_fn:
  391. optimizers = optimizer_fn(self, config)
  392. else:
  393. optimizers = tf.keras.optimizers.Adam(config["lr"])
  394. optimizers = force_list(optimizers)
  395. if self.exploration:
  396. optimizers = self.exploration.get_exploration_optimizer(optimizers)
  397. # The list of local (tf) optimizers (one per loss term).
  398. self._optimizers: List[LocalOptimizer] = optimizers
  399. # Backward compatibility: A user's policy may only support a single
  400. # loss term and optimizer (no lists).
  401. self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
  402. self._initialize_loss_from_dummy_batch(
  403. auto_remove_unneeded_view_reqs=True,
  404. stats_fn=stats_fn,
  405. )
  406. self._loss_initialized = True
  407. if after_init:
  408. after_init(self, observation_space, action_space, config)
  409. # Got to reset global_timestep again after fake run-throughs.
  410. self.global_timestep.assign(0)
  411. @override(Policy)
  412. def compute_actions_from_input_dict(
  413. self,
  414. input_dict: Dict[str, TensorType],
  415. explore: bool = None,
  416. timestep: Optional[int] = None,
  417. episodes: Optional[List[Episode]] = None,
  418. **kwargs,
  419. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  420. if not self.config.get("eager_tracing") and not tf1.executing_eagerly():
  421. tf1.enable_eager_execution()
  422. self._is_training = False
  423. explore = explore if explore is not None else self.explore
  424. timestep = timestep if timestep is not None else self.global_timestep
  425. if isinstance(timestep, tf.Tensor):
  426. timestep = int(timestep.numpy())
  427. # Pass lazy (eager) tensor dict to Model as `input_dict`.
  428. input_dict = self._lazy_tensor_dict(input_dict)
  429. input_dict.set_training(False)
  430. # Pack internal state inputs into (separate) list.
  431. state_batches = [
  432. input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
  433. ]
  434. self._state_in = state_batches
  435. self._is_recurrent = state_batches != []
  436. # Call the exploration before_compute_actions hook.
  437. self.exploration.before_compute_actions(
  438. timestep=timestep, explore=explore, tf_sess=self.get_session()
  439. )
  440. ret = self._compute_actions_helper(
  441. input_dict,
  442. state_batches,
  443. # TODO: Passing episodes into a traced method does not work.
  444. None if self.config["eager_tracing"] else episodes,
  445. explore,
  446. timestep,
  447. )
  448. # Update our global timestep by the batch size.
  449. self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
  450. return convert_to_numpy(ret)
  451. @override(Policy)
  452. def compute_actions(
  453. self,
  454. obs_batch: Union[List[TensorStructType], TensorStructType],
  455. state_batches: Optional[List[TensorType]] = None,
  456. prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
  457. prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
  458. info_batch: Optional[Dict[str, list]] = None,
  459. episodes: Optional[List["Episode"]] = None,
  460. explore: Optional[bool] = None,
  461. timestep: Optional[int] = None,
  462. **kwargs,
  463. ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
  464. # Create input dict to simply pass the entire call to
  465. # self.compute_actions_from_input_dict().
  466. input_dict = SampleBatch(
  467. {
  468. SampleBatch.CUR_OBS: obs_batch,
  469. },
  470. _is_training=tf.constant(False),
  471. )
  472. if state_batches is not None:
  473. for i, s in enumerate(state_batches):
  474. input_dict[f"state_in_{i}"] = s
  475. if prev_action_batch is not None:
  476. input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
  477. if prev_reward_batch is not None:
  478. input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
  479. if info_batch is not None:
  480. input_dict[SampleBatch.INFOS] = info_batch
  481. return self.compute_actions_from_input_dict(
  482. input_dict=input_dict,
  483. explore=explore,
  484. timestep=timestep,
  485. episodes=episodes,
  486. **kwargs,
  487. )
  488. @with_lock
  489. @override(Policy)
  490. def compute_log_likelihoods(
  491. self,
  492. actions,
  493. obs_batch,
  494. state_batches=None,
  495. prev_action_batch=None,
  496. prev_reward_batch=None,
  497. actions_normalized=True,
  498. **kwargs,
  499. ):
  500. if action_sampler_fn and action_distribution_fn is None:
  501. raise ValueError(
  502. "Cannot compute log-prob/likelihood w/o an "
  503. "`action_distribution_fn` and a provided "
  504. "`action_sampler_fn`!"
  505. )
  506. seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
  507. input_batch = SampleBatch(
  508. {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
  509. _is_training=False,
  510. )
  511. if prev_action_batch is not None:
  512. input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
  513. prev_action_batch
  514. )
  515. if prev_reward_batch is not None:
  516. input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
  517. prev_reward_batch
  518. )
  519. if self.exploration:
  520. # Exploration hook before each forward pass.
  521. self.exploration.before_compute_actions(explore=False)
  522. # Action dist class and inputs are generated via custom function.
  523. if action_distribution_fn:
  524. dist_inputs, dist_class, _ = action_distribution_fn(
  525. self, self.model, input_batch, explore=False, is_training=False
  526. )
  527. # Default log-likelihood calculation.
  528. else:
  529. dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
  530. dist_class = self.dist_class
  531. action_dist = dist_class(dist_inputs, self.model)
  532. # Normalize actions if necessary.
  533. if not actions_normalized and self.config["normalize_actions"]:
  534. actions = normalize_action(actions, self.action_space_struct)
  535. log_likelihoods = action_dist.logp(actions)
  536. return log_likelihoods
  537. @override(Policy)
  538. def postprocess_trajectory(
  539. self, sample_batch, other_agent_batches=None, episode=None
  540. ):
  541. assert tf.executing_eagerly()
  542. # Call super's postprocess_trajectory first.
  543. sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
  544. if postprocess_fn:
  545. return postprocess_fn(self, sample_batch, other_agent_batches, episode)
  546. return sample_batch
  547. @with_lock
  548. @override(Policy)
  549. def learn_on_batch(self, postprocessed_batch):
  550. # Callback handling.
  551. learn_stats = {}
  552. self.callbacks.on_learn_on_batch(
  553. policy=self, train_batch=postprocessed_batch, result=learn_stats
  554. )
  555. pad_batch_to_sequences_of_same_size(
  556. postprocessed_batch,
  557. max_seq_len=self._max_seq_len,
  558. shuffle=False,
  559. batch_divisibility_req=self.batch_divisibility_req,
  560. view_requirements=self.view_requirements,
  561. )
  562. self._is_training = True
  563. postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
  564. postprocessed_batch.set_training(True)
  565. stats = self._learn_on_batch_helper(postprocessed_batch)
  566. self.num_grad_updates += 1
  567. stats.update(
  568. {
  569. "custom_metrics": learn_stats,
  570. NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
  571. NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
  572. # -1, b/c we have to measure this diff before we do the update
  573. # above.
  574. DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
  575. self.num_grad_updates
  576. - 1
  577. - (postprocessed_batch.num_grad_updates or 0)
  578. ),
  579. }
  580. )
  581. return convert_to_numpy(stats)
  582. @override(Policy)
  583. def compute_gradients(
  584. self, postprocessed_batch: SampleBatch
  585. ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
  586. pad_batch_to_sequences_of_same_size(
  587. postprocessed_batch,
  588. shuffle=False,
  589. max_seq_len=self._max_seq_len,
  590. batch_divisibility_req=self.batch_divisibility_req,
  591. view_requirements=self.view_requirements,
  592. )
  593. self._is_training = True
  594. self._lazy_tensor_dict(postprocessed_batch)
  595. postprocessed_batch.set_training(True)
  596. grads_and_vars, grads, stats = self._compute_gradients_helper(
  597. postprocessed_batch
  598. )
  599. return convert_to_numpy((grads, stats))
  600. @override(Policy)
  601. def apply_gradients(self, gradients: ModelGradients) -> None:
  602. self._apply_gradients_helper(
  603. list(
  604. zip(
  605. [
  606. (tf.convert_to_tensor(g) if g is not None else None)
  607. for g in gradients
  608. ],
  609. self.model.trainable_variables(),
  610. )
  611. )
  612. )
  613. @override(Policy)
  614. def get_weights(self, as_dict=False):
  615. variables = self.variables()
  616. if as_dict:
  617. return {v.name: v.numpy() for v in variables}
  618. return [v.numpy() for v in variables]
  619. @override(Policy)
  620. def set_weights(self, weights):
  621. variables = self.variables()
  622. assert len(weights) == len(variables), (len(weights), len(variables))
  623. for v, w in zip(variables, weights):
  624. v.assign(w)
  625. @override(Policy)
  626. def get_exploration_state(self):
  627. return convert_to_numpy(self.exploration.get_state())
  628. @override(Policy)
  629. def is_recurrent(self):
  630. return self._is_recurrent
  631. @override(Policy)
  632. def num_state_tensors(self):
  633. return len(self._state_inputs)
  634. @override(Policy)
  635. def get_initial_state(self):
  636. if hasattr(self, "model"):
  637. return self.model.get_initial_state()
  638. return []
  639. @override(Policy)
  640. def get_state(self) -> PolicyState:
  641. # Legacy Policy state (w/o keras model and w/o PolicySpec).
  642. state = super().get_state()
  643. state["global_timestep"] = state["global_timestep"].numpy()
  644. if self._optimizer and len(self._optimizer.variables()) > 0:
  645. state["_optimizer_variables"] = self._optimizer.variables()
  646. # Add exploration state.
  647. if not self.config.get("_enable_rl_module_api", False) and self.exploration:
  648. # This is not compatible with RLModules, which have a method
  649. # `forward_exploration` to specify custom exploration behavior.
  650. state["_exploration_state"] = self.exploration.get_state()
  651. return state
  652. @override(Policy)
  653. def set_state(self, state: PolicyState) -> None:
  654. # Set optimizer vars first.
  655. optimizer_vars = state.get("_optimizer_variables", None)
  656. if optimizer_vars and self._optimizer.variables():
  657. if not type(self).__name__.endswith("_traced") and log_once(
  658. "set_state_optimizer_vars_tf_eager_policy_v2"
  659. ):
  660. logger.warning(
  661. "Cannot restore an optimizer's state for tf eager! Keras "
  662. "is not able to save the v1.x optimizers (from "
  663. "tf.compat.v1.train) since they aren't compatible with "
  664. "checkpoints."
  665. )
  666. for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
  667. opt_var.assign(value)
  668. # Set exploration's state.
  669. if hasattr(self, "exploration") and "_exploration_state" in state:
  670. self.exploration.set_state(state=state["_exploration_state"])
  671. # Restore glbal timestep (tf vars).
  672. self.global_timestep.assign(state["global_timestep"])
  673. # Then the Policy's (NN) weights and connectors.
  674. super().set_state(state)
  675. @override(Policy)
  676. def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
  677. """Exports the Policy's Model to local directory for serving.
  678. Note: Since the TfModelV2 class that EagerTfPolicy uses is-NOT-a
  679. tf.keras.Model, we need to assume that there is a `base_model` property
  680. within this TfModelV2 class that is-a tf.keras.Model. This base model
  681. will be used here for the export.
  682. TODO (kourosh): This restriction will be resolved once we move Policy and
  683. ModelV2 to the new Learner/RLModule APIs.
  684. Args:
  685. export_dir: Local writable directory.
  686. onnx: If given, will export model in ONNX format. The
  687. value of this parameter set the ONNX OpSet version to use.
  688. """
  689. if (
  690. hasattr(self, "model")
  691. and hasattr(self.model, "base_model")
  692. and isinstance(self.model.base_model, tf.keras.Model)
  693. ):
  694. # Store model in ONNX format.
  695. if onnx:
  696. try:
  697. import tf2onnx
  698. except ImportError as e:
  699. raise RuntimeError(
  700. "Converting a TensorFlow model to ONNX requires "
  701. "`tf2onnx` to be installed. Install with "
  702. "`pip install tf2onnx`."
  703. ) from e
  704. model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
  705. self.model.base_model,
  706. output_path=os.path.join(export_dir, "model.onnx"),
  707. )
  708. # Save the tf.keras.Model (architecture and weights, so it can be
  709. # retrieved w/o access to the original (custom) Model or Policy code).
  710. else:
  711. try:
  712. self.model.base_model.save(export_dir, save_format="tf")
  713. except Exception:
  714. logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
  715. else:
  716. logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
  717. def variables(self):
  718. """Return the list of all savable variables for this policy."""
  719. if isinstance(self.model, tf.keras.Model):
  720. return self.model.variables
  721. else:
  722. return self.model.variables()
  723. def loss_initialized(self):
  724. return self._loss_initialized
  725. @with_lock
  726. def _compute_actions_helper(
  727. self, input_dict, state_batches, episodes, explore, timestep
  728. ):
  729. # Increase the tracing counter to make sure we don't re-trace too
  730. # often. If eager_tracing=True, this counter should only get
  731. # incremented during the @tf.function trace operations, never when
  732. # calling the already traced function after that.
  733. self._re_trace_counter += 1
  734. # Calculate RNN sequence lengths.
  735. batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
  736. seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
  737. # Add default and custom fetches.
  738. extra_fetches = {}
  739. # Use Exploration object.
  740. with tf.variable_creator_scope(_disallow_var_creation):
  741. if action_sampler_fn:
  742. action_sampler_outputs = action_sampler_fn(
  743. self,
  744. self.model,
  745. input_dict[SampleBatch.CUR_OBS],
  746. explore=explore,
  747. timestep=timestep,
  748. episodes=episodes,
  749. )
  750. if len(action_sampler_outputs) == 4:
  751. actions, logp, dist_inputs, state_out = action_sampler_outputs
  752. else:
  753. dist_inputs = None
  754. state_out = []
  755. actions, logp = action_sampler_outputs
  756. else:
  757. if action_distribution_fn:
  758. # Try new action_distribution_fn signature, supporting
  759. # state_batches and seq_lens.
  760. try:
  761. (
  762. dist_inputs,
  763. self.dist_class,
  764. state_out,
  765. ) = action_distribution_fn(
  766. self,
  767. self.model,
  768. input_dict=input_dict,
  769. state_batches=state_batches,
  770. seq_lens=seq_lens,
  771. explore=explore,
  772. timestep=timestep,
  773. is_training=False,
  774. )
  775. # Trying the old way (to stay backward compatible).
  776. # TODO: Remove in future.
  777. except TypeError as e:
  778. if (
  779. "positional argument" in e.args[0]
  780. or "unexpected keyword argument" in e.args[0]
  781. ):
  782. (
  783. dist_inputs,
  784. self.dist_class,
  785. state_out,
  786. ) = action_distribution_fn(
  787. self,
  788. self.model,
  789. input_dict[SampleBatch.OBS],
  790. explore=explore,
  791. timestep=timestep,
  792. is_training=False,
  793. )
  794. else:
  795. raise e
  796. elif isinstance(self.model, tf.keras.Model):
  797. input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
  798. if state_batches and "state_in_0" not in input_dict:
  799. for i, s in enumerate(state_batches):
  800. input_dict[f"state_in_{i}"] = s
  801. self._lazy_tensor_dict(input_dict)
  802. dist_inputs, state_out, extra_fetches = self.model(input_dict)
  803. else:
  804. dist_inputs, state_out = self.model(
  805. input_dict, state_batches, seq_lens
  806. )
  807. action_dist = self.dist_class(dist_inputs, self.model)
  808. # Get the exploration action from the forward results.
  809. actions, logp = self.exploration.get_exploration_action(
  810. action_distribution=action_dist,
  811. timestep=timestep,
  812. explore=explore,
  813. )
  814. # Action-logp and action-prob.
  815. if logp is not None:
  816. extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
  817. extra_fetches[SampleBatch.ACTION_LOGP] = logp
  818. # Action-dist inputs.
  819. if dist_inputs is not None:
  820. extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
  821. # Custom extra fetches.
  822. if extra_action_out_fn:
  823. extra_fetches.update(extra_action_out_fn(self))
  824. return actions, state_out, extra_fetches
  825. # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
  826. # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
  827. # It seems there may be a clash between the traced-by-tf function and the
  828. # traced-by-ray functions (for making the policy class a ray actor).
  829. def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
  830. # Increase the tracing counter to make sure we don't re-trace too
  831. # often. If eager_tracing=True, this counter should only get
  832. # incremented during the @tf.function trace operations, never when
  833. # calling the already traced function after that.
  834. self._re_trace_counter += 1
  835. with tf.variable_creator_scope(_disallow_var_creation):
  836. grads_and_vars, _, stats = self._compute_gradients_helper(samples)
  837. self._apply_gradients_helper(grads_and_vars)
  838. return stats
  839. def _get_is_training_placeholder(self):
  840. return tf.convert_to_tensor(self._is_training)
  841. @with_lock
  842. def _compute_gradients_helper(self, samples):
  843. """Computes and returns grads as eager tensors."""
  844. # Increase the tracing counter to make sure we don't re-trace too
  845. # often. If eager_tracing=True, this counter should only get
  846. # incremented during the @tf.function trace operations, never when
  847. # calling the already traced function after that.
  848. self._re_trace_counter += 1
  849. # Gather all variables for which to calculate losses.
  850. if isinstance(self.model, tf.keras.Model):
  851. variables = self.model.trainable_variables
  852. else:
  853. variables = self.model.trainable_variables()
  854. # Calculate the loss(es) inside a tf GradientTape.
  855. with tf.GradientTape(persistent=compute_gradients_fn is not None) as tape:
  856. losses = self._loss(self, self.model, self.dist_class, samples)
  857. losses = force_list(losses)
  858. # User provided a compute_gradients_fn.
  859. if compute_gradients_fn:
  860. # Wrap our tape inside a wrapper, such that the resulting
  861. # object looks like a "classic" tf.optimizer. This way, custom
  862. # compute_gradients_fn will work on both tf static graph
  863. # and tf-eager.
  864. optimizer = _OptimizerWrapper(tape)
  865. # More than one loss terms/optimizers.
  866. if self.config["_tf_policy_handles_more_than_one_loss"]:
  867. grads_and_vars = compute_gradients_fn(
  868. self, [optimizer] * len(losses), losses
  869. )
  870. # Only one loss and one optimizer.
  871. else:
  872. grads_and_vars = [compute_gradients_fn(self, optimizer, losses[0])]
  873. # Default: Compute gradients using the above tape.
  874. else:
  875. grads_and_vars = [
  876. list(zip(tape.gradient(loss, variables), variables))
  877. for loss in losses
  878. ]
  879. if log_once("grad_vars"):
  880. for g_and_v in grads_and_vars:
  881. for g, v in g_and_v:
  882. if g is not None:
  883. logger.info(f"Optimizing variable {v.name}")
  884. # `grads_and_vars` is returned a list (len=num optimizers/losses)
  885. # of lists of (grad, var) tuples.
  886. if self.config["_tf_policy_handles_more_than_one_loss"]:
  887. grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
  888. # `grads_and_vars` is returned as a list of (grad, var) tuples.
  889. else:
  890. grads_and_vars = grads_and_vars[0]
  891. grads = [g for g, _ in grads_and_vars]
  892. stats = self._stats(self, samples, grads)
  893. return grads_and_vars, grads, stats
  894. def _apply_gradients_helper(self, grads_and_vars):
  895. # Increase the tracing counter to make sure we don't re-trace too
  896. # often. If eager_tracing=True, this counter should only get
  897. # incremented during the @tf.function trace operations, never when
  898. # calling the already traced function after that.
  899. self._re_trace_counter += 1
  900. if apply_gradients_fn:
  901. if self.config["_tf_policy_handles_more_than_one_loss"]:
  902. apply_gradients_fn(self, self._optimizers, grads_and_vars)
  903. else:
  904. apply_gradients_fn(self, self._optimizer, grads_and_vars)
  905. else:
  906. if self.config["_tf_policy_handles_more_than_one_loss"]:
  907. for i, o in enumerate(self._optimizers):
  908. o.apply_gradients(
  909. [(g, v) for g, v in grads_and_vars[i] if g is not None]
  910. )
  911. else:
  912. self._optimizer.apply_gradients(
  913. [(g, v) for g, v in grads_and_vars if g is not None]
  914. )
  915. def _stats(self, outputs, samples, grads):
  916. fetches = {}
  917. if stats_fn:
  918. fetches[LEARNER_STATS_KEY] = {
  919. k: v for k, v in stats_fn(outputs, samples).items()
  920. }
  921. else:
  922. fetches[LEARNER_STATS_KEY] = {}
  923. if extra_learn_fetches_fn:
  924. fetches.update({k: v for k, v in extra_learn_fetches_fn(self).items()})
  925. if grad_stats_fn:
  926. fetches.update(
  927. {k: v for k, v in grad_stats_fn(self, samples, grads).items()}
  928. )
  929. return fetches
  930. def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
  931. # TODO: (sven): Keep for a while to ensure backward compatibility.
  932. if not isinstance(postprocessed_batch, SampleBatch):
  933. postprocessed_batch = SampleBatch(postprocessed_batch)
  934. postprocessed_batch.set_get_interceptor(_convert_to_tf)
  935. return postprocessed_batch
  936. @classmethod
  937. def with_tracing(cls):
  938. return _traced_eager_policy(cls)
  939. eager_policy_cls.__name__ = name + "_eager"
  940. eager_policy_cls.__qualname__ = name + "_eager"
  941. return eager_policy_cls