evaluate.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. #!/usr/bin/env python
  2. import collections
  3. import copy
  4. import gymnasium as gym
  5. import json
  6. import os
  7. from pathlib import Path
  8. import shelve
  9. import typer
  10. import ray
  11. import ray.cloudpickle as cloudpickle
  12. from ray.rllib.common import CLIArguments as cli
  13. from ray.rllib.env import MultiAgentEnv
  14. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  15. from ray.rllib.env.env_context import EnvContext
  16. from ray.rllib.env.env_runner_group import EnvRunnerGroup
  17. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  18. from ray.rllib.utils.metrics import (
  19. ENV_RUNNER_RESULTS,
  20. EPISODE_LEN_MEAN,
  21. EPISODE_RETURN_MEAN,
  22. )
  23. from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
  24. from ray.train._checkpoint import Checkpoint
  25. from ray.train._internal.session import _TrainingResult
  26. from ray.tune.utils import merge_dicts
  27. from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
  28. # create the "evaluate" Typer app
  29. eval_app = typer.Typer()
  30. class RolloutSaver:
  31. """Utility class for storing rollouts.
  32. Currently supports two behaviours: the original, which
  33. simply dumps everything to a pickle file once complete,
  34. and a mode which stores each rollout as an entry in a Python
  35. shelf db file. The latter mode is more robust to memory problems
  36. or crashes part-way through the rollout generation. Each rollout
  37. is stored with a key based on the episode number (0-indexed),
  38. and the number of episodes is stored with the key "num_episodes",
  39. so to load the shelf file, use something like:
  40. with shelve.open('rollouts.pkl') as rollouts:
  41. for episode_index in range(rollouts["num_episodes"]):
  42. rollout = rollouts[str(episode_index)]
  43. If outfile is None, this class does nothing.
  44. """
  45. def __init__(
  46. self,
  47. outfile=None,
  48. use_shelve=False,
  49. write_update_file=False,
  50. target_steps=None,
  51. target_episodes=None,
  52. save_info=False,
  53. ):
  54. self._outfile = outfile
  55. self._update_file = None
  56. self._use_shelve = use_shelve
  57. self._write_update_file = write_update_file
  58. self._shelf = None
  59. self._num_episodes = 0
  60. self._rollouts = []
  61. self._current_rollout = []
  62. self._total_steps = 0
  63. self._target_episodes = target_episodes
  64. self._target_steps = target_steps
  65. self._save_info = save_info
  66. def _get_tmp_progress_filename(self):
  67. outpath = Path(self._outfile)
  68. return outpath.parent / ("__progress_" + outpath.name)
  69. @property
  70. def outfile(self):
  71. return self._outfile
  72. def __enter__(self):
  73. if self._outfile:
  74. if self._use_shelve:
  75. # Open a shelf file to store each rollout as they come in
  76. self._shelf = shelve.open(self._outfile)
  77. else:
  78. # Original behaviour - keep all rollouts in memory and save
  79. # them all at the end.
  80. # But check we can actually write to the outfile before going
  81. # through the effort of generating the rollouts:
  82. try:
  83. with open(self._outfile, "wb") as _:
  84. pass
  85. except IOError as x:
  86. print(
  87. "Can not open {} for writing - cancelling rollouts.".format(
  88. self._outfile
  89. )
  90. )
  91. raise x
  92. if self._write_update_file:
  93. # Open a file to track rollout progress:
  94. self._update_file = self._get_tmp_progress_filename().open(mode="w")
  95. return self
  96. def __exit__(self, type, value, traceback):
  97. if self._shelf:
  98. # Close the shelf file, and store the number of episodes for ease
  99. self._shelf["num_episodes"] = self._num_episodes
  100. self._shelf.close()
  101. elif self._outfile and not self._use_shelve:
  102. # Dump everything as one big pickle:
  103. cloudpickle.dump(self._rollouts, open(self._outfile, "wb"))
  104. if self._update_file:
  105. # Remove the temp progress file:
  106. self._get_tmp_progress_filename().unlink()
  107. self._update_file = None
  108. def _get_progress(self):
  109. if self._target_episodes:
  110. return "{} / {} episodes completed".format(
  111. self._num_episodes, self._target_episodes
  112. )
  113. elif self._target_steps:
  114. return "{} / {} steps completed".format(
  115. self._total_steps, self._target_steps
  116. )
  117. else:
  118. return "{} episodes completed".format(self._num_episodes)
  119. def begin_rollout(self):
  120. self._current_rollout = []
  121. def end_rollout(self):
  122. if self._outfile:
  123. if self._use_shelve:
  124. # Save this episode as a new entry in the shelf database,
  125. # using the episode number as the key.
  126. self._shelf[str(self._num_episodes)] = self._current_rollout
  127. else:
  128. # Append this rollout to our list, to save laer.
  129. self._rollouts.append(self._current_rollout)
  130. self._num_episodes += 1
  131. if self._update_file:
  132. self._update_file.seek(0)
  133. self._update_file.write(self._get_progress() + "\n")
  134. self._update_file.flush()
  135. def append_step(self, obs, action, next_obs, reward, terminated, truncated, info):
  136. """Add a step to the current rollout, if we are saving them"""
  137. if self._outfile:
  138. if self._save_info:
  139. self._current_rollout.append(
  140. [obs, action, next_obs, reward, terminated, truncated, info]
  141. )
  142. else:
  143. self._current_rollout.append(
  144. [obs, action, next_obs, reward, terminated, truncated]
  145. )
  146. self._total_steps += 1
  147. @eval_app.command()
  148. def run(
  149. checkpoint: str = cli.Checkpoint,
  150. algo: str = cli.Algo,
  151. env: str = cli.Env,
  152. local_mode: bool = cli.LocalMode,
  153. render: bool = cli.Render,
  154. steps: int = cli.Steps,
  155. episodes: int = cli.Episodes,
  156. out: str = cli.Out,
  157. config: str = cli.Config,
  158. save_info: bool = cli.SaveInfo,
  159. use_shelve: bool = cli.UseShelve,
  160. track_progress: bool = cli.TrackProgress,
  161. ):
  162. if use_shelve and not out:
  163. raise ValueError(
  164. "If you set --use-shelve, you must provide an output file via "
  165. "--out as well!"
  166. )
  167. if track_progress and not out:
  168. raise ValueError(
  169. "If you set --track-progress, you must provide an output file via "
  170. "--out as well!"
  171. )
  172. # Load configuration from checkpoint file.
  173. config_args = json.loads(config)
  174. config_path = ""
  175. if checkpoint:
  176. config_dir = os.path.dirname(checkpoint)
  177. config_path = os.path.join(config_dir, "params.pkl")
  178. # Try parent directory.
  179. if not os.path.exists(config_path):
  180. config_path = os.path.join(config_dir, "../params.pkl")
  181. # Load the config from pickled.
  182. if os.path.exists(config_path):
  183. with open(config_path, "rb") as f:
  184. config = cloudpickle.load(f)
  185. # If no pkl file found, require command line `--config`.
  186. else:
  187. # If no config in given checkpoint -> Error.
  188. if checkpoint:
  189. raise ValueError(
  190. "Could not find params.pkl in either the checkpoint dir or "
  191. "its parent directory AND no `--config` given on command "
  192. "line!"
  193. )
  194. # Use default config for given agent.
  195. if not algo:
  196. raise ValueError("Please provide an algorithm via `--algo`.")
  197. algo_cls = get_trainable_cls(algo)
  198. config = algo_cls.get_default_config()
  199. # Make sure worker 0 has an Env.
  200. config["create_env_on_driver"] = True
  201. # Merge with `evaluation_config` (first try from command line, then from
  202. # pkl file).
  203. evaluation_config = copy.deepcopy(
  204. config_args.get("evaluation_config", config.get("evaluation_config", {}))
  205. )
  206. config = merge_dicts(config, evaluation_config)
  207. # Merge with command line `--config` settings (if not already the same anyways).
  208. config = merge_dicts(config, config_args)
  209. if not env:
  210. if not config.get("env"):
  211. raise ValueError(
  212. "You either need to provide an --env argument or pass"
  213. "an `env` key with a valid environment to your `config`"
  214. "argument."
  215. )
  216. env = config.get("env")
  217. # Make sure we have evaluation workers.
  218. if not config.get(
  219. "evaluation_num_workers", config.get("evaluation_num_env_runners")
  220. ):
  221. config["evaluation_num_env_runners"] = config.get(
  222. "num_env_runners", config.get("num_workers")
  223. )
  224. if not config.get("evaluation_duration"):
  225. config["evaluation_duration"] = 1
  226. # Hard-override this as it raises a warning by Algorithm otherwise.
  227. # Makes no sense anyways, to have it set to None as we don't call
  228. # `Algorithm.train()` here.
  229. config["evaluation_interval"] = 1
  230. # Rendering settings.
  231. config["render_env"] = render
  232. ray.init(local_mode=local_mode)
  233. # Create the Algorithm from config.
  234. cls = get_trainable_cls(algo)
  235. algorithm = cls(config=config)
  236. # Load state from checkpoint, if provided.
  237. if checkpoint:
  238. if os.path.isdir(checkpoint):
  239. checkpoint_dir = checkpoint
  240. else:
  241. checkpoint_dir = str(Path(checkpoint).parent)
  242. print(f"Restoring algorithm from {checkpoint_dir}")
  243. restore_result = _TrainingResult(
  244. checkpoint=Checkpoint.from_directory(checkpoint_dir), metrics={}
  245. )
  246. algorithm.restore(restore_result)
  247. # Do the actual rollout.
  248. with RolloutSaver(
  249. outfile=out,
  250. use_shelve=use_shelve,
  251. write_update_file=track_progress,
  252. target_steps=steps,
  253. target_episodes=episodes,
  254. save_info=save_info,
  255. ) as saver:
  256. rollout(algorithm, env, steps, episodes, saver, not render)
  257. algorithm.stop()
  258. class DefaultMapping(collections.defaultdict):
  259. """default_factory now takes as an argument the missing key."""
  260. def __missing__(self, key):
  261. self[key] = value = self.default_factory(key)
  262. return value
  263. def default_policy_agent_mapping(unused_agent_id) -> str:
  264. return DEFAULT_POLICY_ID
  265. def keep_going(steps: int, num_steps: int, episodes: int, num_episodes: int) -> bool:
  266. """Determine whether we've run enough steps or episodes."""
  267. episode_limit_reached = num_episodes and episodes >= num_episodes
  268. step_limit_reached = num_steps and steps >= num_steps
  269. return False if episode_limit_reached or step_limit_reached else True
  270. def rollout(
  271. agent,
  272. env_name, # keep me, used in tests
  273. num_steps,
  274. num_episodes=0,
  275. saver=None,
  276. no_render=True,
  277. ):
  278. policy_agent_mapping = default_policy_agent_mapping
  279. if saver is None:
  280. saver = RolloutSaver()
  281. # Normal case: Agent was setup correctly with an evaluation EnvRunnerGroup,
  282. # which we will now use to rollout.
  283. if hasattr(agent, "eval_env_runner_group") and isinstance(
  284. agent.eval_env_runner_group, EnvRunnerGroup
  285. ):
  286. steps = 0
  287. episodes = 0
  288. while keep_going(steps, num_steps, episodes, num_episodes):
  289. saver.begin_rollout()
  290. eval_result = agent.evaluate()
  291. # Increase time-step and episode counters.
  292. eps = agent.config["evaluation_duration"]
  293. episodes += eps
  294. steps += eps * eval_result[ENV_RUNNER_RESULTS][EPISODE_LEN_MEAN]
  295. # Print out results and continue.
  296. print(
  297. "Episode #{}: reward: {}".format(
  298. episodes, eval_result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
  299. )
  300. )
  301. saver.end_rollout()
  302. return
  303. # Agent has no evaluation workers, but RolloutWorkers.
  304. elif hasattr(agent, "env_runner_group") and isinstance(
  305. agent.env_runner_group, EnvRunnerGroup
  306. ):
  307. env = agent.env_runner.env
  308. multiagent = isinstance(env, MultiAgentEnv)
  309. if agent.env_runner.multiagent:
  310. policy_agent_mapping = agent.config.policy_mapping_fn
  311. policy_map = agent.env_runner.policy_map
  312. state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
  313. use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
  314. # Agent has neither evaluation- nor rollout workers.
  315. else:
  316. from gymnasium import envs
  317. if envs.registry.env_specs.get(agent.config["env"]):
  318. # if environment is gym environment, load from gym
  319. env = gym.make(agent.config["env"])
  320. else:
  321. # if environment registered ray environment, load from ray
  322. env_creator = _global_registry.get(ENV_CREATOR, agent.config["env"])
  323. env_context = EnvContext(agent.config["env_config"] or {}, worker_index=0)
  324. env = env_creator(env_context)
  325. multiagent = False
  326. try:
  327. policy_map = {DEFAULT_POLICY_ID: agent.policy}
  328. except AttributeError:
  329. raise AttributeError(
  330. "Agent ({}) does not have a `policy` property! This is needed "
  331. "for performing (trained) agent rollouts.".format(agent)
  332. )
  333. use_lstm = {DEFAULT_POLICY_ID: False}
  334. action_init = {
  335. p: flatten_to_single_ndarray(m.action_space.sample())
  336. for p, m in policy_map.items()
  337. }
  338. steps = 0
  339. episodes = 0
  340. while keep_going(steps, num_steps, episodes, num_episodes):
  341. mapping_cache = {} # in case policy_agent_mapping is stochastic
  342. saver.begin_rollout()
  343. obs, info = env.reset()
  344. agent_states = DefaultMapping(
  345. lambda agent_id: state_init[mapping_cache[agent_id]] # noqa
  346. )
  347. prev_actions = DefaultMapping(
  348. lambda agent_id: action_init[mapping_cache[agent_id]] # noqa
  349. )
  350. prev_rewards = collections.defaultdict(lambda: 0.0)
  351. terminated = truncated = False
  352. reward_total = 0.0
  353. while (
  354. not terminated
  355. and not truncated
  356. and keep_going(steps, num_steps, episodes, num_episodes)
  357. ):
  358. multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
  359. action_dict = {}
  360. for agent_id, a_obs in multi_obs.items():
  361. if a_obs is not None:
  362. policy_id = mapping_cache.setdefault(
  363. agent_id, policy_agent_mapping(agent_id)
  364. )
  365. p_use_lstm = use_lstm[policy_id]
  366. if p_use_lstm:
  367. a_action, p_state, _ = agent.compute_single_action(
  368. a_obs,
  369. state=agent_states[agent_id],
  370. prev_action=prev_actions[agent_id],
  371. prev_reward=prev_rewards[agent_id],
  372. policy_id=policy_id,
  373. )
  374. agent_states[agent_id] = p_state
  375. else:
  376. a_action = agent.compute_single_action(
  377. a_obs,
  378. prev_action=prev_actions[agent_id],
  379. prev_reward=prev_rewards[agent_id],
  380. policy_id=policy_id,
  381. )
  382. a_action = flatten_to_single_ndarray(a_action)
  383. action_dict[agent_id] = a_action
  384. prev_actions[agent_id] = a_action
  385. action = action_dict
  386. action = action if multiagent else action[_DUMMY_AGENT_ID]
  387. next_obs, reward, terminated, truncated, info = env.step(action)
  388. if multiagent:
  389. for agent_id, r in reward.items():
  390. prev_rewards[agent_id] = r
  391. else:
  392. prev_rewards[_DUMMY_AGENT_ID] = reward
  393. if multiagent:
  394. terminated = terminated["__all__"]
  395. truncated = truncated["__all__"]
  396. reward_total += sum(r for r in reward.values() if r is not None)
  397. else:
  398. reward_total += reward
  399. if not no_render:
  400. env.render()
  401. saver.append_step(
  402. obs, action, next_obs, reward, terminated, truncated, info
  403. )
  404. steps += 1
  405. obs = next_obs
  406. saver.end_rollout()
  407. print("Episode #{}: reward: {}".format(episodes, reward_total))
  408. if terminated or truncated:
  409. episodes += 1
  410. def main():
  411. """Run the CLI."""
  412. eval_app()
  413. if __name__ == "__main__":
  414. main()