from collections import Counter import copy from gym.spaces import Box import logging import numpy as np import random import re import time import tree # pip install dm_tree from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import yaml import ray from ray.rllib.utils.framework import try_import_jax, try_import_tf, \ try_import_torch from ray.rllib.utils.typing import PartialTrainerConfigDict from ray.tune import CLIReporter, run_experiments jax, _ = try_import_jax() tf1, tf, tfv = try_import_tf() if tf1: eager_mode = None try: from tensorflow.python.eager.context import eager_mode except (ImportError, ModuleNotFoundError): pass torch, _ = try_import_torch() logger = logging.getLogger(__name__) def framework_iterator( config: Optional[PartialTrainerConfigDict] = None, frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"), session: bool = False, with_eager_tracing: bool = False, time_iterations: Optional[dict] = None, ) -> Union[str, Tuple[str, Optional["tf1.Session"]]]: """An generator that allows for looping through n frameworks for testing. Provides the correct config entries ("framework") as well as the correct eager/non-eager contexts for tfe/tf. Args: config: An optional config dict to alter in place depending on the iteration. frameworks: A list/tuple of the frameworks to be tested. Allowed are: "tf2", "tf", "tfe", "torch", and None. session: If True and only in the tf-case: Enter a tf.Session() and yield that as second return value (otherwise yield (fw, None)). Also sets a seed (42) on the session to make the test deterministic. with_eager_tracing: Include `eager_tracing=True` in the returned configs, when framework=[tfe|tf2]. time_iterations: If provided, will write to the given dict (by framework key) the times in seconds that each (framework's) iteration takes. Yields: If `session` is False: The current framework [tf2|tf|tfe|torch] used. If `session` is True: A tuple consisting of the current framework string and the tf1.Session (if fw="tf", otherwise None). """ config = config or {} frameworks = [frameworks] if isinstance(frameworks, str) else \ list(frameworks) # Both tf2 and tfe present -> remove "tfe" or "tf2" depending on version. if "tf2" in frameworks and "tfe" in frameworks: frameworks.remove("tfe" if tfv == 2 else "tf2") for fw in frameworks: # Skip non-installed frameworks. if fw == "torch" and not torch: logger.warning( "framework_iterator skipping torch (not installed)!") continue if fw != "torch" and not tf: logger.warning("framework_iterator skipping {} (tf not " "installed)!".format(fw)) continue elif fw == "tfe" and not eager_mode: logger.warning("framework_iterator skipping tf-eager (could not " "import `eager_mode` from tensorflow.python)!") continue elif fw == "tf2" and tfv != 2: logger.warning( "framework_iterator skipping tf2.x (tf version is < 2.0)!") continue elif fw == "jax" and not jax: logger.warning("framework_iterator skipping JAX (not installed)!") continue assert fw in ["tf2", "tf", "tfe", "torch", "jax", None] # Do we need a test session? sess = None if fw == "tf" and session is True: sess = tf1.Session() sess.__enter__() tf1.set_random_seed(42) config["framework"] = fw eager_ctx = None # Enable eager mode for tf2 and tfe. if fw in ["tf2", "tfe"]: eager_ctx = eager_mode() eager_ctx.__enter__() assert tf1.executing_eagerly() # Make sure, eager mode is off. elif fw == "tf": assert not tf1.executing_eagerly() # Additionally loop through eager_tracing=True + False, if necessary. if fw in ["tf2", "tfe"] and with_eager_tracing: for tracing in [True, False]: config["eager_tracing"] = tracing print(f"framework={fw} (eager-tracing={tracing})") time_started = time.time() yield fw if session is False else (fw, sess) if time_iterations is not None: time_total = time.time() - time_started time_iterations[fw + ("+tracing" if tracing else "")] = \ time_total print(f".. took {time_total}sec") config["eager_tracing"] = False # Yield current framework + tf-session (if necessary). else: print(f"framework={fw}") time_started = time.time() yield fw if session is False else (fw, sess) if time_iterations is not None: time_total = time.time() - time_started time_iterations[fw + ("+tracing" if tracing else "")] = \ time_total print(f".. took {time_total}sec") # Exit any context we may have entered. if eager_ctx: eager_ctx.__exit__(None, None, None) elif sess: sess.__exit__(None, None, None) def check(x, y, decimals=5, atol=None, rtol=None, false=False): """ Checks two structures (dict, tuple, list, np.array, float, int, etc..) for (almost) numeric identity. All numbers in the two structures have to match up to `decimal` digits after the floating point. Uses assertions. Args: x (any): The value to be compared (to the expectation: `y`). This may be a Tensor. y (any): The expected value to be compared to `x`. This must not be a tf-Tensor, but may be a tfe/torch-Tensor. decimals (int): The number of digits after the floating point up to which all numeric values have to match. atol (float): Absolute tolerance of the difference between x and y (overrides `decimals` if given). rtol (float): Relative tolerance of the difference between x and y (overrides `decimals` if given). false (bool): Whether to check that x and y are NOT the same. """ # A dict type. if isinstance(x, dict): assert isinstance(y, dict), \ "ERROR: If x is dict, y needs to be a dict as well!" y_keys = set(x.keys()) for key, value in x.items(): assert key in y, \ "ERROR: y does not have x's key='{}'! y={}".format(key, y) check( value, y[key], decimals=decimals, atol=atol, rtol=rtol, false=false) y_keys.remove(key) assert not y_keys, \ "ERROR: y contains keys ({}) that are not in x! y={}".\ format(list(y_keys), y) # A tuple type. elif isinstance(x, (tuple, list)): assert isinstance(y, (tuple, list)),\ "ERROR: If x is tuple, y needs to be a tuple as well!" assert len(y) == len(x),\ "ERROR: y does not have the same length as x ({} vs {})!".\ format(len(y), len(x)) for i, value in enumerate(x): check( value, y[i], decimals=decimals, atol=atol, rtol=rtol, false=false) # Boolean comparison. elif isinstance(x, (np.bool_, bool)): if false is True: assert bool(x) is not bool(y), \ "ERROR: x ({}) is y ({})!".format(x, y) else: assert bool(x) is bool(y), \ "ERROR: x ({}) is not y ({})!".format(x, y) # Nones or primitives. elif x is None or y is None or isinstance(x, (str, int)): if false is True: assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y) else: assert x == y, \ "ERROR: x ({}) is not the same as y ({})!".format(x, y) # String/byte comparisons. elif hasattr(x, "dtype") and \ (x.dtype == object or str(x.dtype).startswith(" raise error (not expected to be equal). if false is True: assert False, \ "ERROR: x ({}) is the same as y ({})!".format(x, y) # Using atol/rtol. else: # Provide defaults for either one of atol/rtol. if atol is None: atol = 0 if rtol is None: rtol = 1e-7 try: np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) except AssertionError as e: if false is False: raise e else: if false is True: assert False, \ "ERROR: x ({}) is the same as y ({})!".format(x, y) def check_compute_single_action(trainer, include_state=False, include_prev_action_reward=False): """Tests different combinations of args for trainer.compute_single_action. Args: trainer: The Trainer object to test. include_state: Whether to include the initial state of the Policy's Model in the `compute_single_action` call. include_prev_action_reward: Whether to include the prev-action and -reward in the `compute_single_action` call. Raises: ValueError: If anything unexpected happens. """ # Have to import this here to avoid circular dependency. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch # Some Trainers may not abide to the standard API. pid = DEFAULT_POLICY_ID try: # Multi-agent: Pick any policy (or DEFAULT_POLICY if it's the only # one). pid = next(iter(trainer.workers.local_worker().policy_map)) pol = trainer.get_policy(pid) except AttributeError: pol = trainer.policy # Get the policy's model. model = pol.model action_space = pol.action_space def _test(what, method_to_test, obs_space, full_fetch, explore, timestep, unsquash, clip): call_kwargs = {} if what is trainer: call_kwargs["full_fetch"] = full_fetch call_kwargs["policy_id"] = pid obs = obs_space.sample() if isinstance(obs_space, Box): obs = np.clip(obs, -1.0, 1.0) state_in = None if include_state: state_in = model.get_initial_state() if not state_in: state_in = [] i = 0 while f"state_in_{i}" in model.view_requirements: state_in.append(model.view_requirements[f"state_in_{i}"] .space.sample()) i += 1 action_in = action_space.sample() \ if include_prev_action_reward else None reward_in = 1.0 if include_prev_action_reward else None if method_to_test == "input_dict": assert what is pol input_dict = {SampleBatch.OBS: obs} if include_prev_action_reward: input_dict[SampleBatch.PREV_ACTIONS] = action_in input_dict[SampleBatch.PREV_REWARDS] = reward_in if state_in: for i, s in enumerate(state_in): input_dict[f"state_in_{i}"] = s input_dict_batched = SampleBatch( tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict)) action = pol.compute_actions_from_input_dict( input_dict=input_dict_batched, explore=explore, timestep=timestep, **call_kwargs) # Unbatch everything to be able to compare against single # action below. # ARS and ES return action batches as lists. if isinstance(action[0], list): action = (np.array(action[0]), action[1], action[2]) action = tree.map_structure(lambda s: s[0], action) try: action2 = pol.compute_single_action( input_dict=input_dict, explore=explore, timestep=timestep, **call_kwargs) # Make sure these are the same, unless we have exploration # switched on (or noisy layers). if not explore and not pol.config.get("noisy"): check(action, action2) except TypeError: pass else: action = what.compute_single_action( obs, state_in, prev_action=action_in, prev_reward=reward_in, explore=explore, timestep=timestep, unsquash_action=unsquash, clip_action=clip, **call_kwargs) state_out = None if state_in or full_fetch or what is pol: action, state_out, _ = action if state_out: for si, so in zip(state_in, state_out): check(list(si.shape), so.shape) if unsquash is None: unsquash = what.config["normalize_actions"] if clip is None: clip = what.config["clip_actions"] # Test whether unsquash/clipping works on the Trainer's # compute_single_action method: Both flags should force the action # to be within the space's bounds. if method_to_test == "single" and what == trainer: if not action_space.contains(action) and \ (clip or unsquash or not isinstance(action_space, Box)): raise ValueError( f"Returned action ({action}) of trainer/policy {what} " f"not in Env's action_space {action_space}") # We are operating in normalized space: Expect only smaller action # values. if isinstance(action_space, Box) and not unsquash and \ what.config.get("normalize_actions") and \ np.any(np.abs(action) > 3.0): raise ValueError( f"Returned action ({action}) of trainer/policy {what} " "should be in normalized space, but seems too large/small " "for that!") # Loop through: Policy vs Trainer; Different API methods to calculate # actions; unsquash option; clip option; full fetch or not. for what in [pol, trainer]: if what is trainer: # Get the obs-space from Workers.env (not Policy) due to possible # pre-processor up front. worker_set = getattr(trainer, "workers") # TODO: ES and ARS use `self._workers` instead of `self.workers` to # store their rollout worker set. Change to `self.workers`. if worker_set is None: worker_set = getattr(trainer, "_workers", None) assert worker_set if isinstance(worker_set, list): obs_space = trainer.get_policy(pid).observation_space else: obs_space = worker_set.local_worker().for_policy( lambda p: p.observation_space, policy_id=pid) obs_space = getattr(obs_space, "original_space", obs_space) else: obs_space = pol.observation_space for method_to_test in ["single"] + \ (["input_dict"] if what is pol else []): for explore in [True, False]: for full_fetch in ([False, True] if what is trainer else [False]): timestep = random.randint(0, 100000) for unsquash in [True, False, None]: for clip in ([False] if unsquash else [True, False, None]): _test(what, method_to_test, obs_space, full_fetch, explore, timestep, unsquash, clip) def check_learning_achieved(tune_results, min_reward, evaluation=False): """Throws an error if `min_reward` is not reached within tune_results. Checks the last iteration found in tune_results for its "episode_reward_mean" value and compares it to `min_reward`. Args: tune_results: The tune.run returned results object. min_reward (float): The min reward that must be reached. Raises: ValueError: If `min_reward` not reached. """ # Get maximum reward of all trials # (check if at least one trial achieved some learning) avg_rewards = [(trial.last_result["episode_reward_mean"] if not evaluation else trial.last_result["evaluation"]["episode_reward_mean"]) for trial in tune_results.trials] best_avg_reward = max(avg_rewards) if best_avg_reward < min_reward: raise ValueError("`stop-reward` of {} not reached!".format(min_reward)) print("ok") def check_train_results(train_results): """Checks proper structure of a Trainer.train() returned dict. Args: train_results: The train results dict to check. Raises: AssertionError: If `train_results` doesn't have the proper structure or data in it. """ # Import these here to avoid circular dependencies. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \ LEARNER_STATS_KEY from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent # Assert that some keys are where we would expect them. for key in [ "agent_timesteps_total", "config", "custom_metrics", "episode_len_mean", "episode_reward_max", "episode_reward_mean", "episode_reward_min", "episodes_total", "hist_stats", "info", "iterations_since_restore", "num_healthy_workers", "perf", "policy_reward_max", "policy_reward_mean", "policy_reward_min", "sampler_perf", "time_since_restore", "time_this_iter_s", "timesteps_since_restore", "timesteps_total", "timers", "time_total_s", "training_iteration", ]: assert key in train_results, \ f"'{key}' not found in `train_results` ({train_results})!" _, is_multi_agent = check_multi_agent(train_results["config"]) # Check in particular the "info" dict. info = train_results["info"] assert LEARNER_INFO in info, \ f"'learner' not in train_results['infos'] ({info})!" assert "num_steps_trained" in info or "num_env_steps_trained" in info, \ f"'num_(env_)?steps_trained' not in train_results['infos'] ({info})!" learner_info = info[LEARNER_INFO] # Make sure we have a default_policy key if we are not in a # multi-agent setup. if not is_multi_agent: # APEX algos sometimes have an empty learner info dict (no metrics # collected yet). assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \ f"'{DEFAULT_POLICY_ID}' not found in " \ f"train_results['infos']['learner'] ({learner_info})!" for pid, policy_stats in learner_info.items(): if pid == "batch_count": continue # Expect td-errors to be per batch-item. if "td_error" in policy_stats: configured_b = train_results["config"]["train_batch_size"] actual_b = policy_stats["td_error"].shape[0] # R2D2 case. if (configured_b - actual_b) / actual_b > 0.1: assert configured_b / ( train_results["config"]["model"]["max_seq_len"] + train_results["config"]["burn_in"]) == actual_b # Make sure each policy has the LEARNER_STATS_KEY under it. assert LEARNER_STATS_KEY in policy_stats learner_stats = policy_stats[LEARNER_STATS_KEY] for key, value in learner_stats.items(): # Min- and max-stats should be single values. if key.startswith("min_") or key.startswith("max_"): assert np.isscalar( value), f"'key' value not a scalar ({value})!" return train_results def run_learning_tests_from_yaml( yaml_files: List[str], *, max_num_repeats: int = 2, smoke_test: bool = False, ) -> Dict[str, Any]: """Runs the given experiments in yaml_files and returns results dict. Args: yaml_files (List[str]): List of yaml file names. max_num_repeats (int): How many times should we repeat a failed experiment? smoke_test (bool): Whether this is just a smoke-test. If True, set time_total_s to 5min and don't early out due to rewards or timesteps reached. """ print("Will run the following yaml files:") for yaml_file in yaml_files: print("->", yaml_file) # All trials we'll ever run in this test script. all_trials = [] # The experiments (by name) we'll run up to `max_num_repeats` times. experiments = {} # The results per experiment. checks = {} # Metrics per experiment. stats = {} start_time = time.monotonic() def should_check_eval(experiment): # If we have evaluation workers, use their rewards. # This is useful for offline learning tests, where # we evaluate against an actual environment. return experiment["config"].get("evaluation_interval", None) is not None # Loop through all collected files and gather experiments. # Augment all by `torch` framework. for yaml_file in yaml_files: tf_experiments = yaml.safe_load(open(yaml_file).read()) # Add torch version of all experiments to the list. for k, e in tf_experiments.items(): # If framework explicitly given, only test for that framework. # Some algos do not have both versions available. if "frameworks" in e: frameworks = e["frameworks"] else: # By default we don't run tf2, because tf2's multi-gpu support # isn't complete yet. frameworks = ["tf", "torch"] # Pop frameworks key to not confuse Tune. e.pop("frameworks", None) e["stop"] = e["stop"] if "stop" in e else {} e["pass_criteria"] = e[ "pass_criteria"] if "pass_criteria" in e else {} # For smoke-tests, we just run for n min. if smoke_test: # 0sec for each(!) experiment/trial. # This is such that if there are many experiments/trials # in a test (e.g. rllib_learning_test), each one can at least # create its trainer and run a first iteration. e["stop"]["time_total_s"] = 0 else: check_eval = should_check_eval(e) episode_reward_key = ("episode_reward_mean" if not check_eval else "evaluation/episode_reward_mean") # We also stop early, once we reach the desired reward. min_reward = e.get("pass_criteria", {}).get(episode_reward_key) if min_reward is not None: e["stop"][episode_reward_key] = min_reward # Generate `checks` dict for all experiments # (tf, tf2 and/or torch). for framework in frameworks: k_ = k + "-" + framework ec = copy.deepcopy(e) ec["config"]["framework"] = framework if framework == "tf2": ec["config"]["eager_tracing"] = True checks[k_] = { "min_reward": ec["pass_criteria"].get( "episode_reward_mean", 0.0), "min_throughput": ec["pass_criteria"].get( "timesteps_total", 0.0) / (ec["stop"].get("time_total_s", 1.0) or 1.0), "time_total_s": ec["stop"].get("time_total_s"), "failures": 0, "passed": False, } # This key would break tune. ec.pop("pass_criteria", None) # One experiment to run. experiments[k_] = ec # Print out the actual config. print("== Test config ==") print(yaml.dump(experiments)) # Keep track of those experiments we still have to run. # If an experiment passes, we'll remove it from this dict. experiments_to_run = experiments.copy() try: ray.init(address="auto") except ConnectionError: ray.init() for i in range(max_num_repeats): # We are done. if len(experiments_to_run) == 0: print("All experiments finished.") break print(f"Starting learning test iteration {i}...") # Run remaining experiments. trials = run_experiments( experiments_to_run, resume=False, verbose=2, progress_reporter=CLIReporter( metric_columns={ "training_iteration": "iter", "time_total_s": "time_total_s", "timesteps_total": "ts", "episodes_this_iter": "train_episodes", "episode_reward_mean": "reward_mean", "evaluation/episode_reward_mean": "eval_reward_mean", }, sort_by_metric=True, max_report_frequency=30, )) all_trials.extend(trials) # Check each experiment for whether it passed. # Criteria is to a) reach reward AND b) to have reached the throughput # defined by `timesteps_total` / `time_total_s`. for experiment in experiments_to_run.copy(): print(f"Analyzing experiment {experiment} ...") # Collect all trials within this experiment (some experiments may # have num_samples or grid_searches defined). trials_for_experiment = [] for t in trials: trial_exp = re.sub(".+/([^/]+)$", "\\1", t.local_dir) if trial_exp == experiment: trials_for_experiment.append(t) print(f" ... Trials: {trials_for_experiment}.") check_eval = should_check_eval(experiments[experiment]) # Error: Increase failure count and repeat. if any(t.status == "ERROR" for t in trials_for_experiment): print(" ... ERROR.") checks[experiment]["failures"] += 1 # Smoke-tests always succeed. elif smoke_test: print(" ... SMOKE TEST (mark ok).") checks[experiment]["passed"] = True del experiments_to_run[experiment] # Experiment finished: Check reward achieved and timesteps done # (throughput). else: if check_eval: episode_reward_mean = np.mean([ t.last_result["evaluation"]["episode_reward_mean"] for t in trials_for_experiment ]) else: episode_reward_mean = np.mean([ t.last_result["episode_reward_mean"] for t in trials_for_experiment ]) desired_reward = checks[experiment]["min_reward"] timesteps_total = np.mean([ t.last_result["timesteps_total"] for t in trials_for_experiment ]) total_time_s = np.mean([ t.last_result["time_total_s"] for t in trials_for_experiment ]) # TODO(jungong) : track trainer and env throughput separately. throughput = timesteps_total / (total_time_s or 1.0) # TODO(jungong) : enable throughput check again after # TD3_HalfCheetahBulletEnv is fixed and verified. # desired_throughput = checks[experiment]["min_throughput"] desired_throughput = None # Record performance. stats[experiment] = { "episode_reward_mean": float(episode_reward_mean), "throughput": (float(throughput) if throughput is not None else 0.0), } print(f" ... Desired reward={desired_reward}; " f"desired throughput={desired_throughput}") # We failed to reach desired reward or the desired throughput. if (desired_reward and episode_reward_mean < desired_reward) or \ (desired_throughput and throughput < desired_throughput): print(" ... Not successful: Actual " f"reward={episode_reward_mean}; " f"actual throughput={throughput}") checks[experiment]["failures"] += 1 # We succeeded! else: print(" ... Successful: (mark ok).") checks[experiment]["passed"] = True del experiments_to_run[experiment] ray.shutdown() time_taken = time.monotonic() - start_time # Create results dict and write it to disk. result = { "time_taken": float(time_taken), "trial_states": dict(Counter([trial.status for trial in all_trials])), "last_update": float(time.time()), "stats": stats, "passed": [k for k, exp in checks.items() if exp["passed"]], "failures": { k: exp["failures"] for k, exp in checks.items() if exp["failures"] > 0 } } return result