123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795 |
- from collections import Counter
- import copy
- from gym.spaces import Box
- import logging
- import numpy as np
- import random
- import re
- import time
- import tree # pip install dm_tree
- from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
- import yaml
- import ray
- from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
- try_import_torch
- from ray.rllib.utils.typing import PartialTrainerConfigDict
- from ray.tune import CLIReporter, run_experiments
- jax, _ = try_import_jax()
- tf1, tf, tfv = try_import_tf()
- if tf1:
- eager_mode = None
- try:
- from tensorflow.python.eager.context import eager_mode
- except (ImportError, ModuleNotFoundError):
- pass
- torch, _ = try_import_torch()
- logger = logging.getLogger(__name__)
- def framework_iterator(
- config: Optional[PartialTrainerConfigDict] = None,
- frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"),
- session: bool = False,
- with_eager_tracing: bool = False,
- time_iterations: Optional[dict] = None,
- ) -> Union[str, Tuple[str, Optional["tf1.Session"]]]:
- """An generator that allows for looping through n frameworks for testing.
- Provides the correct config entries ("framework") as well
- as the correct eager/non-eager contexts for tfe/tf.
- Args:
- config: An optional config dict to alter in place depending on the
- iteration.
- frameworks: A list/tuple of the frameworks to be tested.
- Allowed are: "tf2", "tf", "tfe", "torch", and None.
- session: If True and only in the tf-case: Enter a tf.Session()
- and yield that as second return value (otherwise yield (fw, None)).
- Also sets a seed (42) on the session to make the test
- deterministic.
- with_eager_tracing: Include `eager_tracing=True` in the returned
- configs, when framework=[tfe|tf2].
- time_iterations: If provided, will write to the given dict (by
- framework key) the times in seconds that each (framework's)
- iteration takes.
- Yields:
- If `session` is False: The current framework [tf2|tf|tfe|torch] used.
- If `session` is True: A tuple consisting of the current framework
- string and the tf1.Session (if fw="tf", otherwise None).
- """
- config = config or {}
- frameworks = [frameworks] if isinstance(frameworks, str) else \
- list(frameworks)
- # Both tf2 and tfe present -> remove "tfe" or "tf2" depending on version.
- if "tf2" in frameworks and "tfe" in frameworks:
- frameworks.remove("tfe" if tfv == 2 else "tf2")
- for fw in frameworks:
- # Skip non-installed frameworks.
- if fw == "torch" and not torch:
- logger.warning(
- "framework_iterator skipping torch (not installed)!")
- continue
- if fw != "torch" and not tf:
- logger.warning("framework_iterator skipping {} (tf not "
- "installed)!".format(fw))
- continue
- elif fw == "tfe" and not eager_mode:
- logger.warning("framework_iterator skipping tf-eager (could not "
- "import `eager_mode` from tensorflow.python)!")
- continue
- elif fw == "tf2" and tfv != 2:
- logger.warning(
- "framework_iterator skipping tf2.x (tf version is < 2.0)!")
- continue
- elif fw == "jax" and not jax:
- logger.warning("framework_iterator skipping JAX (not installed)!")
- continue
- assert fw in ["tf2", "tf", "tfe", "torch", "jax", None]
- # Do we need a test session?
- sess = None
- if fw == "tf" and session is True:
- sess = tf1.Session()
- sess.__enter__()
- tf1.set_random_seed(42)
- config["framework"] = fw
- eager_ctx = None
- # Enable eager mode for tf2 and tfe.
- if fw in ["tf2", "tfe"]:
- eager_ctx = eager_mode()
- eager_ctx.__enter__()
- assert tf1.executing_eagerly()
- # Make sure, eager mode is off.
- elif fw == "tf":
- assert not tf1.executing_eagerly()
- # Additionally loop through eager_tracing=True + False, if necessary.
- if fw in ["tf2", "tfe"] and with_eager_tracing:
- for tracing in [True, False]:
- config["eager_tracing"] = tracing
- print(f"framework={fw} (eager-tracing={tracing})")
- time_started = time.time()
- yield fw if session is False else (fw, sess)
- if time_iterations is not None:
- time_total = time.time() - time_started
- time_iterations[fw + ("+tracing" if tracing else "")] = \
- time_total
- print(f".. took {time_total}sec")
- config["eager_tracing"] = False
- # Yield current framework + tf-session (if necessary).
- else:
- print(f"framework={fw}")
- time_started = time.time()
- yield fw if session is False else (fw, sess)
- if time_iterations is not None:
- time_total = time.time() - time_started
- time_iterations[fw + ("+tracing" if tracing else "")] = \
- time_total
- print(f".. took {time_total}sec")
- # Exit any context we may have entered.
- if eager_ctx:
- eager_ctx.__exit__(None, None, None)
- elif sess:
- sess.__exit__(None, None, None)
- def check(x, y, decimals=5, atol=None, rtol=None, false=False):
- """
- Checks two structures (dict, tuple, list,
- np.array, float, int, etc..) for (almost) numeric identity.
- All numbers in the two structures have to match up to `decimal` digits
- after the floating point. Uses assertions.
- Args:
- x (any): The value to be compared (to the expectation: `y`). This
- may be a Tensor.
- y (any): The expected value to be compared to `x`. This must not
- be a tf-Tensor, but may be a tfe/torch-Tensor.
- decimals (int): The number of digits after the floating point up to
- which all numeric values have to match.
- atol (float): Absolute tolerance of the difference between x and y
- (overrides `decimals` if given).
- rtol (float): Relative tolerance of the difference between x and y
- (overrides `decimals` if given).
- false (bool): Whether to check that x and y are NOT the same.
- """
- # A dict type.
- if isinstance(x, dict):
- assert isinstance(y, dict), \
- "ERROR: If x is dict, y needs to be a dict as well!"
- y_keys = set(x.keys())
- for key, value in x.items():
- assert key in y, \
- "ERROR: y does not have x's key='{}'! y={}".format(key, y)
- check(
- value,
- y[key],
- decimals=decimals,
- atol=atol,
- rtol=rtol,
- false=false)
- y_keys.remove(key)
- assert not y_keys, \
- "ERROR: y contains keys ({}) that are not in x! y={}".\
- format(list(y_keys), y)
- # A tuple type.
- elif isinstance(x, (tuple, list)):
- assert isinstance(y, (tuple, list)),\
- "ERROR: If x is tuple, y needs to be a tuple as well!"
- assert len(y) == len(x),\
- "ERROR: y does not have the same length as x ({} vs {})!".\
- format(len(y), len(x))
- for i, value in enumerate(x):
- check(
- value,
- y[i],
- decimals=decimals,
- atol=atol,
- rtol=rtol,
- false=false)
- # Boolean comparison.
- elif isinstance(x, (np.bool_, bool)):
- if false is True:
- assert bool(x) is not bool(y), \
- "ERROR: x ({}) is y ({})!".format(x, y)
- else:
- assert bool(x) is bool(y), \
- "ERROR: x ({}) is not y ({})!".format(x, y)
- # Nones or primitives.
- elif x is None or y is None or isinstance(x, (str, int)):
- if false is True:
- assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y)
- else:
- assert x == y, \
- "ERROR: x ({}) is not the same as y ({})!".format(x, y)
- # String/byte comparisons.
- elif hasattr(x, "dtype") and \
- (x.dtype == np.object or str(x.dtype).startswith("<U")):
- try:
- np.testing.assert_array_equal(x, y)
- if false is True:
- assert False, \
- "ERROR: x ({}) is the same as y ({})!".format(x, y)
- except AssertionError as e:
- if false is False:
- raise e
- # Everything else (assume numeric or tf/torch.Tensor).
- else:
- if tf1 is not None:
- # y should never be a Tensor (y=expected value).
- if isinstance(y, (tf1.Tensor, tf1.Variable)):
- # In eager mode, numpyize tensors.
- if tf.executing_eagerly():
- y = y.numpy()
- else:
- raise ValueError(
- "`y` (expected value) must not be a Tensor. "
- "Use numpy.ndarray instead")
- if isinstance(x, (tf1.Tensor, tf1.Variable)):
- # In eager mode, numpyize tensors.
- if tf1.executing_eagerly():
- x = x.numpy()
- # Otherwise, use a new tf-session.
- else:
- with tf1.Session() as sess:
- x = sess.run(x)
- return check(
- x,
- y,
- decimals=decimals,
- atol=atol,
- rtol=rtol,
- false=false)
- if torch is not None:
- if isinstance(x, torch.Tensor):
- x = x.detach().cpu().numpy()
- if isinstance(y, torch.Tensor):
- y = y.detach().cpu().numpy()
- # Using decimals.
- if atol is None and rtol is None:
- # Assert equality of both values.
- try:
- np.testing.assert_almost_equal(x, y, decimal=decimals)
- # Both values are not equal.
- except AssertionError as e:
- # Raise error in normal case.
- if false is False:
- raise e
- # Both values are equal.
- else:
- # If false is set -> raise error (not expected to be equal).
- if false is True:
- assert False, \
- "ERROR: x ({}) is the same as y ({})!".format(x, y)
- # Using atol/rtol.
- else:
- # Provide defaults for either one of atol/rtol.
- if atol is None:
- atol = 0
- if rtol is None:
- rtol = 1e-7
- try:
- np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
- except AssertionError as e:
- if false is False:
- raise e
- else:
- if false is True:
- assert False, \
- "ERROR: x ({}) is the same as y ({})!".format(x, y)
- def check_compute_single_action(trainer,
- include_state=False,
- include_prev_action_reward=False):
- """Tests different combinations of args for trainer.compute_single_action.
- Args:
- trainer: The Trainer object to test.
- include_state: Whether to include the initial state of the Policy's
- Model in the `compute_single_action` call.
- include_prev_action_reward: Whether to include the prev-action and
- -reward in the `compute_single_action` call.
- Raises:
- ValueError: If anything unexpected happens.
- """
- # Have to import this here to avoid circular dependency.
- from ray.rllib.policy.sample_batch import SampleBatch
- # Some Trainers may not abide to the standard API.
- try:
- pol = trainer.get_policy()
- except AttributeError:
- pol = trainer.policy
- # Get the policy's model.
- model = pol.model
- action_space = pol.action_space
- def _test(what, method_to_test, obs_space, full_fetch, explore, timestep,
- unsquash, clip):
- call_kwargs = {}
- if what is trainer:
- call_kwargs["full_fetch"] = full_fetch
- obs = obs_space.sample()
- if isinstance(obs_space, Box):
- obs = np.clip(obs, -1.0, 1.0)
- state_in = None
- if include_state:
- state_in = model.get_initial_state()
- if not state_in:
- state_in = []
- i = 0
- while f"state_in_{i}" in model.view_requirements:
- state_in.append(model.view_requirements[f"state_in_{i}"]
- .space.sample())
- i += 1
- action_in = action_space.sample() \
- if include_prev_action_reward else None
- reward_in = 1.0 if include_prev_action_reward else None
- if method_to_test == "input_dict":
- assert what is pol
- input_dict = {SampleBatch.OBS: obs}
- if include_prev_action_reward:
- input_dict[SampleBatch.PREV_ACTIONS] = action_in
- input_dict[SampleBatch.PREV_REWARDS] = reward_in
- if state_in:
- for i, s in enumerate(state_in):
- input_dict[f"state_in_{i}"] = s
- input_dict_batched = SampleBatch(
- tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict))
- action = pol.compute_actions_from_input_dict(
- input_dict=input_dict_batched,
- explore=explore,
- timestep=timestep,
- **call_kwargs)
- # Unbatch everything to be able to compare against single
- # action below.
- # ARS and ES return action batches as lists.
- if isinstance(action[0], list):
- action = (np.array(action[0]), action[1], action[2])
- action = tree.map_structure(lambda s: s[0], action)
- try:
- action2 = pol.compute_single_action(
- input_dict=input_dict,
- explore=explore,
- timestep=timestep,
- **call_kwargs)
- # Make sure these are the same, unless we have exploration
- # switched on (or noisy layers).
- if not explore and not pol.config.get("noisy"):
- check(action, action2)
- except TypeError:
- pass
- else:
- action = what.compute_single_action(
- obs,
- state_in,
- prev_action=action_in,
- prev_reward=reward_in,
- explore=explore,
- timestep=timestep,
- unsquash_action=unsquash,
- clip_action=clip,
- **call_kwargs)
- state_out = None
- if state_in or full_fetch or what is pol:
- action, state_out, _ = action
- if state_out:
- for si, so in zip(state_in, state_out):
- check(list(si.shape), so.shape)
- # Test whether unsquash/clipping works on the Trainer's
- # compute_single_action method: Both flags should force the action
- # to be within the space's bounds.
- if method_to_test == "single" and what == trainer:
- if not action_space.contains(action) and \
- (clip or unsquash or not isinstance(action_space, Box)):
- raise ValueError(
- f"Returned action ({action}) of trainer/policy {what} "
- f"not in Env's action_space {action_space}")
- # We are operating in normalized space: Expect only smaller action
- # values.
- if isinstance(action_space, Box) and not unsquash and \
- what.config.get("normalize_actions") and \
- np.any(np.abs(action) > 3.0):
- raise ValueError(
- f"Returned action ({action}) of trainer/policy {what} "
- "should be in normalized space, but seems too large/small "
- "for that!")
- # Loop through: Policy vs Trainer; Different API methods to calculate
- # actions; unsquash option; clip option; full fetch or not.
- for what in [pol, trainer]:
- if what is trainer:
- # Get the obs-space from Workers.env (not Policy) due to possible
- # pre-processor up front.
- worker_set = getattr(trainer, "workers",
- getattr(trainer, "_workers", None))
- assert worker_set
- if isinstance(worker_set, list):
- obs_space = trainer.get_policy().observation_space
- else:
- obs_space = worker_set.local_worker().for_policy(
- lambda p: p.observation_space)
- obs_space = getattr(obs_space, "original_space", obs_space)
- else:
- obs_space = pol.observation_space
- for method_to_test in ["single"] + \
- (["input_dict"] if what is pol else []):
- for explore in [True, False]:
- for full_fetch in ([False, True]
- if what is trainer else [False]):
- timestep = random.randint(0, 100000)
- for unsquash in [True, False]:
- for clip in ([False] if unsquash else [True, False]):
- _test(what, method_to_test, obs_space, full_fetch,
- explore, timestep, unsquash, clip)
- def check_learning_achieved(tune_results, min_reward, evaluation=False):
- """Throws an error if `min_reward` is not reached within tune_results.
- Checks the last iteration found in tune_results for its
- "episode_reward_mean" value and compares it to `min_reward`.
- Args:
- tune_results: The tune.run returned results object.
- min_reward (float): The min reward that must be reached.
- Raises:
- ValueError: If `min_reward` not reached.
- """
- # Get maximum reward of all trials
- # (check if at least one trial achieved some learning)
- avg_rewards = [(trial.last_result["episode_reward_mean"]
- if not evaluation else
- trial.last_result["evaluation"]["episode_reward_mean"])
- for trial in tune_results.trials]
- best_avg_reward = max(avg_rewards)
- if best_avg_reward < min_reward:
- raise ValueError("`stop-reward` of {} not reached!".format(min_reward))
- print("ok")
- def check_train_results(train_results):
- """Checks proper structure of a Trainer.train() returned dict.
- Args:
- train_results: The train results dict to check.
- Raises:
- AssertionError: If `train_results` doesn't have the proper structure or
- data in it.
- """
- # Import these here to avoid circular dependencies.
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
- LEARNER_STATS_KEY
- from ray.rllib.utils.multi_agent import check_multi_agent
- # Assert that some keys are where we would expect them.
- for key in [
- "agent_timesteps_total",
- "config",
- "custom_metrics",
- "episode_len_mean",
- "episode_reward_max",
- "episode_reward_mean",
- "episode_reward_min",
- "episodes_total",
- "hist_stats",
- "info",
- "iterations_since_restore",
- "num_healthy_workers",
- "perf",
- "policy_reward_max",
- "policy_reward_mean",
- "policy_reward_min",
- "sampler_perf",
- "time_since_restore",
- "time_this_iter_s",
- "timesteps_since_restore",
- "timesteps_total",
- "timers",
- "time_total_s",
- "training_iteration",
- ]:
- assert key in train_results, \
- f"'{key}' not found in `train_results` ({train_results})!"
- _, is_multi_agent = check_multi_agent(train_results["config"])
- # Check in particular the "info" dict.
- info = train_results["info"]
- assert LEARNER_INFO in info, \
- f"'learner' not in train_results['infos'] ({info})!"
- assert "num_steps_trained" in info,\
- f"'num_steps_trained' not in train_results['infos'] ({info})!"
- learner_info = info[LEARNER_INFO]
- # Make sure we have a default_policy key if we are not in a
- # multi-agent setup.
- if not is_multi_agent:
- # APEX algos sometimes have an empty learner info dict (no metrics
- # collected yet).
- assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \
- f"'{DEFAULT_POLICY_ID}' not found in " \
- f"train_results['infos']['learner'] ({learner_info})!"
- for pid, policy_stats in learner_info.items():
- if pid == "batch_count":
- continue
- # Expect td-errors to be per batch-item.
- if "td_error" in policy_stats:
- configured_b = train_results["config"]["train_batch_size"]
- actual_b = policy_stats["td_error"].shape[0]
- # R2D2 case.
- if (configured_b - actual_b) / actual_b > 0.1:
- assert configured_b / (
- train_results["config"]["model"]["max_seq_len"] +
- train_results["config"]["burn_in"]) == actual_b
- # Make sure each policy has the LEARNER_STATS_KEY under it.
- assert LEARNER_STATS_KEY in policy_stats
- learner_stats = policy_stats[LEARNER_STATS_KEY]
- for key, value in learner_stats.items():
- # Min- and max-stats should be single values.
- if key.startswith("min_") or key.startswith("max_"):
- assert np.isscalar(
- value), f"'key' value not a scalar ({value})!"
- return train_results
- def run_learning_tests_from_yaml(
- yaml_files: List[str],
- *,
- max_num_repeats: int = 2,
- smoke_test: bool = False,
- ) -> Dict[str, Any]:
- """Runs the given experiments in yaml_files and returns results dict.
- Args:
- yaml_files (List[str]): List of yaml file names.
- max_num_repeats (int): How many times should we repeat a failed
- experiment?
- smoke_test (bool): Whether this is just a smoke-test. If True,
- set time_total_s to 5min and don't early out due to rewards
- or timesteps reached.
- """
- print("Will run the following yaml files:")
- for yaml_file in yaml_files:
- print("->", yaml_file)
- # All trials we'll ever run in this test script.
- all_trials = []
- # The experiments (by name) we'll run up to `max_num_repeats` times.
- experiments = {}
- # The results per experiment.
- checks = {}
- # Metrics per experiment.
- stats = {}
- start_time = time.monotonic()
- # Loop through all collected files and gather experiments.
- # Augment all by `torch` framework.
- for yaml_file in yaml_files:
- tf_experiments = yaml.safe_load(open(yaml_file).read())
- # Add torch version of all experiments to the list.
- for k, e in tf_experiments.items():
- # If framework explicitly given, only test for that framework.
- # Some algos do not have both versions available.
- if "frameworks" in e:
- frameworks = e["frameworks"]
- else:
- # By default we don't run tf2, because tf2's multi-gpu support
- # isn't complete yet.
- frameworks = ["tf", "torch"]
- # Pop frameworks key to not confuse Tune.
- e.pop("frameworks", None)
- e["stop"] = e["stop"] if "stop" in e else {}
- e["pass_criteria"] = e[
- "pass_criteria"] if "pass_criteria" in e else {}
- # For smoke-tests, we just run for n min.
- if smoke_test:
- # 0sec for each(!) experiment/trial.
- # This is such that if there are many experiments/trials
- # in a test (e.g. rllib_learning_test), each one can at least
- # create its trainer and run a first iteration.
- e["stop"]["time_total_s"] = 0
- else:
- # We also stop early, once we reach the desired reward.
- min_reward = e.get("pass_criteria",
- {}).get("episode_reward_mean")
- if min_reward is not None:
- e["stop"]["episode_reward_mean"] = min_reward
- # Generate `checks` dict for all experiments
- # (tf, tf2 and/or torch).
- for framework in frameworks:
- k_ = k + "-" + framework
- ec = copy.deepcopy(e)
- ec["config"]["framework"] = framework
- if framework == "tf2":
- ec["config"]["eager_tracing"] = True
- checks[k_] = {
- "min_reward": ec["pass_criteria"].get(
- "episode_reward_mean", 0.0),
- "min_throughput": ec["pass_criteria"].get(
- "timesteps_total", 0.0) /
- (ec["stop"].get("time_total_s", 1.0) or 1.0),
- "time_total_s": ec["stop"].get("time_total_s"),
- "failures": 0,
- "passed": False,
- }
- # This key would break tune.
- ec.pop("pass_criteria", None)
- # One experiment to run.
- experiments[k_] = ec
- # Print out the actual config.
- print("== Test config ==")
- print(yaml.dump(experiments))
- # Keep track of those experiments we still have to run.
- # If an experiment passes, we'll remove it from this dict.
- experiments_to_run = experiments.copy()
- try:
- ray.init(address="auto")
- except ConnectionError:
- ray.init()
- for i in range(max_num_repeats):
- # We are done.
- if len(experiments_to_run) == 0:
- print("All experiments finished.")
- break
- print(f"Starting learning test iteration {i}...")
- # Run remaining experiments.
- trials = run_experiments(
- experiments_to_run,
- resume=False,
- verbose=2,
- progress_reporter=CLIReporter(
- metric_columns={
- "training_iteration": "iter",
- "time_total_s": "time_total_s",
- "timesteps_total": "ts",
- "episodes_this_iter": "train_episodes",
- "episode_reward_mean": "reward_mean",
- },
- sort_by_metric=True,
- max_report_frequency=30,
- ))
- all_trials.extend(trials)
- # Check each experiment for whether it passed.
- # Criteria is to a) reach reward AND b) to have reached the throughput
- # defined by `timesteps_total` / `time_total_s`.
- for experiment in experiments_to_run.copy():
- print(f"Analyzing experiment {experiment} ...")
- # Collect all trials within this experiment (some experiments may
- # have num_samples or grid_searches defined).
- trials_for_experiment = []
- for t in trials:
- trial_exp = re.sub(".+/([^/]+)$", "\\1", t.local_dir)
- if trial_exp == experiment:
- trials_for_experiment.append(t)
- print(f" ... Trials: {trials_for_experiment}.")
- # If we have evaluation workers, use their rewards.
- # This is useful for offline learning tests, where
- # we evaluate against an actual environment.
- check_eval = experiments[experiment]["config"].get(
- "evaluation_interval", None) is not None
- # Error: Increase failure count and repeat.
- if any(t.status == "ERROR" for t in trials_for_experiment):
- print(" ... ERROR.")
- checks[experiment]["failures"] += 1
- # Smoke-tests always succeed.
- elif smoke_test:
- print(" ... SMOKE TEST (mark ok).")
- checks[experiment]["passed"] = True
- del experiments_to_run[experiment]
- # Experiment finished: Check reward achieved and timesteps done
- # (throughput).
- else:
- if check_eval:
- episode_reward_mean = np.mean([
- t.last_result["evaluation"]["episode_reward_mean"]
- for t in trials_for_experiment
- ])
- else:
- episode_reward_mean = np.mean([
- t.last_result["episode_reward_mean"]
- for t in trials_for_experiment
- ])
- desired_reward = checks[experiment]["min_reward"]
- timesteps_total = np.mean([
- t.last_result["timesteps_total"]
- for t in trials_for_experiment
- ])
- total_time_s = np.mean([
- t.last_result["time_total_s"]
- for t in trials_for_experiment
- ])
- # TODO(jungong) : track trainer and env throughput separately.
- throughput = timesteps_total / (total_time_s or 1.0)
- desired_throughput = checks[experiment]["min_throughput"]
- # Record performance.
- stats[experiment] = {
- "episode_reward_mean": episode_reward_mean,
- "throughput": throughput,
- }
- print(f" ... Desired reward={desired_reward}; "
- f"desired throughput={desired_throughput}")
- # We failed to reach desired reward or the desired throughput.
- if (desired_reward and
- episode_reward_mean < desired_reward) or \
- (desired_throughput and
- throughput < desired_throughput):
- print(" ... Not successful: Actual "
- f"reward={episode_reward_mean}; "
- f"actual throughput={throughput}")
- checks[experiment]["failures"] += 1
- # We succeeded!
- else:
- print(" ... Successful: (mark ok).")
- checks[experiment]["passed"] = True
- del experiments_to_run[experiment]
- ray.shutdown()
- time_taken = time.monotonic() - start_time
- # Create results dict and write it to disk.
- result = {
- "time_taken": time_taken,
- "trial_states": dict(Counter([trial.status for trial in all_trials])),
- "last_update": time.time(),
- "stats": stats,
- "passed": [k for k, exp in checks.items() if exp["passed"]],
- "failures": {
- k: exp["failures"]
- for k, exp in checks.items() if exp["failures"] > 0
- }
- }
- return result
|