evaluate.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. #!/usr/bin/env python
  2. import argparse
  3. import collections
  4. import copy
  5. import gym
  6. from gym import wrappers as gym_wrappers
  7. import json
  8. import os
  9. from pathlib import Path
  10. import shelve
  11. import ray
  12. import ray.cloudpickle as cloudpickle
  13. from ray.rllib.agents.registry import get_trainer_class
  14. from ray.rllib.env import MultiAgentEnv
  15. from ray.rllib.env.base_env import _DUMMY_AGENT_ID
  16. from ray.rllib.env.env_context import EnvContext
  17. from ray.rllib.evaluation.worker_set import WorkerSet
  18. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
  19. from ray.rllib.utils.deprecation import deprecation_warning
  20. from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
  21. from ray.tune.utils import merge_dicts
  22. from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
  23. EXAMPLE_USAGE = """
  24. Example usage via RLlib CLI:
  25. rllib evaluate /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
  26. --env CartPole-v0 --steps 1000000 --out rollouts.pkl
  27. Example usage via executable:
  28. ./evaluate.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
  29. --env CartPole-v0 --steps 1000000 --out rollouts.pkl
  30. Example usage w/o checkpoint (for testing purposes):
  31. ./evaluate.py --run PPO --env CartPole-v0 --episodes 500
  32. """
  33. # Note: if you use any custom models or envs, register them here first, e.g.:
  34. #
  35. # from ray.rllib.examples.env.parametric_actions_cartpole import \
  36. # ParametricActionsCartPole
  37. # from ray.rllib.examples.model.parametric_actions_model import \
  38. # ParametricActionsModel
  39. # ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
  40. # register_env("pa_cartpole", lambda _: ParametricActionsCartPole(10))
  41. def create_parser(parser_creator=None):
  42. parser_creator = parser_creator or argparse.ArgumentParser
  43. parser = parser_creator(
  44. formatter_class=argparse.RawDescriptionHelpFormatter,
  45. description="Roll out a reinforcement learning agent "
  46. "given a checkpoint.",
  47. epilog=EXAMPLE_USAGE)
  48. parser.add_argument(
  49. "checkpoint",
  50. type=str,
  51. nargs="?",
  52. help="(Optional) checkpoint from which to roll out. "
  53. "If none given, will use an initial (untrained) Trainer.")
  54. required_named = parser.add_argument_group("required named arguments")
  55. required_named.add_argument(
  56. "--run",
  57. type=str,
  58. required=True,
  59. help="The algorithm or model to train. This may refer to the name "
  60. "of a built-on algorithm (e.g. RLLib's `DQN` or `PPO`), or a "
  61. "user-defined trainable function or class registered in the "
  62. "tune registry.")
  63. required_named.add_argument(
  64. "--env",
  65. type=str,
  66. help="The environment specifier to use. This could be an openAI gym "
  67. "specifier (e.g. `CartPole-v0`) or a full class-path (e.g. "
  68. "`ray.rllib.examples.env.simple_corridor.SimpleCorridor`).")
  69. parser.add_argument(
  70. "--local-mode",
  71. action="store_true",
  72. help="Run ray in local mode for easier debugging.")
  73. parser.add_argument(
  74. "--render",
  75. action="store_true",
  76. help="Render the environment while evaluating.")
  77. # Deprecated: Use --render, instead.
  78. parser.add_argument(
  79. "--no-render",
  80. default=False,
  81. action="store_const",
  82. const=True,
  83. help="Deprecated! Rendering is off by default now. "
  84. "Use `--render` to enable.")
  85. parser.add_argument(
  86. "--video-dir",
  87. type=str,
  88. default=None,
  89. help="Specifies the directory into which videos of all episode "
  90. "rollouts will be stored.")
  91. parser.add_argument(
  92. "--steps",
  93. default=10000,
  94. help="Number of timesteps to roll out. Rollout will also stop if "
  95. "`--episodes` limit is reached first. A value of 0 means no "
  96. "limitation on the number of timesteps run.")
  97. parser.add_argument(
  98. "--episodes",
  99. default=0,
  100. help="Number of complete episodes to roll out. Rollout will also stop "
  101. "if `--steps` (timesteps) limit is reached first. A value of 0 means "
  102. "no limitation on the number of episodes run.")
  103. parser.add_argument("--out", default=None, help="Output filename.")
  104. parser.add_argument(
  105. "--config",
  106. default="{}",
  107. type=json.loads,
  108. help="Algorithm-specific configuration (e.g. env, hyperparams). "
  109. "Gets merged with loaded configuration from checkpoint file and "
  110. "`evaluation_config` settings therein.")
  111. parser.add_argument(
  112. "--save-info",
  113. default=False,
  114. action="store_true",
  115. help="Save the info field generated by the step() method, "
  116. "as well as the action, observations, rewards and done fields.")
  117. parser.add_argument(
  118. "--use-shelve",
  119. default=False,
  120. action="store_true",
  121. help="Save rollouts into a python shelf file (will save each episode "
  122. "as it is generated). An output filename must be set using --out.")
  123. parser.add_argument(
  124. "--track-progress",
  125. default=False,
  126. action="store_true",
  127. help="Write progress to a temporary file (updated "
  128. "after each episode). An output filename must be set using --out; "
  129. "the progress file will live in the same folder.")
  130. return parser
  131. class RolloutSaver:
  132. """Utility class for storing rollouts.
  133. Currently supports two behaviours: the original, which
  134. simply dumps everything to a pickle file once complete,
  135. and a mode which stores each rollout as an entry in a Python
  136. shelf db file. The latter mode is more robust to memory problems
  137. or crashes part-way through the rollout generation. Each rollout
  138. is stored with a key based on the episode number (0-indexed),
  139. and the number of episodes is stored with the key "num_episodes",
  140. so to load the shelf file, use something like:
  141. with shelve.open('rollouts.pkl') as rollouts:
  142. for episode_index in range(rollouts["num_episodes"]):
  143. rollout = rollouts[str(episode_index)]
  144. If outfile is None, this class does nothing.
  145. """
  146. def __init__(self,
  147. outfile=None,
  148. use_shelve=False,
  149. write_update_file=False,
  150. target_steps=None,
  151. target_episodes=None,
  152. save_info=False):
  153. self._outfile = outfile
  154. self._update_file = None
  155. self._use_shelve = use_shelve
  156. self._write_update_file = write_update_file
  157. self._shelf = None
  158. self._num_episodes = 0
  159. self._rollouts = []
  160. self._current_rollout = []
  161. self._total_steps = 0
  162. self._target_episodes = target_episodes
  163. self._target_steps = target_steps
  164. self._save_info = save_info
  165. def _get_tmp_progress_filename(self):
  166. outpath = Path(self._outfile)
  167. return outpath.parent / ("__progress_" + outpath.name)
  168. @property
  169. def outfile(self):
  170. return self._outfile
  171. def __enter__(self):
  172. if self._outfile:
  173. if self._use_shelve:
  174. # Open a shelf file to store each rollout as they come in
  175. self._shelf = shelve.open(self._outfile)
  176. else:
  177. # Original behaviour - keep all rollouts in memory and save
  178. # them all at the end.
  179. # But check we can actually write to the outfile before going
  180. # through the effort of generating the rollouts:
  181. try:
  182. with open(self._outfile, "wb") as _:
  183. pass
  184. except IOError as x:
  185. print("Can not open {} for writing - cancelling rollouts.".
  186. format(self._outfile))
  187. raise x
  188. if self._write_update_file:
  189. # Open a file to track rollout progress:
  190. self._update_file = self._get_tmp_progress_filename().open(
  191. mode="w")
  192. return self
  193. def __exit__(self, type, value, traceback):
  194. if self._shelf:
  195. # Close the shelf file, and store the number of episodes for ease
  196. self._shelf["num_episodes"] = self._num_episodes
  197. self._shelf.close()
  198. elif self._outfile and not self._use_shelve:
  199. # Dump everything as one big pickle:
  200. cloudpickle.dump(self._rollouts, open(self._outfile, "wb"))
  201. if self._update_file:
  202. # Remove the temp progress file:
  203. self._get_tmp_progress_filename().unlink()
  204. self._update_file = None
  205. def _get_progress(self):
  206. if self._target_episodes:
  207. return "{} / {} episodes completed".format(self._num_episodes,
  208. self._target_episodes)
  209. elif self._target_steps:
  210. return "{} / {} steps completed".format(self._total_steps,
  211. self._target_steps)
  212. else:
  213. return "{} episodes completed".format(self._num_episodes)
  214. def begin_rollout(self):
  215. self._current_rollout = []
  216. def end_rollout(self):
  217. if self._outfile:
  218. if self._use_shelve:
  219. # Save this episode as a new entry in the shelf database,
  220. # using the episode number as the key.
  221. self._shelf[str(self._num_episodes)] = self._current_rollout
  222. else:
  223. # Append this rollout to our list, to save laer.
  224. self._rollouts.append(self._current_rollout)
  225. self._num_episodes += 1
  226. if self._update_file:
  227. self._update_file.seek(0)
  228. self._update_file.write(self._get_progress() + "\n")
  229. self._update_file.flush()
  230. def append_step(self, obs, action, next_obs, reward, done, info):
  231. """Add a step to the current rollout, if we are saving them"""
  232. if self._outfile:
  233. if self._save_info:
  234. self._current_rollout.append(
  235. [obs, action, next_obs, reward, done, info])
  236. else:
  237. self._current_rollout.append(
  238. [obs, action, next_obs, reward, done])
  239. self._total_steps += 1
  240. def run(args, parser):
  241. # Load configuration from checkpoint file.
  242. config_path = ""
  243. if args.checkpoint:
  244. config_dir = os.path.dirname(args.checkpoint)
  245. config_path = os.path.join(config_dir, "params.pkl")
  246. # Try parent directory.
  247. if not os.path.exists(config_path):
  248. config_path = os.path.join(config_dir, "../params.pkl")
  249. # Load the config from pickled.
  250. if os.path.exists(config_path):
  251. with open(config_path, "rb") as f:
  252. config = cloudpickle.load(f)
  253. # If no pkl file found, require command line `--config`.
  254. else:
  255. # If no config in given checkpoint -> Error.
  256. if args.checkpoint:
  257. raise ValueError(
  258. "Could not find params.pkl in either the checkpoint dir or "
  259. "its parent directory AND no `--config` given on command "
  260. "line!")
  261. # Use default config for given agent.
  262. _, config = get_trainer_class(args.run, return_config=True)
  263. # Make sure worker 0 has an Env.
  264. config["create_env_on_driver"] = True
  265. # Merge with `evaluation_config` (first try from command line, then from
  266. # pkl file).
  267. evaluation_config = copy.deepcopy(
  268. args.config.get("evaluation_config", config.get(
  269. "evaluation_config", {})))
  270. config = merge_dicts(config, evaluation_config)
  271. # Merge with command line `--config` settings (if not already the same
  272. # anyways).
  273. config = merge_dicts(config, args.config)
  274. if not args.env:
  275. if not config.get("env"):
  276. parser.error("the following arguments are required: --env")
  277. args.env = config.get("env")
  278. # Make sure we have evaluation workers.
  279. if not config.get("evaluation_num_workers"):
  280. config["evaluation_num_workers"] = config.get("num_workers", 0)
  281. if not config.get("evaluation_duration"):
  282. config["evaluation_duration"] = 1
  283. # Hard-override this as it raises a warning by Trainer otherwise.
  284. # Makes no sense anyways, to have it set to None as we don't call
  285. # `Trainer.train()` here.
  286. config["evaluation_interval"] = 1
  287. # Rendering and video recording settings.
  288. if args.no_render:
  289. deprecation_warning(old="--no-render", new="--render", error=False)
  290. args.render = False
  291. config["render_env"] = args.render
  292. config["record_env"] = args.video_dir
  293. ray.init(local_mode=args.local_mode)
  294. # Create the Trainer from config.
  295. cls = get_trainable_cls(args.run)
  296. agent = cls(env=args.env, config=config)
  297. # Load state from checkpoint, if provided.
  298. if args.checkpoint:
  299. agent.restore(args.checkpoint)
  300. num_steps = int(args.steps)
  301. num_episodes = int(args.episodes)
  302. # Determine the video output directory.
  303. video_dir = None
  304. # Allow user to specify a video output path.
  305. if args.video_dir:
  306. video_dir = os.path.expanduser(args.video_dir)
  307. # Do the actual rollout.
  308. with RolloutSaver(
  309. args.out,
  310. args.use_shelve,
  311. write_update_file=args.track_progress,
  312. target_steps=num_steps,
  313. target_episodes=num_episodes,
  314. save_info=args.save_info) as saver:
  315. rollout(agent, args.env, num_steps, num_episodes, saver,
  316. not args.render, video_dir)
  317. agent.stop()
  318. class DefaultMapping(collections.defaultdict):
  319. """default_factory now takes as an argument the missing key."""
  320. def __missing__(self, key):
  321. self[key] = value = self.default_factory(key)
  322. return value
  323. def default_policy_agent_mapping(unused_agent_id):
  324. return DEFAULT_POLICY_ID
  325. def keep_going(steps, num_steps, episodes, num_episodes):
  326. """Determine whether we've collected enough data"""
  327. # If num_episodes is set, stop if limit reached.
  328. if num_episodes and episodes >= num_episodes:
  329. return False
  330. # If num_steps is set, stop if limit reached.
  331. elif num_steps and steps >= num_steps:
  332. return False
  333. # Otherwise, keep going.
  334. return True
  335. def rollout(agent,
  336. env_name,
  337. num_steps,
  338. num_episodes=0,
  339. saver=None,
  340. no_render=True,
  341. video_dir=None):
  342. policy_agent_mapping = default_policy_agent_mapping
  343. if saver is None:
  344. saver = RolloutSaver()
  345. # Normal case: Agent was setup correctly with an evaluation WorkerSet,
  346. # which we will now use to rollout.
  347. if hasattr(agent, "evaluation_workers") and isinstance(
  348. agent.evaluation_workers, WorkerSet):
  349. steps = 0
  350. episodes = 0
  351. while keep_going(steps, num_steps, episodes, num_episodes):
  352. saver.begin_rollout()
  353. eval_result = agent.evaluate()["evaluation"]
  354. # Increase timestep and episode counters.
  355. eps = agent.config["evaluation_duration"]
  356. episodes += eps
  357. steps += eps * eval_result["episode_len_mean"]
  358. # Print out results and continue.
  359. print("Episode #{}: reward: {}".format(
  360. episodes, eval_result["episode_reward_mean"]))
  361. saver.end_rollout()
  362. return
  363. # Agent has no evaluation workers, but RolloutWorkers.
  364. elif hasattr(agent, "workers") and isinstance(agent.workers, WorkerSet):
  365. env = agent.workers.local_worker().env
  366. multiagent = isinstance(env, MultiAgentEnv)
  367. if agent.workers.local_worker().multiagent:
  368. policy_agent_mapping = agent.config["multiagent"][
  369. "policy_mapping_fn"]
  370. policy_map = agent.workers.local_worker().policy_map
  371. state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
  372. use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
  373. # Agent has neither evaluation- nor rollout workers.
  374. else:
  375. from gym import envs
  376. if envs.registry.env_specs.get(agent.config["env"]):
  377. # if environment is gym environment, load from gym
  378. env = gym.make(agent.config["env"])
  379. else:
  380. # if environment registered ray environment, load from ray
  381. env_creator = _global_registry.get(ENV_CREATOR,
  382. agent.config["env"])
  383. env_context = EnvContext(
  384. agent.config["env_config"] or {}, worker_index=0)
  385. env = env_creator(env_context)
  386. multiagent = False
  387. try:
  388. policy_map = {DEFAULT_POLICY_ID: agent.policy}
  389. except AttributeError:
  390. raise AttributeError(
  391. "Agent ({}) does not have a `policy` property! This is needed "
  392. "for performing (trained) agent rollouts.".format(agent))
  393. use_lstm = {DEFAULT_POLICY_ID: False}
  394. action_init = {
  395. p: flatten_to_single_ndarray(m.action_space.sample())
  396. for p, m in policy_map.items()
  397. }
  398. # If monitoring has been requested, manually wrap our environment with a
  399. # gym monitor, which is set to record every episode.
  400. if video_dir:
  401. env = gym_wrappers.Monitor(
  402. env=env,
  403. directory=video_dir,
  404. video_callable=lambda _: True,
  405. force=True)
  406. steps = 0
  407. episodes = 0
  408. while keep_going(steps, num_steps, episodes, num_episodes):
  409. mapping_cache = {} # in case policy_agent_mapping is stochastic
  410. saver.begin_rollout()
  411. obs = env.reset()
  412. agent_states = DefaultMapping(
  413. lambda agent_id: state_init[mapping_cache[agent_id]])
  414. prev_actions = DefaultMapping(
  415. lambda agent_id: action_init[mapping_cache[agent_id]])
  416. prev_rewards = collections.defaultdict(lambda: 0.)
  417. done = False
  418. reward_total = 0.0
  419. while not done and keep_going(steps, num_steps, episodes,
  420. num_episodes):
  421. multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
  422. action_dict = {}
  423. for agent_id, a_obs in multi_obs.items():
  424. if a_obs is not None:
  425. policy_id = mapping_cache.setdefault(
  426. agent_id, policy_agent_mapping(agent_id))
  427. p_use_lstm = use_lstm[policy_id]
  428. if p_use_lstm:
  429. a_action, p_state, _ = agent.compute_single_action(
  430. a_obs,
  431. state=agent_states[agent_id],
  432. prev_action=prev_actions[agent_id],
  433. prev_reward=prev_rewards[agent_id],
  434. policy_id=policy_id)
  435. agent_states[agent_id] = p_state
  436. else:
  437. a_action = agent.compute_single_action(
  438. a_obs,
  439. prev_action=prev_actions[agent_id],
  440. prev_reward=prev_rewards[agent_id],
  441. policy_id=policy_id)
  442. a_action = flatten_to_single_ndarray(a_action)
  443. action_dict[agent_id] = a_action
  444. prev_actions[agent_id] = a_action
  445. action = action_dict
  446. action = action if multiagent else action[_DUMMY_AGENT_ID]
  447. next_obs, reward, done, info = env.step(action)
  448. if multiagent:
  449. for agent_id, r in reward.items():
  450. prev_rewards[agent_id] = r
  451. else:
  452. prev_rewards[_DUMMY_AGENT_ID] = reward
  453. if multiagent:
  454. done = done["__all__"]
  455. reward_total += sum(
  456. r for r in reward.values() if r is not None)
  457. else:
  458. reward_total += reward
  459. if not no_render:
  460. env.render()
  461. saver.append_step(obs, action, next_obs, reward, done, info)
  462. steps += 1
  463. obs = next_obs
  464. saver.end_rollout()
  465. print("Episode #{}: reward: {}".format(episodes, reward_total))
  466. if done:
  467. episodes += 1
  468. def main():
  469. parser = create_parser()
  470. args = parser.parse_args()
  471. # --use_shelve w/o --out option.
  472. if args.use_shelve and not args.out:
  473. raise ValueError(
  474. "If you set --use-shelve, you must provide an output file via "
  475. "--out as well!")
  476. # --track-progress w/o --out option.
  477. if args.track_progress and not args.out:
  478. raise ValueError(
  479. "If you set --track-progress, you must provide an output file via "
  480. "--out as well!")
  481. run(args, parser)
  482. if __name__ == "__main__":
  483. main()