123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467 |
- #!/usr/bin/env python
- import collections
- import copy
- import gymnasium as gym
- import json
- import os
- from pathlib import Path
- import shelve
- import typer
- import ray
- import ray.cloudpickle as cloudpickle
- from ray.rllib.common import CLIArguments as cli
- from ray.rllib.env import MultiAgentEnv
- from ray.rllib.env.base_env import _DUMMY_AGENT_ID
- from ray.rllib.env.env_context import EnvContext
- from ray.rllib.env.env_runner_group import EnvRunnerGroup
- from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
- from ray.rllib.utils.metrics import (
- ENV_RUNNER_RESULTS,
- EPISODE_LEN_MEAN,
- EPISODE_RETURN_MEAN,
- )
- from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray
- from ray.train._checkpoint import Checkpoint
- from ray.train._internal.session import _TrainingResult
- from ray.tune.utils import merge_dicts
- from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
- # create the "evaluate" Typer app
- eval_app = typer.Typer()
- class RolloutSaver:
- """Utility class for storing rollouts.
- Currently supports two behaviours: the original, which
- simply dumps everything to a pickle file once complete,
- and a mode which stores each rollout as an entry in a Python
- shelf db file. The latter mode is more robust to memory problems
- or crashes part-way through the rollout generation. Each rollout
- is stored with a key based on the episode number (0-indexed),
- and the number of episodes is stored with the key "num_episodes",
- so to load the shelf file, use something like:
- with shelve.open('rollouts.pkl') as rollouts:
- for episode_index in range(rollouts["num_episodes"]):
- rollout = rollouts[str(episode_index)]
- If outfile is None, this class does nothing.
- """
- def __init__(
- self,
- outfile=None,
- use_shelve=False,
- write_update_file=False,
- target_steps=None,
- target_episodes=None,
- save_info=False,
- ):
- self._outfile = outfile
- self._update_file = None
- self._use_shelve = use_shelve
- self._write_update_file = write_update_file
- self._shelf = None
- self._num_episodes = 0
- self._rollouts = []
- self._current_rollout = []
- self._total_steps = 0
- self._target_episodes = target_episodes
- self._target_steps = target_steps
- self._save_info = save_info
- def _get_tmp_progress_filename(self):
- outpath = Path(self._outfile)
- return outpath.parent / ("__progress_" + outpath.name)
- @property
- def outfile(self):
- return self._outfile
- def __enter__(self):
- if self._outfile:
- if self._use_shelve:
- # Open a shelf file to store each rollout as they come in
- self._shelf = shelve.open(self._outfile)
- else:
- # Original behaviour - keep all rollouts in memory and save
- # them all at the end.
- # But check we can actually write to the outfile before going
- # through the effort of generating the rollouts:
- try:
- with open(self._outfile, "wb") as _:
- pass
- except IOError as x:
- print(
- "Can not open {} for writing - cancelling rollouts.".format(
- self._outfile
- )
- )
- raise x
- if self._write_update_file:
- # Open a file to track rollout progress:
- self._update_file = self._get_tmp_progress_filename().open(mode="w")
- return self
- def __exit__(self, type, value, traceback):
- if self._shelf:
- # Close the shelf file, and store the number of episodes for ease
- self._shelf["num_episodes"] = self._num_episodes
- self._shelf.close()
- elif self._outfile and not self._use_shelve:
- # Dump everything as one big pickle:
- cloudpickle.dump(self._rollouts, open(self._outfile, "wb"))
- if self._update_file:
- # Remove the temp progress file:
- self._get_tmp_progress_filename().unlink()
- self._update_file = None
- def _get_progress(self):
- if self._target_episodes:
- return "{} / {} episodes completed".format(
- self._num_episodes, self._target_episodes
- )
- elif self._target_steps:
- return "{} / {} steps completed".format(
- self._total_steps, self._target_steps
- )
- else:
- return "{} episodes completed".format(self._num_episodes)
- def begin_rollout(self):
- self._current_rollout = []
- def end_rollout(self):
- if self._outfile:
- if self._use_shelve:
- # Save this episode as a new entry in the shelf database,
- # using the episode number as the key.
- self._shelf[str(self._num_episodes)] = self._current_rollout
- else:
- # Append this rollout to our list, to save laer.
- self._rollouts.append(self._current_rollout)
- self._num_episodes += 1
- if self._update_file:
- self._update_file.seek(0)
- self._update_file.write(self._get_progress() + "\n")
- self._update_file.flush()
- def append_step(self, obs, action, next_obs, reward, terminated, truncated, info):
- """Add a step to the current rollout, if we are saving them"""
- if self._outfile:
- if self._save_info:
- self._current_rollout.append(
- [obs, action, next_obs, reward, terminated, truncated, info]
- )
- else:
- self._current_rollout.append(
- [obs, action, next_obs, reward, terminated, truncated]
- )
- self._total_steps += 1
- @eval_app.command()
- def run(
- checkpoint: str = cli.Checkpoint,
- algo: str = cli.Algo,
- env: str = cli.Env,
- local_mode: bool = cli.LocalMode,
- render: bool = cli.Render,
- steps: int = cli.Steps,
- episodes: int = cli.Episodes,
- out: str = cli.Out,
- config: str = cli.Config,
- save_info: bool = cli.SaveInfo,
- use_shelve: bool = cli.UseShelve,
- track_progress: bool = cli.TrackProgress,
- ):
- if use_shelve and not out:
- raise ValueError(
- "If you set --use-shelve, you must provide an output file via "
- "--out as well!"
- )
- if track_progress and not out:
- raise ValueError(
- "If you set --track-progress, you must provide an output file via "
- "--out as well!"
- )
- # Load configuration from checkpoint file.
- config_args = json.loads(config)
- config_path = ""
- if checkpoint:
- config_dir = os.path.dirname(checkpoint)
- config_path = os.path.join(config_dir, "params.pkl")
- # Try parent directory.
- if not os.path.exists(config_path):
- config_path = os.path.join(config_dir, "../params.pkl")
- # Load the config from pickled.
- if os.path.exists(config_path):
- with open(config_path, "rb") as f:
- config = cloudpickle.load(f)
- # If no pkl file found, require command line `--config`.
- else:
- # If no config in given checkpoint -> Error.
- if checkpoint:
- raise ValueError(
- "Could not find params.pkl in either the checkpoint dir or "
- "its parent directory AND no `--config` given on command "
- "line!"
- )
- # Use default config for given agent.
- if not algo:
- raise ValueError("Please provide an algorithm via `--algo`.")
- algo_cls = get_trainable_cls(algo)
- config = algo_cls.get_default_config()
- # Make sure worker 0 has an Env.
- config["create_env_on_driver"] = True
- # Merge with `evaluation_config` (first try from command line, then from
- # pkl file).
- evaluation_config = copy.deepcopy(
- config_args.get("evaluation_config", config.get("evaluation_config", {}))
- )
- config = merge_dicts(config, evaluation_config)
- # Merge with command line `--config` settings (if not already the same anyways).
- config = merge_dicts(config, config_args)
- if not env:
- if not config.get("env"):
- raise ValueError(
- "You either need to provide an --env argument or pass"
- "an `env` key with a valid environment to your `config`"
- "argument."
- )
- env = config.get("env")
- # Make sure we have evaluation workers.
- if not config.get(
- "evaluation_num_workers", config.get("evaluation_num_env_runners")
- ):
- config["evaluation_num_env_runners"] = config.get(
- "num_env_runners", config.get("num_workers")
- )
- if not config.get("evaluation_duration"):
- config["evaluation_duration"] = 1
- # Hard-override this as it raises a warning by Algorithm otherwise.
- # Makes no sense anyways, to have it set to None as we don't call
- # `Algorithm.train()` here.
- config["evaluation_interval"] = 1
- # Rendering settings.
- config["render_env"] = render
- ray.init(local_mode=local_mode)
- # Create the Algorithm from config.
- cls = get_trainable_cls(algo)
- algorithm = cls(config=config)
- # Load state from checkpoint, if provided.
- if checkpoint:
- if os.path.isdir(checkpoint):
- checkpoint_dir = checkpoint
- else:
- checkpoint_dir = str(Path(checkpoint).parent)
- print(f"Restoring algorithm from {checkpoint_dir}")
- restore_result = _TrainingResult(
- checkpoint=Checkpoint.from_directory(checkpoint_dir), metrics={}
- )
- algorithm.restore(restore_result)
- # Do the actual rollout.
- with RolloutSaver(
- outfile=out,
- use_shelve=use_shelve,
- write_update_file=track_progress,
- target_steps=steps,
- target_episodes=episodes,
- save_info=save_info,
- ) as saver:
- rollout(algorithm, env, steps, episodes, saver, not render)
- algorithm.stop()
- class DefaultMapping(collections.defaultdict):
- """default_factory now takes as an argument the missing key."""
- def __missing__(self, key):
- self[key] = value = self.default_factory(key)
- return value
- def default_policy_agent_mapping(unused_agent_id) -> str:
- return DEFAULT_POLICY_ID
- def keep_going(steps: int, num_steps: int, episodes: int, num_episodes: int) -> bool:
- """Determine whether we've run enough steps or episodes."""
- episode_limit_reached = num_episodes and episodes >= num_episodes
- step_limit_reached = num_steps and steps >= num_steps
- return False if episode_limit_reached or step_limit_reached else True
- def rollout(
- agent,
- env_name, # keep me, used in tests
- num_steps,
- num_episodes=0,
- saver=None,
- no_render=True,
- ):
- policy_agent_mapping = default_policy_agent_mapping
- if saver is None:
- saver = RolloutSaver()
- # Normal case: Agent was setup correctly with an evaluation EnvRunnerGroup,
- # which we will now use to rollout.
- if hasattr(agent, "eval_env_runner_group") and isinstance(
- agent.eval_env_runner_group, EnvRunnerGroup
- ):
- steps = 0
- episodes = 0
- while keep_going(steps, num_steps, episodes, num_episodes):
- saver.begin_rollout()
- eval_result = agent.evaluate()
- # Increase time-step and episode counters.
- eps = agent.config["evaluation_duration"]
- episodes += eps
- steps += eps * eval_result[ENV_RUNNER_RESULTS][EPISODE_LEN_MEAN]
- # Print out results and continue.
- print(
- "Episode #{}: reward: {}".format(
- episodes, eval_result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]
- )
- )
- saver.end_rollout()
- return
- # Agent has no evaluation workers, but RolloutWorkers.
- elif hasattr(agent, "env_runner_group") and isinstance(
- agent.env_runner_group, EnvRunnerGroup
- ):
- env = agent.env_runner.env
- multiagent = isinstance(env, MultiAgentEnv)
- if agent.env_runner.multiagent:
- policy_agent_mapping = agent.config.policy_mapping_fn
- policy_map = agent.env_runner.policy_map
- state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
- use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
- # Agent has neither evaluation- nor rollout workers.
- else:
- from gymnasium import envs
- if envs.registry.env_specs.get(agent.config["env"]):
- # if environment is gym environment, load from gym
- env = gym.make(agent.config["env"])
- else:
- # if environment registered ray environment, load from ray
- env_creator = _global_registry.get(ENV_CREATOR, agent.config["env"])
- env_context = EnvContext(agent.config["env_config"] or {}, worker_index=0)
- env = env_creator(env_context)
- multiagent = False
- try:
- policy_map = {DEFAULT_POLICY_ID: agent.policy}
- except AttributeError:
- raise AttributeError(
- "Agent ({}) does not have a `policy` property! This is needed "
- "for performing (trained) agent rollouts.".format(agent)
- )
- use_lstm = {DEFAULT_POLICY_ID: False}
- action_init = {
- p: flatten_to_single_ndarray(m.action_space.sample())
- for p, m in policy_map.items()
- }
- steps = 0
- episodes = 0
- while keep_going(steps, num_steps, episodes, num_episodes):
- mapping_cache = {} # in case policy_agent_mapping is stochastic
- saver.begin_rollout()
- obs, info = env.reset()
- agent_states = DefaultMapping(
- lambda agent_id: state_init[mapping_cache[agent_id]] # noqa
- )
- prev_actions = DefaultMapping(
- lambda agent_id: action_init[mapping_cache[agent_id]] # noqa
- )
- prev_rewards = collections.defaultdict(lambda: 0.0)
- terminated = truncated = False
- reward_total = 0.0
- while (
- not terminated
- and not truncated
- and keep_going(steps, num_steps, episodes, num_episodes)
- ):
- multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
- action_dict = {}
- for agent_id, a_obs in multi_obs.items():
- if a_obs is not None:
- policy_id = mapping_cache.setdefault(
- agent_id, policy_agent_mapping(agent_id)
- )
- p_use_lstm = use_lstm[policy_id]
- if p_use_lstm:
- a_action, p_state, _ = agent.compute_single_action(
- a_obs,
- state=agent_states[agent_id],
- prev_action=prev_actions[agent_id],
- prev_reward=prev_rewards[agent_id],
- policy_id=policy_id,
- )
- agent_states[agent_id] = p_state
- else:
- a_action = agent.compute_single_action(
- a_obs,
- prev_action=prev_actions[agent_id],
- prev_reward=prev_rewards[agent_id],
- policy_id=policy_id,
- )
- a_action = flatten_to_single_ndarray(a_action)
- action_dict[agent_id] = a_action
- prev_actions[agent_id] = a_action
- action = action_dict
- action = action if multiagent else action[_DUMMY_AGENT_ID]
- next_obs, reward, terminated, truncated, info = env.step(action)
- if multiagent:
- for agent_id, r in reward.items():
- prev_rewards[agent_id] = r
- else:
- prev_rewards[_DUMMY_AGENT_ID] = reward
- if multiagent:
- terminated = terminated["__all__"]
- truncated = truncated["__all__"]
- reward_total += sum(r for r in reward.values() if r is not None)
- else:
- reward_total += reward
- if not no_render:
- env.render()
- saver.append_step(
- obs, action, next_obs, reward, terminated, truncated, info
- )
- steps += 1
- obs = next_obs
- saver.end_rollout()
- print("Episode #{}: reward: {}".format(episodes, reward_total))
- if terminated or truncated:
- episodes += 1
- def main():
- """Run the CLI."""
- eval_app()
- if __name__ == "__main__":
- main()
|