evaluate.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. #!/usr/bin/env python
  2. import collections
  3. import copy
  4. import gymnasium as gym
  5. import json
  6. import os
  7. from pathlib import Path
  8. import shelve
  9. import typer
  10. import ray
  11. import ray.cloudpickle as cloudpickle
  12. from ray.rllib.env import MultiAgentEnv
  13. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  14. from ray.rllib.env.env_context import EnvContext
  15. from ray.rllib.evaluation.worker_set import WorkerSet
  16. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  17. from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
  18. from ray.rllib.common import CLIArguments as cli
  19. from ray.train._checkpoint import Checkpoint
  20. from ray.train._internal.session import _TrainingResult
  21. from ray.tune.utils import merge_dicts
  22. from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
  23. # create the "evaluate" Typer app
  24. eval_app = typer.Typer()
  25. class RolloutSaver:
  26. """Utility class for storing rollouts.
  27. Currently supports two behaviours: the original, which
  28. simply dumps everything to a pickle file once complete,
  29. and a mode which stores each rollout as an entry in a Python
  30. shelf db file. The latter mode is more robust to memory problems
  31. or crashes part-way through the rollout generation. Each rollout
  32. is stored with a key based on the episode number (0-indexed),
  33. and the number of episodes is stored with the key "num_episodes",
  34. so to load the shelf file, use something like:
  35. with shelve.open('rollouts.pkl') as rollouts:
  36. for episode_index in range(rollouts["num_episodes"]):
  37. rollout = rollouts[str(episode_index)]
  38. If outfile is None, this class does nothing.
  39. """
  40. def __init__(
  41. self,
  42. outfile=None,
  43. use_shelve=False,
  44. write_update_file=False,
  45. target_steps=None,
  46. target_episodes=None,
  47. save_info=False,
  48. ):
  49. self._outfile = outfile
  50. self._update_file = None
  51. self._use_shelve = use_shelve
  52. self._write_update_file = write_update_file
  53. self._shelf = None
  54. self._num_episodes = 0
  55. self._rollouts = []
  56. self._current_rollout = []
  57. self._total_steps = 0
  58. self._target_episodes = target_episodes
  59. self._target_steps = target_steps
  60. self._save_info = save_info
  61. def _get_tmp_progress_filename(self):
  62. outpath = Path(self._outfile)
  63. return outpath.parent / ("__progress_" + outpath.name)
  64. @property
  65. def outfile(self):
  66. return self._outfile
  67. def __enter__(self):
  68. if self._outfile:
  69. if self._use_shelve:
  70. # Open a shelf file to store each rollout as they come in
  71. self._shelf = shelve.open(self._outfile)
  72. else:
  73. # Original behaviour - keep all rollouts in memory and save
  74. # them all at the end.
  75. # But check we can actually write to the outfile before going
  76. # through the effort of generating the rollouts:
  77. try:
  78. with open(self._outfile, "wb") as _:
  79. pass
  80. except IOError as x:
  81. print(
  82. "Can not open {} for writing - cancelling rollouts.".format(
  83. self._outfile
  84. )
  85. )
  86. raise x
  87. if self._write_update_file:
  88. # Open a file to track rollout progress:
  89. self._update_file = self._get_tmp_progress_filename().open(mode="w")
  90. return self
  91. def __exit__(self, type, value, traceback):
  92. if self._shelf:
  93. # Close the shelf file, and store the number of episodes for ease
  94. self._shelf["num_episodes"] = self._num_episodes
  95. self._shelf.close()
  96. elif self._outfile and not self._use_shelve:
  97. # Dump everything as one big pickle:
  98. cloudpickle.dump(self._rollouts, open(self._outfile, "wb"))
  99. if self._update_file:
  100. # Remove the temp progress file:
  101. self._get_tmp_progress_filename().unlink()
  102. self._update_file = None
  103. def _get_progress(self):
  104. if self._target_episodes:
  105. return "{} / {} episodes completed".format(
  106. self._num_episodes, self._target_episodes
  107. )
  108. elif self._target_steps:
  109. return "{} / {} steps completed".format(
  110. self._total_steps, self._target_steps
  111. )
  112. else:
  113. return "{} episodes completed".format(self._num_episodes)
  114. def begin_rollout(self):
  115. self._current_rollout = []
  116. def end_rollout(self):
  117. if self._outfile:
  118. if self._use_shelve:
  119. # Save this episode as a new entry in the shelf database,
  120. # using the episode number as the key.
  121. self._shelf[str(self._num_episodes)] = self._current_rollout
  122. else:
  123. # Append this rollout to our list, to save laer.
  124. self._rollouts.append(self._current_rollout)
  125. self._num_episodes += 1
  126. if self._update_file:
  127. self._update_file.seek(0)
  128. self._update_file.write(self._get_progress() + "\n")
  129. self._update_file.flush()
  130. def append_step(self, obs, action, next_obs, reward, terminated, truncated, info):
  131. """Add a step to the current rollout, if we are saving them"""
  132. if self._outfile:
  133. if self._save_info:
  134. self._current_rollout.append(
  135. [obs, action, next_obs, reward, terminated, truncated, info]
  136. )
  137. else:
  138. self._current_rollout.append(
  139. [obs, action, next_obs, reward, terminated, truncated]
  140. )
  141. self._total_steps += 1
  142. @eval_app.command()
  143. def run(
  144. checkpoint: str = cli.Checkpoint,
  145. algo: str = cli.Algo,
  146. env: str = cli.Env,
  147. local_mode: bool = cli.LocalMode,
  148. render: bool = cli.Render,
  149. steps: int = cli.Steps,
  150. episodes: int = cli.Episodes,
  151. out: str = cli.Out,
  152. config: str = cli.Config,
  153. save_info: bool = cli.SaveInfo,
  154. use_shelve: bool = cli.UseShelve,
  155. track_progress: bool = cli.TrackProgress,
  156. ):
  157. if use_shelve and not out:
  158. raise ValueError(
  159. "If you set --use-shelve, you must provide an output file via "
  160. "--out as well!"
  161. )
  162. if track_progress and not out:
  163. raise ValueError(
  164. "If you set --track-progress, you must provide an output file via "
  165. "--out as well!"
  166. )
  167. # Load configuration from checkpoint file.
  168. config_args = json.loads(config)
  169. config_path = ""
  170. if checkpoint:
  171. config_dir = os.path.dirname(checkpoint)
  172. config_path = os.path.join(config_dir, "params.pkl")
  173. # Try parent directory.
  174. if not os.path.exists(config_path):
  175. config_path = os.path.join(config_dir, "../params.pkl")
  176. # Load the config from pickled.
  177. if os.path.exists(config_path):
  178. with open(config_path, "rb") as f:
  179. config = cloudpickle.load(f)
  180. # If no pkl file found, require command line `--config`.
  181. else:
  182. # If no config in given checkpoint -> Error.
  183. if checkpoint:
  184. raise ValueError(
  185. "Could not find params.pkl in either the checkpoint dir or "
  186. "its parent directory AND no `--config` given on command "
  187. "line!"
  188. )
  189. # Use default config for given agent.
  190. if not algo:
  191. raise ValueError("Please provide an algorithm via `--algo`.")
  192. algo_cls = get_trainable_cls(algo)
  193. config = algo_cls.get_default_config()
  194. # Make sure worker 0 has an Env.
  195. config["create_env_on_driver"] = True
  196. # Merge with `evaluation_config` (first try from command line, then from
  197. # pkl file).
  198. evaluation_config = copy.deepcopy(
  199. config_args.get("evaluation_config", config.get("evaluation_config", {}))
  200. )
  201. config = merge_dicts(config, evaluation_config)
  202. # Merge with command line `--config` settings (if not already the same anyways).
  203. config = merge_dicts(config, config_args)
  204. if not env:
  205. if not config.get("env"):
  206. raise ValueError(
  207. "You either need to provide an --env argument or pass"
  208. "an `env` key with a valid environment to your `config`"
  209. "argument."
  210. )
  211. env = config.get("env")
  212. # Make sure we have evaluation workers.
  213. if not config.get("evaluation_num_workers"):
  214. config["evaluation_num_workers"] = config.get("num_workers", 0)
  215. if not config.get("evaluation_duration"):
  216. config["evaluation_duration"] = 1
  217. # Hard-override this as it raises a warning by Algorithm otherwise.
  218. # Makes no sense anyways, to have it set to None as we don't call
  219. # `Algorithm.train()` here.
  220. config["evaluation_interval"] = 1
  221. # Rendering settings.
  222. config["render_env"] = render
  223. ray.init(local_mode=local_mode)
  224. # Create the Algorithm from config.
  225. cls = get_trainable_cls(algo)
  226. algorithm = cls(env=env, config=config)
  227. # Load state from checkpoint, if provided.
  228. if checkpoint:
  229. if os.path.isdir(checkpoint):
  230. checkpoint_dir = checkpoint
  231. else:
  232. checkpoint_dir = str(Path(checkpoint).parent)
  233. print(f"Restoring algorithm from {checkpoint_dir}")
  234. restore_result = _TrainingResult(
  235. checkpoint=Checkpoint.from_directory(checkpoint_dir), metrics={}
  236. )
  237. algorithm.restore(restore_result)
  238. # Do the actual rollout.
  239. with RolloutSaver(
  240. outfile=out,
  241. use_shelve=use_shelve,
  242. write_update_file=track_progress,
  243. target_steps=steps,
  244. target_episodes=episodes,
  245. save_info=save_info,
  246. ) as saver:
  247. rollout(algorithm, env, steps, episodes, saver, not render)
  248. algorithm.stop()
  249. class DefaultMapping(collections.defaultdict):
  250. """default_factory now takes as an argument the missing key."""
  251. def __missing__(self, key):
  252. self[key] = value = self.default_factory(key)
  253. return value
  254. def default_policy_agent_mapping(unused_agent_id) -> str:
  255. return DEFAULT_POLICY_ID
  256. def keep_going(steps: int, num_steps: int, episodes: int, num_episodes: int) -> bool:
  257. """Determine whether we've run enough steps or episodes."""
  258. episode_limit_reached = num_episodes and episodes >= num_episodes
  259. step_limit_reached = num_steps and steps >= num_steps
  260. return False if episode_limit_reached or step_limit_reached else True
  261. def rollout(
  262. agent,
  263. env_name, # keep me, used in tests
  264. num_steps,
  265. num_episodes=0,
  266. saver=None,
  267. no_render=True,
  268. ):
  269. policy_agent_mapping = default_policy_agent_mapping
  270. if saver is None:
  271. saver = RolloutSaver()
  272. # Normal case: Agent was setup correctly with an evaluation WorkerSet,
  273. # which we will now use to rollout.
  274. if hasattr(agent, "evaluation_workers") and isinstance(
  275. agent.evaluation_workers, WorkerSet
  276. ):
  277. steps = 0
  278. episodes = 0
  279. while keep_going(steps, num_steps, episodes, num_episodes):
  280. saver.begin_rollout()
  281. eval_result = agent.evaluate()["evaluation"]
  282. # Increase time-step and episode counters.
  283. eps = agent.config["evaluation_duration"]
  284. episodes += eps
  285. steps += eps * eval_result["episode_len_mean"]
  286. # Print out results and continue.
  287. print(
  288. "Episode #{}: reward: {}".format(
  289. episodes, eval_result["episode_reward_mean"]
  290. )
  291. )
  292. saver.end_rollout()
  293. return
  294. # Agent has no evaluation workers, but RolloutWorkers.
  295. elif hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet):
  296. env = agent.workers.local_worker().env
  297. multiagent = isinstance(env, MultiAgentEnv)
  298. if agent.workers.local_worker().multiagent:
  299. policy_agent_mapping = agent.config.policy_mapping_fn
  300. policy_map = agent.workers.local_worker().policy_map
  301. state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
  302. use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
  303. # Agent has neither evaluation- nor rollout workers.
  304. else:
  305. from gymnasium import envs
  306. if envs.registry.env_specs.get(agent.config["env"]):
  307. # if environment is gym environment, load from gym
  308. env = gym.make(agent.config["env"])
  309. else:
  310. # if environment registered ray environment, load from ray
  311. env_creator = _global_registry.get(ENV_CREATOR, agent.config["env"])
  312. env_context = EnvContext(agent.config["env_config"] or {}, worker_index=0)
  313. env = env_creator(env_context)
  314. multiagent = False
  315. try:
  316. policy_map = {DEFAULT_POLICY_ID: agent.policy}
  317. except AttributeError:
  318. raise AttributeError(
  319. "Agent ({}) does not have a `policy` property! This is needed "
  320. "for performing (trained) agent rollouts.".format(agent)
  321. )
  322. use_lstm = {DEFAULT_POLICY_ID: False}
  323. action_init = {
  324. p: flatten_to_single_ndarray(m.action_space.sample())
  325. for p, m in policy_map.items()
  326. }
  327. steps = 0
  328. episodes = 0
  329. while keep_going(steps, num_steps, episodes, num_episodes):
  330. mapping_cache = {} # in case policy_agent_mapping is stochastic
  331. saver.begin_rollout()
  332. obs, info = env.reset()
  333. agent_states = DefaultMapping(
  334. lambda agent_id: state_init[mapping_cache[agent_id]]
  335. )
  336. prev_actions = DefaultMapping(
  337. lambda agent_id: action_init[mapping_cache[agent_id]]
  338. )
  339. prev_rewards = collections.defaultdict(lambda: 0.0)
  340. terminated = truncated = False
  341. reward_total = 0.0
  342. while (
  343. not terminated
  344. and not truncated
  345. and keep_going(steps, num_steps, episodes, num_episodes)
  346. ):
  347. multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
  348. action_dict = {}
  349. for agent_id, a_obs in multi_obs.items():
  350. if a_obs is not None:
  351. policy_id = mapping_cache.setdefault(
  352. agent_id, policy_agent_mapping(agent_id)
  353. )
  354. p_use_lstm = use_lstm[policy_id]
  355. if p_use_lstm:
  356. a_action, p_state, _ = agent.compute_single_action(
  357. a_obs,
  358. state=agent_states[agent_id],
  359. prev_action=prev_actions[agent_id],
  360. prev_reward=prev_rewards[agent_id],
  361. policy_id=policy_id,
  362. )
  363. agent_states[agent_id] = p_state
  364. else:
  365. a_action = agent.compute_single_action(
  366. a_obs,
  367. prev_action=prev_actions[agent_id],
  368. prev_reward=prev_rewards[agent_id],
  369. policy_id=policy_id,
  370. )
  371. a_action = flatten_to_single_ndarray(a_action)
  372. action_dict[agent_id] = a_action
  373. prev_actions[agent_id] = a_action
  374. action = action_dict
  375. action = action if multiagent else action[_DUMMY_AGENT_ID]
  376. next_obs, reward, terminated, truncated, info = env.step(action)
  377. if multiagent:
  378. for agent_id, r in reward.items():
  379. prev_rewards[agent_id] = r
  380. else:
  381. prev_rewards[_DUMMY_AGENT_ID] = reward
  382. if multiagent:
  383. terminated = terminated["__all__"]
  384. truncated = truncated["__all__"]
  385. reward_total += sum(r for r in reward.values() if r is not None)
  386. else:
  387. reward_total += reward
  388. if not no_render:
  389. env.render()
  390. saver.append_step(
  391. obs, action, next_obs, reward, terminated, truncated, info
  392. )
  393. steps += 1
  394. obs = next_obs
  395. saver.end_rollout()
  396. print("Episode #{}: reward: {}".format(episodes, reward_total))
  397. if terminated or truncated:
  398. episodes += 1
  399. def main():
  400. """Run the CLI."""
  401. eval_app()
  402. if __name__ == "__main__":
  403. main()