12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085 |
- """Eager mode TF policy built using build_tf_policy().
- It supports both traced and non-traced eager execution modes."""
- import functools
- import logging
- import os
- import threading
- from typing import Dict, List, Optional, Tuple, Union
- import tree # pip install dm_tree
- from ray.rllib.evaluation.episode import Episode
- from ray.rllib.models.catalog import ModelCatalog
- from ray.rllib.models.repeated_values import RepeatedValues
- from ray.rllib.policy.policy import Policy, PolicyState
- from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils import add_mixins, force_list
- from ray.rllib.utils.annotations import override, DeveloperAPI
- from ray.rllib.utils.deprecation import (
- DEPRECATED_VALUE,
- deprecation_warning,
- )
- from ray.rllib.utils.error import ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.metrics import (
- DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY,
- NUM_AGENT_STEPS_TRAINED,
- NUM_GRAD_UPDATES_LIFETIME,
- )
- from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.spaces.space_utils import normalize_action
- from ray.rllib.utils.tf_utils import get_gpu_devices
- from ray.rllib.utils.threading import with_lock
- from ray.rllib.utils.typing import (
- LocalOptimizer,
- ModelGradients,
- TensorType,
- TensorStructType,
- )
- from ray.util.debug import log_once
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- def _convert_to_tf(x, dtype=None):
- if isinstance(x, SampleBatch):
- dict_ = {k: v for k, v in x.items() if k != SampleBatch.INFOS}
- return tree.map_structure(_convert_to_tf, dict_)
- elif isinstance(x, Policy):
- return x
- # Special handling of "Repeated" values.
- elif isinstance(x, RepeatedValues):
- return RepeatedValues(
- tree.map_structure(_convert_to_tf, x.values), x.lengths, x.max_len
- )
- if x is not None:
- d = dtype
- return tree.map_structure(
- lambda f: _convert_to_tf(f, d)
- if isinstance(f, RepeatedValues)
- else tf.convert_to_tensor(f, d)
- if f is not None and not tf.is_tensor(f)
- else f,
- x,
- )
- return x
- def _convert_to_numpy(x):
- def _map(x):
- if isinstance(x, tf.Tensor):
- return x.numpy()
- return x
- try:
- return tf.nest.map_structure(_map, x)
- except AttributeError:
- raise TypeError(
- ("Object of type {} has no method to convert to numpy.").format(type(x))
- )
- def _convert_eager_inputs(func):
- @functools.wraps(func)
- def _func(*args, **kwargs):
- if tf.executing_eagerly():
- eager_args = [_convert_to_tf(x) for x in args]
- # TODO: (sven) find a way to remove key-specific hacks.
- eager_kwargs = {
- k: _convert_to_tf(v, dtype=tf.int64 if k == "timestep" else None)
- for k, v in kwargs.items()
- if k not in {"info_batch", "episodes"}
- }
- return func(*eager_args, **eager_kwargs)
- else:
- return func(*args, **kwargs)
- return _func
- def _convert_eager_outputs(func):
- @functools.wraps(func)
- def _func(*args, **kwargs):
- out = func(*args, **kwargs)
- if tf.executing_eagerly():
- out = tf.nest.map_structure(_convert_to_numpy, out)
- return out
- return _func
- def _disallow_var_creation(next_creator, **kw):
- v = next_creator(**kw)
- raise ValueError(
- "Detected a variable being created during an eager "
- "forward pass. Variables should only be created during "
- "model initialization: {}".format(v.name)
- )
- def _check_too_many_retraces(obj):
- """Asserts that a given number of re-traces is not breached."""
- def _func(self_, *args, **kwargs):
- if (
- self_.config.get("eager_max_retraces") is not None
- and self_._re_trace_counter > self_.config["eager_max_retraces"]
- ):
- raise RuntimeError(
- "Too many tf-eager re-traces detected! This could lead to"
- " significant slow-downs (even slower than running in "
- "tf-eager mode w/ `eager_tracing=False`). To switch off "
- "these re-trace counting checks, set `eager_max_retraces`"
- " in your config to None."
- )
- return obj(self_, *args, **kwargs)
- return _func
- @DeveloperAPI
- class EagerTFPolicy(Policy):
- """Dummy class to recognize any eagerized TFPolicy by its inheritance."""
- pass
- def _traced_eager_policy(eager_policy_cls):
- """Wrapper class that enables tracing for all eager policy methods.
- This is enabled by the `--trace`/`eager_tracing=True` config when
- framework=tf2.
- """
- class TracedEagerPolicy(eager_policy_cls):
- def __init__(self, *args, **kwargs):
- self._traced_learn_on_batch_helper = False
- self._traced_compute_actions_helper = False
- self._traced_compute_gradients_helper = False
- self._traced_apply_gradients_helper = False
- super(TracedEagerPolicy, self).__init__(*args, **kwargs)
- @_check_too_many_retraces
- @override(Policy)
- def compute_actions_from_input_dict(
- self,
- input_dict: Dict[str, TensorType],
- explore: bool = None,
- timestep: Optional[int] = None,
- episodes: Optional[List[Episode]] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- """Traced version of Policy.compute_actions_from_input_dict."""
- # Create a traced version of `self._compute_actions_helper`.
- if self._traced_compute_actions_helper is False and not self._no_tracing:
- if self.config.get("_enable_rl_module_api"):
- self._compute_actions_helper_rl_module_explore = (
- _convert_eager_inputs(
- tf.function(
- super(
- TracedEagerPolicy, self
- )._compute_actions_helper_rl_module_explore,
- autograph=True,
- reduce_retracing=True,
- )
- )
- )
- self._compute_actions_helper_rl_module_inference = (
- _convert_eager_inputs(
- tf.function(
- super(
- TracedEagerPolicy, self
- )._compute_actions_helper_rl_module_inference,
- autograph=True,
- reduce_retracing=True,
- )
- )
- )
- else:
- self._compute_actions_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._compute_actions_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_compute_actions_helper = True
- # Now that the helper method is traced, call super's
- # `compute_actions_from_input_dict()` (which will call the traced helper).
- return super(TracedEagerPolicy, self).compute_actions_from_input_dict(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- **kwargs,
- )
- @_check_too_many_retraces
- @override(eager_policy_cls)
- def learn_on_batch(self, samples):
- """Traced version of Policy.learn_on_batch."""
- # Create a traced version of `self._learn_on_batch_helper`.
- if self._traced_learn_on_batch_helper is False and not self._no_tracing:
- self._learn_on_batch_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._learn_on_batch_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_learn_on_batch_helper = True
- # Now that the helper method is traced, call super's
- # apply_gradients (which will call the traced helper).
- return super(TracedEagerPolicy, self).learn_on_batch(samples)
- @_check_too_many_retraces
- @override(eager_policy_cls)
- def compute_gradients(self, samples: SampleBatch) -> ModelGradients:
- """Traced version of Policy.compute_gradients."""
- # Create a traced version of `self._compute_gradients_helper`.
- if self._traced_compute_gradients_helper is False and not self._no_tracing:
- self._compute_gradients_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._compute_gradients_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_compute_gradients_helper = True
- # Now that the helper method is traced, call super's
- # `compute_gradients()` (which will call the traced helper).
- return super(TracedEagerPolicy, self).compute_gradients(samples)
- @_check_too_many_retraces
- @override(Policy)
- def apply_gradients(self, grads: ModelGradients) -> None:
- """Traced version of Policy.apply_gradients."""
- # Create a traced version of `self._apply_gradients_helper`.
- if self._traced_apply_gradients_helper is False and not self._no_tracing:
- self._apply_gradients_helper = _convert_eager_inputs(
- tf.function(
- super(TracedEagerPolicy, self)._apply_gradients_helper,
- autograph=False,
- reduce_retracing=True,
- )
- )
- self._traced_apply_gradients_helper = True
- # Now that the helper method is traced, call super's
- # `apply_gradients()` (which will call the traced helper).
- return super(TracedEagerPolicy, self).apply_gradients(grads)
- @classmethod
- def with_tracing(cls):
- # Already traced -> Return same class.
- return cls
- TracedEagerPolicy.__name__ = eager_policy_cls.__name__ + "_traced"
- TracedEagerPolicy.__qualname__ = eager_policy_cls.__qualname__ + "_traced"
- return TracedEagerPolicy
- class _OptimizerWrapper:
- def __init__(self, tape):
- self.tape = tape
- def compute_gradients(self, loss, var_list):
- return list(zip(self.tape.gradient(loss, var_list), var_list))
- def _build_eager_tf_policy(
- name,
- loss_fn,
- get_default_config=None,
- postprocess_fn=None,
- stats_fn=None,
- optimizer_fn=None,
- compute_gradients_fn=None,
- apply_gradients_fn=None,
- grad_stats_fn=None,
- extra_learn_fetches_fn=None,
- extra_action_out_fn=None,
- validate_spaces=None,
- before_init=None,
- before_loss_init=None,
- after_init=None,
- make_model=None,
- action_sampler_fn=None,
- action_distribution_fn=None,
- mixins=None,
- get_batch_divisibility_req=None,
- # Deprecated args.
- obs_include_prev_action_reward=DEPRECATED_VALUE,
- extra_action_fetches_fn=None,
- gradients_fn=None,
- ):
- """Build an eager TF policy.
- An eager policy runs all operations in eager mode, which makes debugging
- much simpler, but has lower performance.
- You shouldn't need to call this directly. Rather, prefer to build a TF
- graph policy and use set `.framework("tf2", eager_tracing=False) in your
- AlgorithmConfig to have it automatically be converted to an eager policy.
- This has the same signature as build_tf_policy()."""
- base = add_mixins(EagerTFPolicy, mixins)
- if obs_include_prev_action_reward != DEPRECATED_VALUE:
- deprecation_warning(old="obs_include_prev_action_reward", error=True)
- if extra_action_fetches_fn is not None:
- deprecation_warning(
- old="extra_action_fetches_fn", new="extra_action_out_fn", error=True
- )
- if gradients_fn is not None:
- deprecation_warning(old="gradients_fn", new="compute_gradients_fn", error=True)
- class eager_policy_cls(base):
- def __init__(self, observation_space, action_space, config):
- # If this class runs as a @ray.remote actor, eager mode may not
- # have been activated yet.
- if not tf1.executing_eagerly():
- tf1.enable_eager_execution()
- self.framework = config.get("framework", "tf2")
- EagerTFPolicy.__init__(self, observation_space, action_space, config)
- # Global timestep should be a tensor.
- self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64)
- self.explore = tf.Variable(
- self.config["explore"], trainable=False, dtype=tf.bool
- )
- # Log device and worker index.
- num_gpus = self._get_num_gpus_for_policy()
- if num_gpus > 0:
- gpu_ids = get_gpu_devices()
- logger.info(f"Found {len(gpu_ids)} visible cuda devices.")
- self._is_training = False
- # Only for `config.eager_tracing=True`: A counter to keep track of
- # how many times an eager-traced method (e.g.
- # `self._compute_actions_helper`) has been re-traced by tensorflow.
- # We will raise an error if more than n re-tracings have been
- # detected, since this would considerably slow down execution.
- # The variable below should only get incremented during the
- # tf.function trace operations, never when calling the already
- # traced function after that.
- self._re_trace_counter = 0
- self._loss_initialized = False
- # To ensure backward compatibility:
- # Old way: If `loss` provided here, use as-is (as a function).
- if loss_fn is not None:
- self._loss = loss_fn
- # New way: Convert the overridden `self.loss` into a plain
- # function, so it can be called the same way as `loss` would
- # be, ensuring backward compatibility.
- elif self.loss.__func__.__qualname__ != "Policy.loss":
- self._loss = self.loss.__func__
- # `loss` not provided nor overridden from Policy -> Set to None.
- else:
- self._loss = None
- self.batch_divisibility_req = (
- get_batch_divisibility_req(self)
- if callable(get_batch_divisibility_req)
- else (get_batch_divisibility_req or 1)
- )
- self._max_seq_len = config["model"]["max_seq_len"]
- if validate_spaces:
- validate_spaces(self, observation_space, action_space, config)
- if before_init:
- before_init(self, observation_space, action_space, config)
- self.config = config
- self.dist_class = None
- if action_sampler_fn or action_distribution_fn:
- if not make_model:
- raise ValueError(
- "`make_model` is required if `action_sampler_fn` OR "
- "`action_distribution_fn` is given"
- )
- else:
- self.dist_class, logit_dim = ModelCatalog.get_action_dist(
- action_space, self.config["model"]
- )
- if make_model:
- self.model = make_model(self, observation_space, action_space, config)
- else:
- self.model = ModelCatalog.get_model_v2(
- observation_space,
- action_space,
- logit_dim,
- config["model"],
- framework=self.framework,
- )
- # Lock used for locking some methods on the object-level.
- # This prevents possible race conditions when calling the model
- # first, then its value function (e.g. in a loss function), in
- # between of which another model call is made (e.g. to compute an
- # action).
- self._lock = threading.RLock()
- if self.config.get("_enable_rl_module_api", False):
- # Maybe update view_requirements, e.g. for recurrent case.
- self.view_requirements = self.model.update_default_view_requirements(
- self.view_requirements
- )
- else:
- # Auto-update model's inference view requirements, if recurrent.
- self._update_model_view_requirements_from_init_state()
- # Combine view_requirements for Model and Policy.
- self.view_requirements.update(self.model.view_requirements)
- self.exploration = self._create_exploration()
- self._state_inputs = self.model.get_initial_state()
- self._is_recurrent = len(self._state_inputs) > 0
- if before_loss_init:
- before_loss_init(self, observation_space, action_space, config)
- if optimizer_fn:
- optimizers = optimizer_fn(self, config)
- else:
- optimizers = tf.keras.optimizers.Adam(config["lr"])
- optimizers = force_list(optimizers)
- if self.exploration:
- optimizers = self.exploration.get_exploration_optimizer(optimizers)
- # The list of local (tf) optimizers (one per loss term).
- self._optimizers: List[LocalOptimizer] = optimizers
- # Backward compatibility: A user's policy may only support a single
- # loss term and optimizer (no lists).
- self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None
- self._initialize_loss_from_dummy_batch(
- auto_remove_unneeded_view_reqs=True,
- stats_fn=stats_fn,
- )
- self._loss_initialized = True
- if after_init:
- after_init(self, observation_space, action_space, config)
- # Got to reset global_timestep again after fake run-throughs.
- self.global_timestep.assign(0)
- @override(Policy)
- def compute_actions_from_input_dict(
- self,
- input_dict: Dict[str, TensorType],
- explore: bool = None,
- timestep: Optional[int] = None,
- episodes: Optional[List[Episode]] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- if not self.config.get("eager_tracing") and not tf1.executing_eagerly():
- tf1.enable_eager_execution()
- self._is_training = False
- explore = explore if explore is not None else self.explore
- timestep = timestep if timestep is not None else self.global_timestep
- if isinstance(timestep, tf.Tensor):
- timestep = int(timestep.numpy())
- # Pass lazy (eager) tensor dict to Model as `input_dict`.
- input_dict = self._lazy_tensor_dict(input_dict)
- input_dict.set_training(False)
- # Pack internal state inputs into (separate) list.
- state_batches = [
- input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
- ]
- self._state_in = state_batches
- self._is_recurrent = state_batches != []
- # Call the exploration before_compute_actions hook.
- self.exploration.before_compute_actions(
- timestep=timestep, explore=explore, tf_sess=self.get_session()
- )
- ret = self._compute_actions_helper(
- input_dict,
- state_batches,
- # TODO: Passing episodes into a traced method does not work.
- None if self.config["eager_tracing"] else episodes,
- explore,
- timestep,
- )
- # Update our global timestep by the batch size.
- self.global_timestep.assign_add(tree.flatten(ret[0])[0].shape.as_list()[0])
- return convert_to_numpy(ret)
- @override(Policy)
- def compute_actions(
- self,
- obs_batch: Union[List[TensorStructType], TensorStructType],
- state_batches: Optional[List[TensorType]] = None,
- prev_action_batch: Union[List[TensorStructType], TensorStructType] = None,
- prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None,
- info_batch: Optional[Dict[str, list]] = None,
- episodes: Optional[List["Episode"]] = None,
- explore: Optional[bool] = None,
- timestep: Optional[int] = None,
- **kwargs,
- ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
- # Create input dict to simply pass the entire call to
- # self.compute_actions_from_input_dict().
- input_dict = SampleBatch(
- {
- SampleBatch.CUR_OBS: obs_batch,
- },
- _is_training=tf.constant(False),
- )
- if state_batches is not None:
- for i, s in enumerate(state_batches):
- input_dict[f"state_in_{i}"] = s
- if prev_action_batch is not None:
- input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
- if prev_reward_batch is not None:
- input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
- if info_batch is not None:
- input_dict[SampleBatch.INFOS] = info_batch
- return self.compute_actions_from_input_dict(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- **kwargs,
- )
- @with_lock
- @override(Policy)
- def compute_log_likelihoods(
- self,
- actions,
- obs_batch,
- state_batches=None,
- prev_action_batch=None,
- prev_reward_batch=None,
- actions_normalized=True,
- **kwargs,
- ):
- if action_sampler_fn and action_distribution_fn is None:
- raise ValueError(
- "Cannot compute log-prob/likelihood w/o an "
- "`action_distribution_fn` and a provided "
- "`action_sampler_fn`!"
- )
- seq_lens = tf.ones(len(obs_batch), dtype=tf.int32)
- input_batch = SampleBatch(
- {SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_batch)},
- _is_training=False,
- )
- if prev_action_batch is not None:
- input_batch[SampleBatch.PREV_ACTIONS] = tf.convert_to_tensor(
- prev_action_batch
- )
- if prev_reward_batch is not None:
- input_batch[SampleBatch.PREV_REWARDS] = tf.convert_to_tensor(
- prev_reward_batch
- )
- if self.exploration:
- # Exploration hook before each forward pass.
- self.exploration.before_compute_actions(explore=False)
- # Action dist class and inputs are generated via custom function.
- if action_distribution_fn:
- dist_inputs, dist_class, _ = action_distribution_fn(
- self, self.model, input_batch, explore=False, is_training=False
- )
- # Default log-likelihood calculation.
- else:
- dist_inputs, _ = self.model(input_batch, state_batches, seq_lens)
- dist_class = self.dist_class
- action_dist = dist_class(dist_inputs, self.model)
- # Normalize actions if necessary.
- if not actions_normalized and self.config["normalize_actions"]:
- actions = normalize_action(actions, self.action_space_struct)
- log_likelihoods = action_dist.logp(actions)
- return log_likelihoods
- @override(Policy)
- def postprocess_trajectory(
- self, sample_batch, other_agent_batches=None, episode=None
- ):
- assert tf.executing_eagerly()
- # Call super's postprocess_trajectory first.
- sample_batch = EagerTFPolicy.postprocess_trajectory(self, sample_batch)
- if postprocess_fn:
- return postprocess_fn(self, sample_batch, other_agent_batches, episode)
- return sample_batch
- @with_lock
- @override(Policy)
- def learn_on_batch(self, postprocessed_batch):
- # Callback handling.
- learn_stats = {}
- self.callbacks.on_learn_on_batch(
- policy=self, train_batch=postprocessed_batch, result=learn_stats
- )
- pad_batch_to_sequences_of_same_size(
- postprocessed_batch,
- max_seq_len=self._max_seq_len,
- shuffle=False,
- batch_divisibility_req=self.batch_divisibility_req,
- view_requirements=self.view_requirements,
- )
- self._is_training = True
- postprocessed_batch = self._lazy_tensor_dict(postprocessed_batch)
- postprocessed_batch.set_training(True)
- stats = self._learn_on_batch_helper(postprocessed_batch)
- self.num_grad_updates += 1
- stats.update(
- {
- "custom_metrics": learn_stats,
- NUM_AGENT_STEPS_TRAINED: postprocessed_batch.count,
- NUM_GRAD_UPDATES_LIFETIME: self.num_grad_updates,
- # -1, b/c we have to measure this diff before we do the update
- # above.
- DIFF_NUM_GRAD_UPDATES_VS_SAMPLER_POLICY: (
- self.num_grad_updates
- - 1
- - (postprocessed_batch.num_grad_updates or 0)
- ),
- }
- )
- return convert_to_numpy(stats)
- @override(Policy)
- def compute_gradients(
- self, postprocessed_batch: SampleBatch
- ) -> Tuple[ModelGradients, Dict[str, TensorType]]:
- pad_batch_to_sequences_of_same_size(
- postprocessed_batch,
- shuffle=False,
- max_seq_len=self._max_seq_len,
- batch_divisibility_req=self.batch_divisibility_req,
- view_requirements=self.view_requirements,
- )
- self._is_training = True
- self._lazy_tensor_dict(postprocessed_batch)
- postprocessed_batch.set_training(True)
- grads_and_vars, grads, stats = self._compute_gradients_helper(
- postprocessed_batch
- )
- return convert_to_numpy((grads, stats))
- @override(Policy)
- def apply_gradients(self, gradients: ModelGradients) -> None:
- self._apply_gradients_helper(
- list(
- zip(
- [
- (tf.convert_to_tensor(g) if g is not None else None)
- for g in gradients
- ],
- self.model.trainable_variables(),
- )
- )
- )
- @override(Policy)
- def get_weights(self, as_dict=False):
- variables = self.variables()
- if as_dict:
- return {v.name: v.numpy() for v in variables}
- return [v.numpy() for v in variables]
- @override(Policy)
- def set_weights(self, weights):
- variables = self.variables()
- assert len(weights) == len(variables), (len(weights), len(variables))
- for v, w in zip(variables, weights):
- v.assign(w)
- @override(Policy)
- def get_exploration_state(self):
- return convert_to_numpy(self.exploration.get_state())
- @override(Policy)
- def is_recurrent(self):
- return self._is_recurrent
- @override(Policy)
- def num_state_tensors(self):
- return len(self._state_inputs)
- @override(Policy)
- def get_initial_state(self):
- if hasattr(self, "model"):
- return self.model.get_initial_state()
- return []
- @override(Policy)
- def get_state(self) -> PolicyState:
- # Legacy Policy state (w/o keras model and w/o PolicySpec).
- state = super().get_state()
- state["global_timestep"] = state["global_timestep"].numpy()
- if self._optimizer and len(self._optimizer.variables()) > 0:
- state["_optimizer_variables"] = self._optimizer.variables()
- # Add exploration state.
- if not self.config.get("_enable_rl_module_api", False) and self.exploration:
- # This is not compatible with RLModules, which have a method
- # `forward_exploration` to specify custom exploration behavior.
- state["_exploration_state"] = self.exploration.get_state()
- return state
- @override(Policy)
- def set_state(self, state: PolicyState) -> None:
- # Set optimizer vars first.
- optimizer_vars = state.get("_optimizer_variables", None)
- if optimizer_vars and self._optimizer.variables():
- if not type(self).__name__.endswith("_traced") and log_once(
- "set_state_optimizer_vars_tf_eager_policy_v2"
- ):
- logger.warning(
- "Cannot restore an optimizer's state for tf eager! Keras "
- "is not able to save the v1.x optimizers (from "
- "tf.compat.v1.train) since they aren't compatible with "
- "checkpoints."
- )
- for opt_var, value in zip(self._optimizer.variables(), optimizer_vars):
- opt_var.assign(value)
- # Set exploration's state.
- if hasattr(self, "exploration") and "_exploration_state" in state:
- self.exploration.set_state(state=state["_exploration_state"])
- # Restore glbal timestep (tf vars).
- self.global_timestep.assign(state["global_timestep"])
- # Then the Policy's (NN) weights and connectors.
- super().set_state(state)
- @override(Policy)
- def export_model(self, export_dir, onnx: Optional[int] = None) -> None:
- """Exports the Policy's Model to local directory for serving.
- Note: Since the TfModelV2 class that EagerTfPolicy uses is-NOT-a
- tf.keras.Model, we need to assume that there is a `base_model` property
- within this TfModelV2 class that is-a tf.keras.Model. This base model
- will be used here for the export.
- TODO (kourosh): This restriction will be resolved once we move Policy and
- ModelV2 to the new Learner/RLModule APIs.
- Args:
- export_dir: Local writable directory.
- onnx: If given, will export model in ONNX format. The
- value of this parameter set the ONNX OpSet version to use.
- """
- if (
- hasattr(self, "model")
- and hasattr(self.model, "base_model")
- and isinstance(self.model.base_model, tf.keras.Model)
- ):
- # Store model in ONNX format.
- if onnx:
- try:
- import tf2onnx
- except ImportError as e:
- raise RuntimeError(
- "Converting a TensorFlow model to ONNX requires "
- "`tf2onnx` to be installed. Install with "
- "`pip install tf2onnx`."
- ) from e
- model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
- self.model.base_model,
- output_path=os.path.join(export_dir, "model.onnx"),
- )
- # Save the tf.keras.Model (architecture and weights, so it can be
- # retrieved w/o access to the original (custom) Model or Policy code).
- else:
- try:
- self.model.base_model.save(export_dir, save_format="tf")
- except Exception:
- logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
- else:
- logger.warning(ERR_MSG_TF_POLICY_CANNOT_SAVE_KERAS_MODEL)
- def variables(self):
- """Return the list of all savable variables for this policy."""
- if isinstance(self.model, tf.keras.Model):
- return self.model.variables
- else:
- return self.model.variables()
- def loss_initialized(self):
- return self._loss_initialized
- @with_lock
- def _compute_actions_helper(
- self, input_dict, state_batches, episodes, explore, timestep
- ):
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- # Calculate RNN sequence lengths.
- batch_size = tree.flatten(input_dict[SampleBatch.OBS])[0].shape[0]
- seq_lens = tf.ones(batch_size, dtype=tf.int32) if state_batches else None
- # Add default and custom fetches.
- extra_fetches = {}
- # Use Exploration object.
- with tf.variable_creator_scope(_disallow_var_creation):
- if action_sampler_fn:
- action_sampler_outputs = action_sampler_fn(
- self,
- self.model,
- input_dict[SampleBatch.CUR_OBS],
- explore=explore,
- timestep=timestep,
- episodes=episodes,
- )
- if len(action_sampler_outputs) == 4:
- actions, logp, dist_inputs, state_out = action_sampler_outputs
- else:
- dist_inputs = None
- state_out = []
- actions, logp = action_sampler_outputs
- else:
- if action_distribution_fn:
- # Try new action_distribution_fn signature, supporting
- # state_batches and seq_lens.
- try:
- (
- dist_inputs,
- self.dist_class,
- state_out,
- ) = action_distribution_fn(
- self,
- self.model,
- input_dict=input_dict,
- state_batches=state_batches,
- seq_lens=seq_lens,
- explore=explore,
- timestep=timestep,
- is_training=False,
- )
- # Trying the old way (to stay backward compatible).
- # TODO: Remove in future.
- except TypeError as e:
- if (
- "positional argument" in e.args[0]
- or "unexpected keyword argument" in e.args[0]
- ):
- (
- dist_inputs,
- self.dist_class,
- state_out,
- ) = action_distribution_fn(
- self,
- self.model,
- input_dict[SampleBatch.OBS],
- explore=explore,
- timestep=timestep,
- is_training=False,
- )
- else:
- raise e
- elif isinstance(self.model, tf.keras.Model):
- input_dict = SampleBatch(input_dict, seq_lens=seq_lens)
- if state_batches and "state_in_0" not in input_dict:
- for i, s in enumerate(state_batches):
- input_dict[f"state_in_{i}"] = s
- self._lazy_tensor_dict(input_dict)
- dist_inputs, state_out, extra_fetches = self.model(input_dict)
- else:
- dist_inputs, state_out = self.model(
- input_dict, state_batches, seq_lens
- )
- action_dist = self.dist_class(dist_inputs, self.model)
- # Get the exploration action from the forward results.
- actions, logp = self.exploration.get_exploration_action(
- action_distribution=action_dist,
- timestep=timestep,
- explore=explore,
- )
- # Action-logp and action-prob.
- if logp is not None:
- extra_fetches[SampleBatch.ACTION_PROB] = tf.exp(logp)
- extra_fetches[SampleBatch.ACTION_LOGP] = logp
- # Action-dist inputs.
- if dist_inputs is not None:
- extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
- # Custom extra fetches.
- if extra_action_out_fn:
- extra_fetches.update(extra_action_out_fn(self))
- return actions, state_out, extra_fetches
- # TODO: Figure out, why _ray_trace_ctx=None helps to prevent a crash in
- # AlphaStar w/ framework=tf2; eager_tracing=True on the policy learner actors.
- # It seems there may be a clash between the traced-by-tf function and the
- # traced-by-ray functions (for making the policy class a ray actor).
- def _learn_on_batch_helper(self, samples, _ray_trace_ctx=None):
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- with tf.variable_creator_scope(_disallow_var_creation):
- grads_and_vars, _, stats = self._compute_gradients_helper(samples)
- self._apply_gradients_helper(grads_and_vars)
- return stats
- def _get_is_training_placeholder(self):
- return tf.convert_to_tensor(self._is_training)
- @with_lock
- def _compute_gradients_helper(self, samples):
- """Computes and returns grads as eager tensors."""
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- # Gather all variables for which to calculate losses.
- if isinstance(self.model, tf.keras.Model):
- variables = self.model.trainable_variables
- else:
- variables = self.model.trainable_variables()
- # Calculate the loss(es) inside a tf GradientTape.
- with tf.GradientTape(persistent=compute_gradients_fn is not None) as tape:
- losses = self._loss(self, self.model, self.dist_class, samples)
- losses = force_list(losses)
- # User provided a compute_gradients_fn.
- if compute_gradients_fn:
- # Wrap our tape inside a wrapper, such that the resulting
- # object looks like a "classic" tf.optimizer. This way, custom
- # compute_gradients_fn will work on both tf static graph
- # and tf-eager.
- optimizer = _OptimizerWrapper(tape)
- # More than one loss terms/optimizers.
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- grads_and_vars = compute_gradients_fn(
- self, [optimizer] * len(losses), losses
- )
- # Only one loss and one optimizer.
- else:
- grads_and_vars = [compute_gradients_fn(self, optimizer, losses[0])]
- # Default: Compute gradients using the above tape.
- else:
- grads_and_vars = [
- list(zip(tape.gradient(loss, variables), variables))
- for loss in losses
- ]
- if log_once("grad_vars"):
- for g_and_v in grads_and_vars:
- for g, v in g_and_v:
- if g is not None:
- logger.info(f"Optimizing variable {v.name}")
- # `grads_and_vars` is returned a list (len=num optimizers/losses)
- # of lists of (grad, var) tuples.
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- grads = [[g for g, _ in g_and_v] for g_and_v in grads_and_vars]
- # `grads_and_vars` is returned as a list of (grad, var) tuples.
- else:
- grads_and_vars = grads_and_vars[0]
- grads = [g for g, _ in grads_and_vars]
- stats = self._stats(self, samples, grads)
- return grads_and_vars, grads, stats
- def _apply_gradients_helper(self, grads_and_vars):
- # Increase the tracing counter to make sure we don't re-trace too
- # often. If eager_tracing=True, this counter should only get
- # incremented during the @tf.function trace operations, never when
- # calling the already traced function after that.
- self._re_trace_counter += 1
- if apply_gradients_fn:
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- apply_gradients_fn(self, self._optimizers, grads_and_vars)
- else:
- apply_gradients_fn(self, self._optimizer, grads_and_vars)
- else:
- if self.config["_tf_policy_handles_more_than_one_loss"]:
- for i, o in enumerate(self._optimizers):
- o.apply_gradients(
- [(g, v) for g, v in grads_and_vars[i] if g is not None]
- )
- else:
- self._optimizer.apply_gradients(
- [(g, v) for g, v in grads_and_vars if g is not None]
- )
- def _stats(self, outputs, samples, grads):
- fetches = {}
- if stats_fn:
- fetches[LEARNER_STATS_KEY] = {
- k: v for k, v in stats_fn(outputs, samples).items()
- }
- else:
- fetches[LEARNER_STATS_KEY] = {}
- if extra_learn_fetches_fn:
- fetches.update({k: v for k, v in extra_learn_fetches_fn(self).items()})
- if grad_stats_fn:
- fetches.update(
- {k: v for k, v in grad_stats_fn(self, samples, grads).items()}
- )
- return fetches
- def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch):
- # TODO: (sven): Keep for a while to ensure backward compatibility.
- if not isinstance(postprocessed_batch, SampleBatch):
- postprocessed_batch = SampleBatch(postprocessed_batch)
- postprocessed_batch.set_get_interceptor(_convert_to_tf)
- return postprocessed_batch
- @classmethod
- def with_tracing(cls):
- return _traced_eager_policy(cls)
- eager_policy_cls.__name__ = name + "_eager"
- eager_policy_cls.__qualname__ = name + "_eager"
- return eager_policy_cls
|