test_utils.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. from collections import Counter
  2. import copy
  3. from gym.spaces import Box
  4. import logging
  5. import numpy as np
  6. import random
  7. import re
  8. import time
  9. import tree # pip install dm_tree
  10. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  11. import yaml
  12. import ray
  13. from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
  14. try_import_torch
  15. from ray.rllib.utils.typing import PartialTrainerConfigDict
  16. from ray.tune import CLIReporter, run_experiments
  17. jax, _ = try_import_jax()
  18. tf1, tf, tfv = try_import_tf()
  19. if tf1:
  20. eager_mode = None
  21. try:
  22. from tensorflow.python.eager.context import eager_mode
  23. except (ImportError, ModuleNotFoundError):
  24. pass
  25. torch, _ = try_import_torch()
  26. logger = logging.getLogger(__name__)
  27. def framework_iterator(
  28. config: Optional[PartialTrainerConfigDict] = None,
  29. frameworks: Sequence[str] = ("tf2", "tf", "tfe", "torch"),
  30. session: bool = False,
  31. with_eager_tracing: bool = False,
  32. time_iterations: Optional[dict] = None,
  33. ) -> Union[str, Tuple[str, Optional["tf1.Session"]]]:
  34. """An generator that allows for looping through n frameworks for testing.
  35. Provides the correct config entries ("framework") as well
  36. as the correct eager/non-eager contexts for tfe/tf.
  37. Args:
  38. config: An optional config dict to alter in place depending on the
  39. iteration.
  40. frameworks: A list/tuple of the frameworks to be tested.
  41. Allowed are: "tf2", "tf", "tfe", "torch", and None.
  42. session: If True and only in the tf-case: Enter a tf.Session()
  43. and yield that as second return value (otherwise yield (fw, None)).
  44. Also sets a seed (42) on the session to make the test
  45. deterministic.
  46. with_eager_tracing: Include `eager_tracing=True` in the returned
  47. configs, when framework=[tfe|tf2].
  48. time_iterations: If provided, will write to the given dict (by
  49. framework key) the times in seconds that each (framework's)
  50. iteration takes.
  51. Yields:
  52. If `session` is False: The current framework [tf2|tf|tfe|torch] used.
  53. If `session` is True: A tuple consisting of the current framework
  54. string and the tf1.Session (if fw="tf", otherwise None).
  55. """
  56. config = config or {}
  57. frameworks = [frameworks] if isinstance(frameworks, str) else \
  58. list(frameworks)
  59. # Both tf2 and tfe present -> remove "tfe" or "tf2" depending on version.
  60. if "tf2" in frameworks and "tfe" in frameworks:
  61. frameworks.remove("tfe" if tfv == 2 else "tf2")
  62. for fw in frameworks:
  63. # Skip non-installed frameworks.
  64. if fw == "torch" and not torch:
  65. logger.warning(
  66. "framework_iterator skipping torch (not installed)!")
  67. continue
  68. if fw != "torch" and not tf:
  69. logger.warning("framework_iterator skipping {} (tf not "
  70. "installed)!".format(fw))
  71. continue
  72. elif fw == "tfe" and not eager_mode:
  73. logger.warning("framework_iterator skipping tf-eager (could not "
  74. "import `eager_mode` from tensorflow.python)!")
  75. continue
  76. elif fw == "tf2" and tfv != 2:
  77. logger.warning(
  78. "framework_iterator skipping tf2.x (tf version is < 2.0)!")
  79. continue
  80. elif fw == "jax" and not jax:
  81. logger.warning("framework_iterator skipping JAX (not installed)!")
  82. continue
  83. assert fw in ["tf2", "tf", "tfe", "torch", "jax", None]
  84. # Do we need a test session?
  85. sess = None
  86. if fw == "tf" and session is True:
  87. sess = tf1.Session()
  88. sess.__enter__()
  89. tf1.set_random_seed(42)
  90. config["framework"] = fw
  91. eager_ctx = None
  92. # Enable eager mode for tf2 and tfe.
  93. if fw in ["tf2", "tfe"]:
  94. eager_ctx = eager_mode()
  95. eager_ctx.__enter__()
  96. assert tf1.executing_eagerly()
  97. # Make sure, eager mode is off.
  98. elif fw == "tf":
  99. assert not tf1.executing_eagerly()
  100. # Additionally loop through eager_tracing=True + False, if necessary.
  101. if fw in ["tf2", "tfe"] and with_eager_tracing:
  102. for tracing in [True, False]:
  103. config["eager_tracing"] = tracing
  104. print(f"framework={fw} (eager-tracing={tracing})")
  105. time_started = time.time()
  106. yield fw if session is False else (fw, sess)
  107. if time_iterations is not None:
  108. time_total = time.time() - time_started
  109. time_iterations[fw + ("+tracing" if tracing else "")] = \
  110. time_total
  111. print(f".. took {time_total}sec")
  112. config["eager_tracing"] = False
  113. # Yield current framework + tf-session (if necessary).
  114. else:
  115. print(f"framework={fw}")
  116. time_started = time.time()
  117. yield fw if session is False else (fw, sess)
  118. if time_iterations is not None:
  119. time_total = time.time() - time_started
  120. time_iterations[fw + ("+tracing" if tracing else "")] = \
  121. time_total
  122. print(f".. took {time_total}sec")
  123. # Exit any context we may have entered.
  124. if eager_ctx:
  125. eager_ctx.__exit__(None, None, None)
  126. elif sess:
  127. sess.__exit__(None, None, None)
  128. def check(x, y, decimals=5, atol=None, rtol=None, false=False):
  129. """
  130. Checks two structures (dict, tuple, list,
  131. np.array, float, int, etc..) for (almost) numeric identity.
  132. All numbers in the two structures have to match up to `decimal` digits
  133. after the floating point. Uses assertions.
  134. Args:
  135. x (any): The value to be compared (to the expectation: `y`). This
  136. may be a Tensor.
  137. y (any): The expected value to be compared to `x`. This must not
  138. be a tf-Tensor, but may be a tfe/torch-Tensor.
  139. decimals (int): The number of digits after the floating point up to
  140. which all numeric values have to match.
  141. atol (float): Absolute tolerance of the difference between x and y
  142. (overrides `decimals` if given).
  143. rtol (float): Relative tolerance of the difference between x and y
  144. (overrides `decimals` if given).
  145. false (bool): Whether to check that x and y are NOT the same.
  146. """
  147. # A dict type.
  148. if isinstance(x, dict):
  149. assert isinstance(y, dict), \
  150. "ERROR: If x is dict, y needs to be a dict as well!"
  151. y_keys = set(x.keys())
  152. for key, value in x.items():
  153. assert key in y, \
  154. "ERROR: y does not have x's key='{}'! y={}".format(key, y)
  155. check(
  156. value,
  157. y[key],
  158. decimals=decimals,
  159. atol=atol,
  160. rtol=rtol,
  161. false=false)
  162. y_keys.remove(key)
  163. assert not y_keys, \
  164. "ERROR: y contains keys ({}) that are not in x! y={}".\
  165. format(list(y_keys), y)
  166. # A tuple type.
  167. elif isinstance(x, (tuple, list)):
  168. assert isinstance(y, (tuple, list)),\
  169. "ERROR: If x is tuple, y needs to be a tuple as well!"
  170. assert len(y) == len(x),\
  171. "ERROR: y does not have the same length as x ({} vs {})!".\
  172. format(len(y), len(x))
  173. for i, value in enumerate(x):
  174. check(
  175. value,
  176. y[i],
  177. decimals=decimals,
  178. atol=atol,
  179. rtol=rtol,
  180. false=false)
  181. # Boolean comparison.
  182. elif isinstance(x, (np.bool_, bool)):
  183. if false is True:
  184. assert bool(x) is not bool(y), \
  185. "ERROR: x ({}) is y ({})!".format(x, y)
  186. else:
  187. assert bool(x) is bool(y), \
  188. "ERROR: x ({}) is not y ({})!".format(x, y)
  189. # Nones or primitives.
  190. elif x is None or y is None or isinstance(x, (str, int)):
  191. if false is True:
  192. assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y)
  193. else:
  194. assert x == y, \
  195. "ERROR: x ({}) is not the same as y ({})!".format(x, y)
  196. # String/byte comparisons.
  197. elif hasattr(x, "dtype") and \
  198. (x.dtype == object or str(x.dtype).startswith("<U")):
  199. try:
  200. np.testing.assert_array_equal(x, y)
  201. if false is True:
  202. assert False, \
  203. "ERROR: x ({}) is the same as y ({})!".format(x, y)
  204. except AssertionError as e:
  205. if false is False:
  206. raise e
  207. # Everything else (assume numeric or tf/torch.Tensor).
  208. else:
  209. if tf1 is not None:
  210. # y should never be a Tensor (y=expected value).
  211. if isinstance(y, (tf1.Tensor, tf1.Variable)):
  212. # In eager mode, numpyize tensors.
  213. if tf.executing_eagerly():
  214. y = y.numpy()
  215. else:
  216. raise ValueError(
  217. "`y` (expected value) must not be a Tensor. "
  218. "Use numpy.ndarray instead")
  219. if isinstance(x, (tf1.Tensor, tf1.Variable)):
  220. # In eager mode, numpyize tensors.
  221. if tf1.executing_eagerly():
  222. x = x.numpy()
  223. # Otherwise, use a new tf-session.
  224. else:
  225. with tf1.Session() as sess:
  226. x = sess.run(x)
  227. return check(
  228. x,
  229. y,
  230. decimals=decimals,
  231. atol=atol,
  232. rtol=rtol,
  233. false=false)
  234. if torch is not None:
  235. if isinstance(x, torch.Tensor):
  236. x = x.detach().cpu().numpy()
  237. if isinstance(y, torch.Tensor):
  238. y = y.detach().cpu().numpy()
  239. # Using decimals.
  240. if atol is None and rtol is None:
  241. # Assert equality of both values.
  242. try:
  243. np.testing.assert_almost_equal(x, y, decimal=decimals)
  244. # Both values are not equal.
  245. except AssertionError as e:
  246. # Raise error in normal case.
  247. if false is False:
  248. raise e
  249. # Both values are equal.
  250. else:
  251. # If false is set -> raise error (not expected to be equal).
  252. if false is True:
  253. assert False, \
  254. "ERROR: x ({}) is the same as y ({})!".format(x, y)
  255. # Using atol/rtol.
  256. else:
  257. # Provide defaults for either one of atol/rtol.
  258. if atol is None:
  259. atol = 0
  260. if rtol is None:
  261. rtol = 1e-7
  262. try:
  263. np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
  264. except AssertionError as e:
  265. if false is False:
  266. raise e
  267. else:
  268. if false is True:
  269. assert False, \
  270. "ERROR: x ({}) is the same as y ({})!".format(x, y)
  271. def check_compute_single_action(trainer,
  272. include_state=False,
  273. include_prev_action_reward=False):
  274. """Tests different combinations of args for trainer.compute_single_action.
  275. Args:
  276. trainer: The Trainer object to test.
  277. include_state: Whether to include the initial state of the Policy's
  278. Model in the `compute_single_action` call.
  279. include_prev_action_reward: Whether to include the prev-action and
  280. -reward in the `compute_single_action` call.
  281. Raises:
  282. ValueError: If anything unexpected happens.
  283. """
  284. # Have to import this here to avoid circular dependency.
  285. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
  286. # Some Trainers may not abide to the standard API.
  287. pid = DEFAULT_POLICY_ID
  288. try:
  289. # Multi-agent: Pick any policy (or DEFAULT_POLICY if it's the only
  290. # one).
  291. pid = next(iter(trainer.workers.local_worker().policy_map))
  292. pol = trainer.get_policy(pid)
  293. except AttributeError:
  294. pol = trainer.policy
  295. # Get the policy's model.
  296. model = pol.model
  297. action_space = pol.action_space
  298. def _test(what, method_to_test, obs_space, full_fetch, explore, timestep,
  299. unsquash, clip):
  300. call_kwargs = {}
  301. if what is trainer:
  302. call_kwargs["full_fetch"] = full_fetch
  303. call_kwargs["policy_id"] = pid
  304. obs = obs_space.sample()
  305. if isinstance(obs_space, Box):
  306. obs = np.clip(obs, -1.0, 1.0)
  307. state_in = None
  308. if include_state:
  309. state_in = model.get_initial_state()
  310. if not state_in:
  311. state_in = []
  312. i = 0
  313. while f"state_in_{i}" in model.view_requirements:
  314. state_in.append(model.view_requirements[f"state_in_{i}"]
  315. .space.sample())
  316. i += 1
  317. action_in = action_space.sample() \
  318. if include_prev_action_reward else None
  319. reward_in = 1.0 if include_prev_action_reward else None
  320. if method_to_test == "input_dict":
  321. assert what is pol
  322. input_dict = {SampleBatch.OBS: obs}
  323. if include_prev_action_reward:
  324. input_dict[SampleBatch.PREV_ACTIONS] = action_in
  325. input_dict[SampleBatch.PREV_REWARDS] = reward_in
  326. if state_in:
  327. for i, s in enumerate(state_in):
  328. input_dict[f"state_in_{i}"] = s
  329. input_dict_batched = SampleBatch(
  330. tree.map_structure(lambda s: np.expand_dims(s, 0), input_dict))
  331. action = pol.compute_actions_from_input_dict(
  332. input_dict=input_dict_batched,
  333. explore=explore,
  334. timestep=timestep,
  335. **call_kwargs)
  336. # Unbatch everything to be able to compare against single
  337. # action below.
  338. # ARS and ES return action batches as lists.
  339. if isinstance(action[0], list):
  340. action = (np.array(action[0]), action[1], action[2])
  341. action = tree.map_structure(lambda s: s[0], action)
  342. try:
  343. action2 = pol.compute_single_action(
  344. input_dict=input_dict,
  345. explore=explore,
  346. timestep=timestep,
  347. **call_kwargs)
  348. # Make sure these are the same, unless we have exploration
  349. # switched on (or noisy layers).
  350. if not explore and not pol.config.get("noisy"):
  351. check(action, action2)
  352. except TypeError:
  353. pass
  354. else:
  355. action = what.compute_single_action(
  356. obs,
  357. state_in,
  358. prev_action=action_in,
  359. prev_reward=reward_in,
  360. explore=explore,
  361. timestep=timestep,
  362. unsquash_action=unsquash,
  363. clip_action=clip,
  364. **call_kwargs)
  365. state_out = None
  366. if state_in or full_fetch or what is pol:
  367. action, state_out, _ = action
  368. if state_out:
  369. for si, so in zip(state_in, state_out):
  370. check(list(si.shape), so.shape)
  371. if unsquash is None:
  372. unsquash = what.config["normalize_actions"]
  373. if clip is None:
  374. clip = what.config["clip_actions"]
  375. # Test whether unsquash/clipping works on the Trainer's
  376. # compute_single_action method: Both flags should force the action
  377. # to be within the space's bounds.
  378. if method_to_test == "single" and what == trainer:
  379. if not action_space.contains(action) and \
  380. (clip or unsquash or not isinstance(action_space, Box)):
  381. raise ValueError(
  382. f"Returned action ({action}) of trainer/policy {what} "
  383. f"not in Env's action_space {action_space}")
  384. # We are operating in normalized space: Expect only smaller action
  385. # values.
  386. if isinstance(action_space, Box) and not unsquash and \
  387. what.config.get("normalize_actions") and \
  388. np.any(np.abs(action) > 3.0):
  389. raise ValueError(
  390. f"Returned action ({action}) of trainer/policy {what} "
  391. "should be in normalized space, but seems too large/small "
  392. "for that!")
  393. # Loop through: Policy vs Trainer; Different API methods to calculate
  394. # actions; unsquash option; clip option; full fetch or not.
  395. for what in [pol, trainer]:
  396. if what is trainer:
  397. # Get the obs-space from Workers.env (not Policy) due to possible
  398. # pre-processor up front.
  399. worker_set = getattr(trainer, "workers")
  400. # TODO: ES and ARS use `self._workers` instead of `self.workers` to
  401. # store their rollout worker set. Change to `self.workers`.
  402. if worker_set is None:
  403. worker_set = getattr(trainer, "_workers", None)
  404. assert worker_set
  405. if isinstance(worker_set, list):
  406. obs_space = trainer.get_policy(pid).observation_space
  407. else:
  408. obs_space = worker_set.local_worker().for_policy(
  409. lambda p: p.observation_space, policy_id=pid)
  410. obs_space = getattr(obs_space, "original_space", obs_space)
  411. else:
  412. obs_space = pol.observation_space
  413. for method_to_test in ["single"] + \
  414. (["input_dict"] if what is pol else []):
  415. for explore in [True, False]:
  416. for full_fetch in ([False, True]
  417. if what is trainer else [False]):
  418. timestep = random.randint(0, 100000)
  419. for unsquash in [True, False, None]:
  420. for clip in ([False]
  421. if unsquash else [True, False, None]):
  422. _test(what, method_to_test, obs_space, full_fetch,
  423. explore, timestep, unsquash, clip)
  424. def check_learning_achieved(tune_results, min_reward, evaluation=False):
  425. """Throws an error if `min_reward` is not reached within tune_results.
  426. Checks the last iteration found in tune_results for its
  427. "episode_reward_mean" value and compares it to `min_reward`.
  428. Args:
  429. tune_results: The tune.run returned results object.
  430. min_reward (float): The min reward that must be reached.
  431. Raises:
  432. ValueError: If `min_reward` not reached.
  433. """
  434. # Get maximum reward of all trials
  435. # (check if at least one trial achieved some learning)
  436. avg_rewards = [(trial.last_result["episode_reward_mean"]
  437. if not evaluation else
  438. trial.last_result["evaluation"]["episode_reward_mean"])
  439. for trial in tune_results.trials]
  440. best_avg_reward = max(avg_rewards)
  441. if best_avg_reward < min_reward:
  442. raise ValueError("`stop-reward` of {} not reached!".format(min_reward))
  443. print("ok")
  444. def check_train_results(train_results):
  445. """Checks proper structure of a Trainer.train() returned dict.
  446. Args:
  447. train_results: The train results dict to check.
  448. Raises:
  449. AssertionError: If `train_results` doesn't have the proper structure or
  450. data in it.
  451. """
  452. # Import these here to avoid circular dependencies.
  453. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  454. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
  455. LEARNER_STATS_KEY
  456. from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent
  457. # Assert that some keys are where we would expect them.
  458. for key in [
  459. "agent_timesteps_total",
  460. "config",
  461. "custom_metrics",
  462. "episode_len_mean",
  463. "episode_reward_max",
  464. "episode_reward_mean",
  465. "episode_reward_min",
  466. "episodes_total",
  467. "hist_stats",
  468. "info",
  469. "iterations_since_restore",
  470. "num_healthy_workers",
  471. "perf",
  472. "policy_reward_max",
  473. "policy_reward_mean",
  474. "policy_reward_min",
  475. "sampler_perf",
  476. "time_since_restore",
  477. "time_this_iter_s",
  478. "timesteps_since_restore",
  479. "timesteps_total",
  480. "timers",
  481. "time_total_s",
  482. "training_iteration",
  483. ]:
  484. assert key in train_results, \
  485. f"'{key}' not found in `train_results` ({train_results})!"
  486. _, is_multi_agent = check_multi_agent(train_results["config"])
  487. # Check in particular the "info" dict.
  488. info = train_results["info"]
  489. assert LEARNER_INFO in info, \
  490. f"'learner' not in train_results['infos'] ({info})!"
  491. assert "num_steps_trained" in info or "num_env_steps_trained" in info, \
  492. f"'num_(env_)?steps_trained' not in train_results['infos'] ({info})!"
  493. learner_info = info[LEARNER_INFO]
  494. # Make sure we have a default_policy key if we are not in a
  495. # multi-agent setup.
  496. if not is_multi_agent:
  497. # APEX algos sometimes have an empty learner info dict (no metrics
  498. # collected yet).
  499. assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \
  500. f"'{DEFAULT_POLICY_ID}' not found in " \
  501. f"train_results['infos']['learner'] ({learner_info})!"
  502. for pid, policy_stats in learner_info.items():
  503. if pid == "batch_count":
  504. continue
  505. # Expect td-errors to be per batch-item.
  506. if "td_error" in policy_stats:
  507. configured_b = train_results["config"]["train_batch_size"]
  508. actual_b = policy_stats["td_error"].shape[0]
  509. # R2D2 case.
  510. if (configured_b - actual_b) / actual_b > 0.1:
  511. assert configured_b / (
  512. train_results["config"]["model"]["max_seq_len"] +
  513. train_results["config"]["burn_in"]) == actual_b
  514. # Make sure each policy has the LEARNER_STATS_KEY under it.
  515. assert LEARNER_STATS_KEY in policy_stats
  516. learner_stats = policy_stats[LEARNER_STATS_KEY]
  517. for key, value in learner_stats.items():
  518. # Min- and max-stats should be single values.
  519. if key.startswith("min_") or key.startswith("max_"):
  520. assert np.isscalar(
  521. value), f"'key' value not a scalar ({value})!"
  522. return train_results
  523. def run_learning_tests_from_yaml(
  524. yaml_files: List[str],
  525. *,
  526. max_num_repeats: int = 2,
  527. smoke_test: bool = False,
  528. ) -> Dict[str, Any]:
  529. """Runs the given experiments in yaml_files and returns results dict.
  530. Args:
  531. yaml_files (List[str]): List of yaml file names.
  532. max_num_repeats (int): How many times should we repeat a failed
  533. experiment?
  534. smoke_test (bool): Whether this is just a smoke-test. If True,
  535. set time_total_s to 5min and don't early out due to rewards
  536. or timesteps reached.
  537. """
  538. print("Will run the following yaml files:")
  539. for yaml_file in yaml_files:
  540. print("->", yaml_file)
  541. # All trials we'll ever run in this test script.
  542. all_trials = []
  543. # The experiments (by name) we'll run up to `max_num_repeats` times.
  544. experiments = {}
  545. # The results per experiment.
  546. checks = {}
  547. # Metrics per experiment.
  548. stats = {}
  549. start_time = time.monotonic()
  550. def should_check_eval(experiment):
  551. # If we have evaluation workers, use their rewards.
  552. # This is useful for offline learning tests, where
  553. # we evaluate against an actual environment.
  554. return experiment["config"].get("evaluation_interval",
  555. None) is not None
  556. # Loop through all collected files and gather experiments.
  557. # Augment all by `torch` framework.
  558. for yaml_file in yaml_files:
  559. tf_experiments = yaml.safe_load(open(yaml_file).read())
  560. # Add torch version of all experiments to the list.
  561. for k, e in tf_experiments.items():
  562. # If framework explicitly given, only test for that framework.
  563. # Some algos do not have both versions available.
  564. if "frameworks" in e:
  565. frameworks = e["frameworks"]
  566. else:
  567. # By default we don't run tf2, because tf2's multi-gpu support
  568. # isn't complete yet.
  569. frameworks = ["tf", "torch"]
  570. # Pop frameworks key to not confuse Tune.
  571. e.pop("frameworks", None)
  572. e["stop"] = e["stop"] if "stop" in e else {}
  573. e["pass_criteria"] = e[
  574. "pass_criteria"] if "pass_criteria" in e else {}
  575. # For smoke-tests, we just run for n min.
  576. if smoke_test:
  577. # 0sec for each(!) experiment/trial.
  578. # This is such that if there are many experiments/trials
  579. # in a test (e.g. rllib_learning_test), each one can at least
  580. # create its trainer and run a first iteration.
  581. e["stop"]["time_total_s"] = 0
  582. else:
  583. check_eval = should_check_eval(e)
  584. episode_reward_key = ("episode_reward_mean" if not check_eval
  585. else "evaluation/episode_reward_mean")
  586. # We also stop early, once we reach the desired reward.
  587. min_reward = e.get("pass_criteria", {}).get(episode_reward_key)
  588. if min_reward is not None:
  589. e["stop"][episode_reward_key] = min_reward
  590. # Generate `checks` dict for all experiments
  591. # (tf, tf2 and/or torch).
  592. for framework in frameworks:
  593. k_ = k + "-" + framework
  594. ec = copy.deepcopy(e)
  595. ec["config"]["framework"] = framework
  596. if framework == "tf2":
  597. ec["config"]["eager_tracing"] = True
  598. checks[k_] = {
  599. "min_reward": ec["pass_criteria"].get(
  600. "episode_reward_mean", 0.0),
  601. "min_throughput": ec["pass_criteria"].get(
  602. "timesteps_total", 0.0) /
  603. (ec["stop"].get("time_total_s", 1.0) or 1.0),
  604. "time_total_s": ec["stop"].get("time_total_s"),
  605. "failures": 0,
  606. "passed": False,
  607. }
  608. # This key would break tune.
  609. ec.pop("pass_criteria", None)
  610. # One experiment to run.
  611. experiments[k_] = ec
  612. # Print out the actual config.
  613. print("== Test config ==")
  614. print(yaml.dump(experiments))
  615. # Keep track of those experiments we still have to run.
  616. # If an experiment passes, we'll remove it from this dict.
  617. experiments_to_run = experiments.copy()
  618. try:
  619. ray.init(address="auto")
  620. except ConnectionError:
  621. ray.init()
  622. for i in range(max_num_repeats):
  623. # We are done.
  624. if len(experiments_to_run) == 0:
  625. print("All experiments finished.")
  626. break
  627. print(f"Starting learning test iteration {i}...")
  628. # Run remaining experiments.
  629. trials = run_experiments(
  630. experiments_to_run,
  631. resume=False,
  632. verbose=2,
  633. progress_reporter=CLIReporter(
  634. metric_columns={
  635. "training_iteration": "iter",
  636. "time_total_s": "time_total_s",
  637. "timesteps_total": "ts",
  638. "episodes_this_iter": "train_episodes",
  639. "episode_reward_mean": "reward_mean",
  640. "evaluation/episode_reward_mean": "eval_reward_mean",
  641. },
  642. sort_by_metric=True,
  643. max_report_frequency=30,
  644. ))
  645. all_trials.extend(trials)
  646. # Check each experiment for whether it passed.
  647. # Criteria is to a) reach reward AND b) to have reached the throughput
  648. # defined by `timesteps_total` / `time_total_s`.
  649. for experiment in experiments_to_run.copy():
  650. print(f"Analyzing experiment {experiment} ...")
  651. # Collect all trials within this experiment (some experiments may
  652. # have num_samples or grid_searches defined).
  653. trials_for_experiment = []
  654. for t in trials:
  655. trial_exp = re.sub(".+/([^/]+)$", "\\1", t.local_dir)
  656. if trial_exp == experiment:
  657. trials_for_experiment.append(t)
  658. print(f" ... Trials: {trials_for_experiment}.")
  659. check_eval = should_check_eval(experiments[experiment])
  660. # Error: Increase failure count and repeat.
  661. if any(t.status == "ERROR" for t in trials_for_experiment):
  662. print(" ... ERROR.")
  663. checks[experiment]["failures"] += 1
  664. # Smoke-tests always succeed.
  665. elif smoke_test:
  666. print(" ... SMOKE TEST (mark ok).")
  667. checks[experiment]["passed"] = True
  668. del experiments_to_run[experiment]
  669. # Experiment finished: Check reward achieved and timesteps done
  670. # (throughput).
  671. else:
  672. if check_eval:
  673. episode_reward_mean = np.mean([
  674. t.last_result["evaluation"]["episode_reward_mean"]
  675. for t in trials_for_experiment
  676. ])
  677. else:
  678. episode_reward_mean = np.mean([
  679. t.last_result["episode_reward_mean"]
  680. for t in trials_for_experiment
  681. ])
  682. desired_reward = checks[experiment]["min_reward"]
  683. timesteps_total = np.mean([
  684. t.last_result["timesteps_total"]
  685. for t in trials_for_experiment
  686. ])
  687. total_time_s = np.mean([
  688. t.last_result["time_total_s"]
  689. for t in trials_for_experiment
  690. ])
  691. # TODO(jungong) : track trainer and env throughput separately.
  692. throughput = timesteps_total / (total_time_s or 1.0)
  693. # TODO(jungong) : enable throughput check again after
  694. # TD3_HalfCheetahBulletEnv is fixed and verified.
  695. # desired_throughput = checks[experiment]["min_throughput"]
  696. desired_throughput = None
  697. # Record performance.
  698. stats[experiment] = {
  699. "episode_reward_mean": float(episode_reward_mean),
  700. "throughput": (float(throughput)
  701. if throughput is not None else 0.0),
  702. }
  703. print(f" ... Desired reward={desired_reward}; "
  704. f"desired throughput={desired_throughput}")
  705. # We failed to reach desired reward or the desired throughput.
  706. if (desired_reward and
  707. episode_reward_mean < desired_reward) or \
  708. (desired_throughput and
  709. throughput < desired_throughput):
  710. print(" ... Not successful: Actual "
  711. f"reward={episode_reward_mean}; "
  712. f"actual throughput={throughput}")
  713. checks[experiment]["failures"] += 1
  714. # We succeeded!
  715. else:
  716. print(" ... Successful: (mark ok).")
  717. checks[experiment]["passed"] = True
  718. del experiments_to_run[experiment]
  719. ray.shutdown()
  720. time_taken = time.monotonic() - start_time
  721. # Create results dict and write it to disk.
  722. result = {
  723. "time_taken": float(time_taken),
  724. "trial_states": dict(Counter([trial.status for trial in all_trials])),
  725. "last_update": float(time.time()),
  726. "stats": stats,
  727. "passed": [k for k, exp in checks.items() if exp["passed"]],
  728. "failures": {
  729. k: exp["failures"]
  730. for k, exp in checks.items() if exp["failures"] > 0
  731. }
  732. }
  733. return result