123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543 |
- import gym
- import logging
- import importlib.util
- from types import FunctionType
- from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
- import ray
- from ray.actor import ActorHandle
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- from ray.rllib.env.base_env import BaseEnv
- from ray.rllib.env.env_context import EnvContext
- from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
- ShuffledInput, D4RLReader
- from ray.rllib.policy.policy import Policy, PolicySpec
- from ray.rllib.utils import merge_dicts
- from ray.rllib.utils.annotations import DeveloperAPI
- from ray.rllib.utils.framework import try_import_tf
- from ray.rllib.utils.from_config import from_config
- from ray.rllib.utils.typing import EnvCreator, EnvType, PolicyID, \
- TrainerConfigDict
- from ray.tune.registry import registry_contains_input, registry_get_input
- tf1, tf, tfv = try_import_tf()
- logger = logging.getLogger(__name__)
- # Generic type var for foreach_* methods.
- T = TypeVar("T")
- @DeveloperAPI
- class WorkerSet:
- """Set of RolloutWorkers with n @ray.remote workers and one local worker.
- Where n may be 0.
- """
- def __init__(
- self,
- *,
- env_creator: Optional[EnvCreator] = None,
- validate_env: Optional[Callable[[EnvType], None]] = None,
- policy_class: Optional[Type[Policy]] = None,
- trainer_config: Optional[TrainerConfigDict] = None,
- num_workers: int = 0,
- local_worker: bool = True,
- logdir: Optional[str] = None,
- _setup: bool = True,
- ):
- """Initializes a WorkerSet instance.
- Args:
- env_creator: Function that returns env given env config.
- validate_env: Optional callable to validate the generated
- environment (only on worker=0).
- policy_class: An optional Policy class. If None, PolicySpecs can be
- generated automatically by using the Trainer's default class
- of via a given multi-agent policy config dict.
- trainer_config: Optional dict that extends the common config of
- the Trainer class.
- num_workers: Number of remote rollout workers to create.
- local_worker: Whether to create a local (non @ray.remote) worker
- in the returned set as well (default: True). If `num_workers`
- is 0, always create a local worker.
- logdir: Optional logging directory for workers.
- _setup: Whether to setup workers. This is only for testing.
- """
- if not trainer_config:
- from ray.rllib.agents.trainer import COMMON_CONFIG
- trainer_config = COMMON_CONFIG
- self._env_creator = env_creator
- self._policy_class = policy_class
- self._remote_config = trainer_config
- self._logdir = logdir
- if _setup:
- # Force a local worker if num_workers == 0 (no remote workers).
- # Otherwise, this WorkerSet would be empty.
- self._local_worker = None
- if num_workers == 0:
- local_worker = True
- self._local_config = merge_dicts(
- trainer_config,
- {"tf_session_args": trainer_config["local_tf_session_args"]})
- # Create a number of @ray.remote workers.
- self._remote_workers = []
- self.add_workers(num_workers)
- # Create a local worker, if needed.
- # If num_workers > 0 and we don't have an env on the local worker,
- # get the observation- and action spaces for each policy from
- # the first remote worker (which does have an env).
- if local_worker and self._remote_workers and \
- not trainer_config.get("create_env_on_driver") and \
- (not trainer_config.get("observation_space") or
- not trainer_config.get("action_space")):
- remote_spaces = ray.get(self.remote_workers(
- )[0].foreach_policy.remote(
- lambda p, pid: (pid, p.observation_space, p.action_space)))
- spaces = {
- e[0]: (getattr(e[1], "original_space", e[1]), e[2])
- for e in remote_spaces
- }
- # Try to add the actual env's obs/action spaces.
- try:
- env_spaces = ray.get(self.remote_workers(
- )[0].foreach_env.remote(
- lambda env: (env.observation_space, env.action_space))
- )[0]
- spaces["__env__"] = env_spaces
- except Exception:
- pass
- logger.info("Inferred observation/action spaces from remote "
- f"worker (local worker has no env): {spaces}")
- else:
- spaces = None
- if local_worker:
- self._local_worker = self._make_worker(
- cls=RolloutWorker,
- env_creator=env_creator,
- validate_env=validate_env,
- policy_cls=self._policy_class,
- worker_index=0,
- num_workers=num_workers,
- config=self._local_config,
- spaces=spaces,
- )
- def local_worker(self) -> RolloutWorker:
- """Returns the local rollout worker."""
- return self._local_worker
- def remote_workers(self) -> List[ActorHandle]:
- """Returns a list of remote rollout workers."""
- return self._remote_workers
- def sync_weights(self,
- policies: Optional[List[PolicyID]] = None,
- from_worker: Optional[RolloutWorker] = None) -> None:
- """Syncs model weights from the local worker to all remote workers.
- Args:
- policies: Optional list of PolicyIDs to sync weights for.
- If None (default), sync weights to/from all policies.
- from_worker: Optional RolloutWorker instance to sync from.
- If None (default), sync from this WorkerSet's local worker.
- """
- if self.local_worker() is None and from_worker is None:
- raise TypeError(
- "No `local_worker` in WorkerSet, must provide `from_worker` "
- "arg in `sync_weights()`!")
- # Only sync if we have remote workers or `from_worker` is provided.
- if self.remote_workers() or from_worker is not None:
- weights = (from_worker
- or self.local_worker()).get_weights(policies)
- # Put weights only once into object store and use same object
- # ref to synch to all workers.
- weights_ref = ray.put(weights)
- # Sync to all remote workers in this WorkerSet.
- for to_worker in self.remote_workers():
- to_worker.set_weights.remote(weights_ref)
- # If `from_worker` is provided, also sync to this WorkerSet's
- # local worker.
- if from_worker is not None and self.local_worker() is not None:
- self.local_worker().set_weights(weights)
- def add_workers(self, num_workers: int) -> None:
- """Creates and adds a number of remote workers to this worker set.
- Can be called several times on the same WorkerSet to add more
- RolloutWorkers to the set.
- Args:
- num_workers: The number of remote Workers to add to this
- WorkerSet.
- """
- remote_args = {
- "num_cpus": self._remote_config["num_cpus_per_worker"],
- "num_gpus": self._remote_config["num_gpus_per_worker"],
- "resources": self._remote_config["custom_resources_per_worker"],
- }
- cls = RolloutWorker.as_remote(**remote_args).remote
- self._remote_workers.extend([
- self._make_worker(
- cls=cls,
- env_creator=self._env_creator,
- validate_env=None,
- policy_cls=self._policy_class,
- worker_index=i + 1,
- num_workers=num_workers,
- config=self._remote_config,
- ) for i in range(num_workers)
- ])
- def reset(self, new_remote_workers: List[ActorHandle]) -> None:
- """Hard overrides the remote workers in this set with the given one.
- Args:
- new_remote_workers: A list of new RolloutWorkers
- (as `ActorHandles`) to use as remote workers.
- """
- self._remote_workers = new_remote_workers
- def stop(self) -> None:
- """Calls `stop` on all rollout workers (including the local one)."""
- try:
- self.local_worker().stop()
- tids = [w.stop.remote() for w in self.remote_workers()]
- ray.get(tids)
- except Exception:
- logger.exception("Failed to stop workers!")
- finally:
- for w in self.remote_workers():
- w.__ray_terminate__.remote()
- @DeveloperAPI
- def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
- """Calls the given function with each worker instance as arg.
- Args:
- func: The function to call for each worker (as only arg).
- Returns:
- The list of return values of all calls to `func([worker])`.
- """
- local_result = []
- if self.local_worker() is not None:
- local_result = [func(self.local_worker())]
- remote_results = ray.get(
- [w.apply.remote(func) for w in self.remote_workers()])
- return local_result + remote_results
- @DeveloperAPI
- def foreach_worker_with_index(
- self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
- """Calls `func` with each worker instance and worker idx as args.
- The index will be passed as the second arg to the given function.
- Args:
- func: The function to call for each worker and its index
- (as args). The local worker has index 0, all remote workers
- have indices > 0.
- Returns:
- The list of return values of all calls to `func([worker, idx])`.
- The first entry in this list are the results of the local
- worker, followed by all remote workers' results.
- """
- local_result = []
- # Local worker: Index=0.
- if self.local_worker() is not None:
- local_result = [func(self.local_worker(), 0)]
- # Remote workers: Index > 0.
- remote_results = ray.get([
- w.apply.remote(func, i + 1)
- for i, w in enumerate(self.remote_workers())
- ])
- return local_result + remote_results
- @DeveloperAPI
- def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
- """Calls `func` with each worker's (policy, PolicyID) tuple.
- Note that in the multi-agent case, each worker may have more than one
- policy.
- Args:
- func: A function - taking a Policy and its ID - that is
- called on all workers' Policies.
- Returns:
- The list of return values of func over all workers' policies. The
- length of this list is:
- (num_workers + 1 (local-worker)) *
- [num policies in the multi-agent config dict].
- The local workers' results are first, followed by all remote
- workers' results
- """
- results = []
- if self.local_worker() is not None:
- results = self.local_worker().foreach_policy(func)
- ray_gets = []
- for worker in self.remote_workers():
- ray_gets.append(
- worker.apply.remote(lambda w: w.foreach_policy(func)))
- remote_results = ray.get(ray_gets)
- for r in remote_results:
- results.extend(r)
- return results
- @DeveloperAPI
- def trainable_policies(self) -> List[PolicyID]:
- """Returns the list of trainable policy ids."""
- if self.local_worker() is not None:
- return self.local_worker().policies_to_train
- else:
- raise NotImplementedError
- @DeveloperAPI
- def foreach_trainable_policy(
- self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
- """Apply `func` to all workers' Policies iff in `policies_to_train`.
- Args:
- func: A function - taking a Policy and its ID - that is
- called on all workers' Policies in `worker.policies_to_train`.
- Returns:
- List[any]: The list of n return values of all
- `func([trainable policy], [ID])`-calls.
- """
- results = []
- if self.local_worker() is not None:
- results = self.local_worker().foreach_trainable_policy(func)
- ray_gets = []
- for worker in self.remote_workers():
- ray_gets.append(
- worker.apply.remote(
- lambda w: w.foreach_trainable_policy(func)))
- remote_results = ray.get(ray_gets)
- for r in remote_results:
- results.extend(r)
- return results
- @DeveloperAPI
- def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
- """Calls `func` with all workers' sub-environments as args.
- An "underlying sub environment" is a single clone of an env within
- a vectorized environment.
- `func` takes a single underlying sub environment as arg, e.g. a
- gym.Env object.
- Args:
- func: A function - taking an EnvType (normally a gym.Env object)
- as arg and returning a list of lists of return values, one
- value per underlying sub-environment per each worker.
- Returns:
- The list (workers) of lists (sub environments) of results.
- """
- local_results = []
- if self.local_worker() is not None:
- local_results = [self.local_worker().foreach_env(func)]
- ray_gets = []
- for worker in self.remote_workers():
- ray_gets.append(worker.foreach_env.remote(func))
- return local_results + ray.get(ray_gets)
- @DeveloperAPI
- def foreach_env_with_context(
- self,
- func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
- """Calls `func` with all workers' sub-environments and env_ctx as args.
- An "underlying sub environment" is a single clone of an env within
- a vectorized environment.
- `func` takes a single underlying sub environment and the env_context
- as args.
- Args:
- func: A function - taking a BaseEnv object and an EnvContext as
- arg - and returning a list of lists of return values over envs
- of the worker.
- Returns:
- The list (1 item per workers) of lists (1 item per sub-environment)
- of results.
- """
- local_results = []
- if self.local_worker() is not None:
- local_results = [
- self.local_worker().foreach_env_with_context(func)
- ]
- ray_gets = []
- for worker in self.remote_workers():
- ray_gets.append(worker.foreach_env_with_context.remote(func))
- return local_results + ray.get(ray_gets)
- @staticmethod
- def _from_existing(local_worker: RolloutWorker,
- remote_workers: List[ActorHandle] = None):
- workers = WorkerSet(
- env_creator=None,
- policy_class=None,
- trainer_config={},
- _setup=False)
- workers._local_worker = local_worker
- workers._remote_workers = remote_workers or []
- return workers
- def _make_worker(
- self,
- *,
- cls: Callable,
- env_creator: EnvCreator,
- validate_env: Optional[Callable[[EnvType], None]],
- policy_cls: Type[Policy],
- worker_index: int,
- num_workers: int,
- config: TrainerConfigDict,
- spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
- gym.spaces.Space]]] = None,
- ) -> Union[RolloutWorker, ActorHandle]:
- def session_creator():
- logger.debug("Creating TF session {}".format(
- config["tf_session_args"]))
- return tf1.Session(
- config=tf1.ConfigProto(**config["tf_session_args"]))
- def valid_module(class_path):
- if isinstance(class_path, str) and "." in class_path:
- module_path, class_name = class_path.rsplit(".", 1)
- try:
- spec = importlib.util.find_spec(module_path)
- if spec is not None:
- return True
- except (ModuleNotFoundError, ValueError):
- print(
- f"module {module_path} not found while trying to get "
- f"input {class_path}")
- return False
- if isinstance(config["input"], FunctionType):
- input_creator = config["input"]
- elif config["input"] == "sampler":
- input_creator = (lambda ioctx: ioctx.default_sampler_input())
- elif isinstance(config["input"], dict):
- input_creator = (
- lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
- config["shuffle_buffer_size"]))
- elif isinstance(config["input"], str) and \
- registry_contains_input(config["input"]):
- input_creator = registry_get_input(config["input"])
- elif "d4rl" in config["input"]:
- env_name = config["input"].split(".")[-1]
- input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
- elif valid_module(config["input"]):
- input_creator = (lambda ioctx: ShuffledInput(from_config(
- config["input"], ioctx=ioctx)))
- else:
- input_creator = (
- lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
- config["shuffle_buffer_size"]))
- if isinstance(config["output"], FunctionType):
- output_creator = config["output"]
- elif config["output"] is None:
- output_creator = (lambda ioctx: NoopOutput())
- elif config["output"] == "logdir":
- output_creator = (lambda ioctx: JsonWriter(
- ioctx.log_dir,
- ioctx,
- max_file_size=config["output_max_file_size"],
- compress_columns=config["output_compress_columns"]))
- else:
- output_creator = (lambda ioctx: JsonWriter(
- config["output"],
- ioctx,
- max_file_size=config["output_max_file_size"],
- compress_columns=config["output_compress_columns"]))
- if config["input"] == "sampler":
- input_evaluation = []
- else:
- input_evaluation = config["input_evaluation"]
- # Assert everything is correct in "multiagent" config dict (if given).
- ma_policies = config["multiagent"]["policies"]
- if ma_policies:
- for pid, policy_spec in ma_policies.copy().items():
- assert isinstance(policy_spec, PolicySpec)
- # Class is None -> Use `policy_cls`.
- if policy_spec.policy_class is None:
- ma_policies[pid] = ma_policies[pid]._replace(
- policy_class=policy_cls)
- policies = ma_policies
- # Create a policy_spec (MultiAgentPolicyConfigDict),
- # even if no "multiagent" setup given by user.
- else:
- policies = policy_cls
- if worker_index == 0:
- extra_python_environs = config.get(
- "extra_python_environs_for_driver", None)
- else:
- extra_python_environs = config.get(
- "extra_python_environs_for_worker", None)
- worker = cls(
- env_creator=env_creator,
- validate_env=validate_env,
- policy_spec=policies,
- policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
- policies_to_train=config["multiagent"]["policies_to_train"],
- tf_session_creator=(session_creator
- if config["tf_session_args"] else None),
- rollout_fragment_length=config["rollout_fragment_length"],
- count_steps_by=config["multiagent"]["count_steps_by"],
- batch_mode=config["batch_mode"],
- episode_horizon=config["horizon"],
- preprocessor_pref=config["preprocessor_pref"],
- sample_async=config["sample_async"],
- compress_observations=config["compress_observations"],
- num_envs=config["num_envs_per_worker"],
- observation_fn=config["multiagent"]["observation_fn"],
- observation_filter=config["observation_filter"],
- clip_rewards=config["clip_rewards"],
- normalize_actions=config["normalize_actions"],
- clip_actions=config["clip_actions"],
- env_config=config["env_config"],
- policy_config=config,
- worker_index=worker_index,
- num_workers=num_workers,
- record_env=config["record_env"],
- log_dir=self._logdir,
- log_level=config["log_level"],
- callbacks=config["callbacks"],
- input_creator=input_creator,
- input_evaluation=input_evaluation,
- output_creator=output_creator,
- remote_worker_envs=config["remote_worker_envs"],
- remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
- soft_horizon=config["soft_horizon"],
- no_done_at_end=config["no_done_at_end"],
- seed=(config["seed"] + worker_index)
- if config["seed"] is not None else None,
- fake_sampler=config["fake_sampler"],
- extra_python_environs=extra_python_environs,
- spaces=spaces,
- )
- return worker
|