worker_set.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. import gym
  2. import logging
  3. import importlib.util
  4. from types import FunctionType
  5. from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
  6. import ray
  7. from ray.actor import ActorHandle
  8. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  9. from ray.rllib.env.base_env import BaseEnv
  10. from ray.rllib.env.env_context import EnvContext
  11. from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
  12. ShuffledInput, D4RLReader
  13. from ray.rllib.policy.policy import Policy, PolicySpec
  14. from ray.rllib.utils import merge_dicts
  15. from ray.rllib.utils.annotations import DeveloperAPI
  16. from ray.rllib.utils.framework import try_import_tf
  17. from ray.rllib.utils.from_config import from_config
  18. from ray.rllib.utils.typing import EnvCreator, EnvType, PolicyID, \
  19. TrainerConfigDict
  20. from ray.tune.registry import registry_contains_input, registry_get_input
  21. tf1, tf, tfv = try_import_tf()
  22. logger = logging.getLogger(__name__)
  23. # Generic type var for foreach_* methods.
  24. T = TypeVar("T")
  25. @DeveloperAPI
  26. class WorkerSet:
  27. """Set of RolloutWorkers with n @ray.remote workers and one local worker.
  28. Where n may be 0.
  29. """
  30. def __init__(
  31. self,
  32. *,
  33. env_creator: Optional[EnvCreator] = None,
  34. validate_env: Optional[Callable[[EnvType], None]] = None,
  35. policy_class: Optional[Type[Policy]] = None,
  36. trainer_config: Optional[TrainerConfigDict] = None,
  37. num_workers: int = 0,
  38. local_worker: bool = True,
  39. logdir: Optional[str] = None,
  40. _setup: bool = True,
  41. ):
  42. """Initializes a WorkerSet instance.
  43. Args:
  44. env_creator: Function that returns env given env config.
  45. validate_env: Optional callable to validate the generated
  46. environment (only on worker=0).
  47. policy_class: An optional Policy class. If None, PolicySpecs can be
  48. generated automatically by using the Trainer's default class
  49. of via a given multi-agent policy config dict.
  50. trainer_config: Optional dict that extends the common config of
  51. the Trainer class.
  52. num_workers: Number of remote rollout workers to create.
  53. local_worker: Whether to create a local (non @ray.remote) worker
  54. in the returned set as well (default: True). If `num_workers`
  55. is 0, always create a local worker.
  56. logdir: Optional logging directory for workers.
  57. _setup: Whether to setup workers. This is only for testing.
  58. """
  59. if not trainer_config:
  60. from ray.rllib.agents.trainer import COMMON_CONFIG
  61. trainer_config = COMMON_CONFIG
  62. self._env_creator = env_creator
  63. self._policy_class = policy_class
  64. self._remote_config = trainer_config
  65. self._logdir = logdir
  66. if _setup:
  67. # Force a local worker if num_workers == 0 (no remote workers).
  68. # Otherwise, this WorkerSet would be empty.
  69. self._local_worker = None
  70. if num_workers == 0:
  71. local_worker = True
  72. self._local_config = merge_dicts(
  73. trainer_config,
  74. {"tf_session_args": trainer_config["local_tf_session_args"]})
  75. # Create a number of @ray.remote workers.
  76. self._remote_workers = []
  77. self.add_workers(num_workers)
  78. # Create a local worker, if needed.
  79. # If num_workers > 0 and we don't have an env on the local worker,
  80. # get the observation- and action spaces for each policy from
  81. # the first remote worker (which does have an env).
  82. if local_worker and self._remote_workers and \
  83. not trainer_config.get("create_env_on_driver") and \
  84. (not trainer_config.get("observation_space") or
  85. not trainer_config.get("action_space")):
  86. remote_spaces = ray.get(self.remote_workers(
  87. )[0].foreach_policy.remote(
  88. lambda p, pid: (pid, p.observation_space, p.action_space)))
  89. spaces = {
  90. e[0]: (getattr(e[1], "original_space", e[1]), e[2])
  91. for e in remote_spaces
  92. }
  93. # Try to add the actual env's obs/action spaces.
  94. try:
  95. env_spaces = ray.get(self.remote_workers(
  96. )[0].foreach_env.remote(
  97. lambda env: (env.observation_space, env.action_space))
  98. )[0]
  99. spaces["__env__"] = env_spaces
  100. except Exception:
  101. pass
  102. logger.info("Inferred observation/action spaces from remote "
  103. f"worker (local worker has no env): {spaces}")
  104. else:
  105. spaces = None
  106. if local_worker:
  107. self._local_worker = self._make_worker(
  108. cls=RolloutWorker,
  109. env_creator=env_creator,
  110. validate_env=validate_env,
  111. policy_cls=self._policy_class,
  112. worker_index=0,
  113. num_workers=num_workers,
  114. config=self._local_config,
  115. spaces=spaces,
  116. )
  117. def local_worker(self) -> RolloutWorker:
  118. """Returns the local rollout worker."""
  119. return self._local_worker
  120. def remote_workers(self) -> List[ActorHandle]:
  121. """Returns a list of remote rollout workers."""
  122. return self._remote_workers
  123. def sync_weights(self,
  124. policies: Optional[List[PolicyID]] = None,
  125. from_worker: Optional[RolloutWorker] = None) -> None:
  126. """Syncs model weights from the local worker to all remote workers.
  127. Args:
  128. policies: Optional list of PolicyIDs to sync weights for.
  129. If None (default), sync weights to/from all policies.
  130. from_worker: Optional RolloutWorker instance to sync from.
  131. If None (default), sync from this WorkerSet's local worker.
  132. """
  133. if self.local_worker() is None and from_worker is None:
  134. raise TypeError(
  135. "No `local_worker` in WorkerSet, must provide `from_worker` "
  136. "arg in `sync_weights()`!")
  137. # Only sync if we have remote workers or `from_worker` is provided.
  138. if self.remote_workers() or from_worker is not None:
  139. weights = (from_worker
  140. or self.local_worker()).get_weights(policies)
  141. # Put weights only once into object store and use same object
  142. # ref to synch to all workers.
  143. weights_ref = ray.put(weights)
  144. # Sync to all remote workers in this WorkerSet.
  145. for to_worker in self.remote_workers():
  146. to_worker.set_weights.remote(weights_ref)
  147. # If `from_worker` is provided, also sync to this WorkerSet's
  148. # local worker.
  149. if from_worker is not None and self.local_worker() is not None:
  150. self.local_worker().set_weights(weights)
  151. def add_workers(self, num_workers: int) -> None:
  152. """Creates and adds a number of remote workers to this worker set.
  153. Can be called several times on the same WorkerSet to add more
  154. RolloutWorkers to the set.
  155. Args:
  156. num_workers: The number of remote Workers to add to this
  157. WorkerSet.
  158. """
  159. remote_args = {
  160. "num_cpus": self._remote_config["num_cpus_per_worker"],
  161. "num_gpus": self._remote_config["num_gpus_per_worker"],
  162. "resources": self._remote_config["custom_resources_per_worker"],
  163. }
  164. cls = RolloutWorker.as_remote(**remote_args).remote
  165. self._remote_workers.extend([
  166. self._make_worker(
  167. cls=cls,
  168. env_creator=self._env_creator,
  169. validate_env=None,
  170. policy_cls=self._policy_class,
  171. worker_index=i + 1,
  172. num_workers=num_workers,
  173. config=self._remote_config,
  174. ) for i in range(num_workers)
  175. ])
  176. def reset(self, new_remote_workers: List[ActorHandle]) -> None:
  177. """Hard overrides the remote workers in this set with the given one.
  178. Args:
  179. new_remote_workers: A list of new RolloutWorkers
  180. (as `ActorHandles`) to use as remote workers.
  181. """
  182. self._remote_workers = new_remote_workers
  183. def stop(self) -> None:
  184. """Calls `stop` on all rollout workers (including the local one)."""
  185. try:
  186. self.local_worker().stop()
  187. tids = [w.stop.remote() for w in self.remote_workers()]
  188. ray.get(tids)
  189. except Exception:
  190. logger.exception("Failed to stop workers!")
  191. finally:
  192. for w in self.remote_workers():
  193. w.__ray_terminate__.remote()
  194. @DeveloperAPI
  195. def foreach_worker(self, func: Callable[[RolloutWorker], T]) -> List[T]:
  196. """Calls the given function with each worker instance as arg.
  197. Args:
  198. func: The function to call for each worker (as only arg).
  199. Returns:
  200. The list of return values of all calls to `func([worker])`.
  201. """
  202. local_result = []
  203. if self.local_worker() is not None:
  204. local_result = [func(self.local_worker())]
  205. remote_results = ray.get(
  206. [w.apply.remote(func) for w in self.remote_workers()])
  207. return local_result + remote_results
  208. @DeveloperAPI
  209. def foreach_worker_with_index(
  210. self, func: Callable[[RolloutWorker, int], T]) -> List[T]:
  211. """Calls `func` with each worker instance and worker idx as args.
  212. The index will be passed as the second arg to the given function.
  213. Args:
  214. func: The function to call for each worker and its index
  215. (as args). The local worker has index 0, all remote workers
  216. have indices > 0.
  217. Returns:
  218. The list of return values of all calls to `func([worker, idx])`.
  219. The first entry in this list are the results of the local
  220. worker, followed by all remote workers' results.
  221. """
  222. local_result = []
  223. # Local worker: Index=0.
  224. if self.local_worker() is not None:
  225. local_result = [func(self.local_worker(), 0)]
  226. # Remote workers: Index > 0.
  227. remote_results = ray.get([
  228. w.apply.remote(func, i + 1)
  229. for i, w in enumerate(self.remote_workers())
  230. ])
  231. return local_result + remote_results
  232. @DeveloperAPI
  233. def foreach_policy(self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
  234. """Calls `func` with each worker's (policy, PolicyID) tuple.
  235. Note that in the multi-agent case, each worker may have more than one
  236. policy.
  237. Args:
  238. func: A function - taking a Policy and its ID - that is
  239. called on all workers' Policies.
  240. Returns:
  241. The list of return values of func over all workers' policies. The
  242. length of this list is:
  243. (num_workers + 1 (local-worker)) *
  244. [num policies in the multi-agent config dict].
  245. The local workers' results are first, followed by all remote
  246. workers' results
  247. """
  248. results = []
  249. if self.local_worker() is not None:
  250. results = self.local_worker().foreach_policy(func)
  251. ray_gets = []
  252. for worker in self.remote_workers():
  253. ray_gets.append(
  254. worker.apply.remote(lambda w: w.foreach_policy(func)))
  255. remote_results = ray.get(ray_gets)
  256. for r in remote_results:
  257. results.extend(r)
  258. return results
  259. @DeveloperAPI
  260. def trainable_policies(self) -> List[PolicyID]:
  261. """Returns the list of trainable policy ids."""
  262. if self.local_worker() is not None:
  263. return self.local_worker().policies_to_train
  264. else:
  265. raise NotImplementedError
  266. @DeveloperAPI
  267. def foreach_trainable_policy(
  268. self, func: Callable[[Policy, PolicyID], T]) -> List[T]:
  269. """Apply `func` to all workers' Policies iff in `policies_to_train`.
  270. Args:
  271. func: A function - taking a Policy and its ID - that is
  272. called on all workers' Policies in `worker.policies_to_train`.
  273. Returns:
  274. List[any]: The list of n return values of all
  275. `func([trainable policy], [ID])`-calls.
  276. """
  277. results = []
  278. if self.local_worker() is not None:
  279. results = self.local_worker().foreach_trainable_policy(func)
  280. ray_gets = []
  281. for worker in self.remote_workers():
  282. ray_gets.append(
  283. worker.apply.remote(
  284. lambda w: w.foreach_trainable_policy(func)))
  285. remote_results = ray.get(ray_gets)
  286. for r in remote_results:
  287. results.extend(r)
  288. return results
  289. @DeveloperAPI
  290. def foreach_env(self, func: Callable[[EnvType], List[T]]) -> List[List[T]]:
  291. """Calls `func` with all workers' sub-environments as args.
  292. An "underlying sub environment" is a single clone of an env within
  293. a vectorized environment.
  294. `func` takes a single underlying sub environment as arg, e.g. a
  295. gym.Env object.
  296. Args:
  297. func: A function - taking an EnvType (normally a gym.Env object)
  298. as arg and returning a list of lists of return values, one
  299. value per underlying sub-environment per each worker.
  300. Returns:
  301. The list (workers) of lists (sub environments) of results.
  302. """
  303. local_results = []
  304. if self.local_worker() is not None:
  305. local_results = [self.local_worker().foreach_env(func)]
  306. ray_gets = []
  307. for worker in self.remote_workers():
  308. ray_gets.append(worker.foreach_env.remote(func))
  309. return local_results + ray.get(ray_gets)
  310. @DeveloperAPI
  311. def foreach_env_with_context(
  312. self,
  313. func: Callable[[BaseEnv, EnvContext], List[T]]) -> List[List[T]]:
  314. """Calls `func` with all workers' sub-environments and env_ctx as args.
  315. An "underlying sub environment" is a single clone of an env within
  316. a vectorized environment.
  317. `func` takes a single underlying sub environment and the env_context
  318. as args.
  319. Args:
  320. func: A function - taking a BaseEnv object and an EnvContext as
  321. arg - and returning a list of lists of return values over envs
  322. of the worker.
  323. Returns:
  324. The list (1 item per workers) of lists (1 item per sub-environment)
  325. of results.
  326. """
  327. local_results = []
  328. if self.local_worker() is not None:
  329. local_results = [
  330. self.local_worker().foreach_env_with_context(func)
  331. ]
  332. ray_gets = []
  333. for worker in self.remote_workers():
  334. ray_gets.append(worker.foreach_env_with_context.remote(func))
  335. return local_results + ray.get(ray_gets)
  336. @staticmethod
  337. def _from_existing(local_worker: RolloutWorker,
  338. remote_workers: List[ActorHandle] = None):
  339. workers = WorkerSet(
  340. env_creator=None,
  341. policy_class=None,
  342. trainer_config={},
  343. _setup=False)
  344. workers._local_worker = local_worker
  345. workers._remote_workers = remote_workers or []
  346. return workers
  347. def _make_worker(
  348. self,
  349. *,
  350. cls: Callable,
  351. env_creator: EnvCreator,
  352. validate_env: Optional[Callable[[EnvType], None]],
  353. policy_cls: Type[Policy],
  354. worker_index: int,
  355. num_workers: int,
  356. config: TrainerConfigDict,
  357. spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space,
  358. gym.spaces.Space]]] = None,
  359. ) -> Union[RolloutWorker, ActorHandle]:
  360. def session_creator():
  361. logger.debug("Creating TF session {}".format(
  362. config["tf_session_args"]))
  363. return tf1.Session(
  364. config=tf1.ConfigProto(**config["tf_session_args"]))
  365. def valid_module(class_path):
  366. if isinstance(class_path, str) and "." in class_path:
  367. module_path, class_name = class_path.rsplit(".", 1)
  368. try:
  369. spec = importlib.util.find_spec(module_path)
  370. if spec is not None:
  371. return True
  372. except (ModuleNotFoundError, ValueError):
  373. print(
  374. f"module {module_path} not found while trying to get "
  375. f"input {class_path}")
  376. return False
  377. if isinstance(config["input"], FunctionType):
  378. input_creator = config["input"]
  379. elif config["input"] == "sampler":
  380. input_creator = (lambda ioctx: ioctx.default_sampler_input())
  381. elif isinstance(config["input"], dict):
  382. input_creator = (
  383. lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx),
  384. config["shuffle_buffer_size"]))
  385. elif isinstance(config["input"], str) and \
  386. registry_contains_input(config["input"]):
  387. input_creator = registry_get_input(config["input"])
  388. elif "d4rl" in config["input"]:
  389. env_name = config["input"].split(".")[-1]
  390. input_creator = (lambda ioctx: D4RLReader(env_name, ioctx))
  391. elif valid_module(config["input"]):
  392. input_creator = (lambda ioctx: ShuffledInput(from_config(
  393. config["input"], ioctx=ioctx)))
  394. else:
  395. input_creator = (
  396. lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx),
  397. config["shuffle_buffer_size"]))
  398. if isinstance(config["output"], FunctionType):
  399. output_creator = config["output"]
  400. elif config["output"] is None:
  401. output_creator = (lambda ioctx: NoopOutput())
  402. elif config["output"] == "logdir":
  403. output_creator = (lambda ioctx: JsonWriter(
  404. ioctx.log_dir,
  405. ioctx,
  406. max_file_size=config["output_max_file_size"],
  407. compress_columns=config["output_compress_columns"]))
  408. else:
  409. output_creator = (lambda ioctx: JsonWriter(
  410. config["output"],
  411. ioctx,
  412. max_file_size=config["output_max_file_size"],
  413. compress_columns=config["output_compress_columns"]))
  414. if config["input"] == "sampler":
  415. input_evaluation = []
  416. else:
  417. input_evaluation = config["input_evaluation"]
  418. # Assert everything is correct in "multiagent" config dict (if given).
  419. ma_policies = config["multiagent"]["policies"]
  420. if ma_policies:
  421. for pid, policy_spec in ma_policies.copy().items():
  422. assert isinstance(policy_spec, PolicySpec)
  423. # Class is None -> Use `policy_cls`.
  424. if policy_spec.policy_class is None:
  425. ma_policies[pid] = ma_policies[pid]._replace(
  426. policy_class=policy_cls)
  427. policies = ma_policies
  428. # Create a policy_spec (MultiAgentPolicyConfigDict),
  429. # even if no "multiagent" setup given by user.
  430. else:
  431. policies = policy_cls
  432. if worker_index == 0:
  433. extra_python_environs = config.get(
  434. "extra_python_environs_for_driver", None)
  435. else:
  436. extra_python_environs = config.get(
  437. "extra_python_environs_for_worker", None)
  438. worker = cls(
  439. env_creator=env_creator,
  440. validate_env=validate_env,
  441. policy_spec=policies,
  442. policy_mapping_fn=config["multiagent"]["policy_mapping_fn"],
  443. policies_to_train=config["multiagent"]["policies_to_train"],
  444. tf_session_creator=(session_creator
  445. if config["tf_session_args"] else None),
  446. rollout_fragment_length=config["rollout_fragment_length"],
  447. count_steps_by=config["multiagent"]["count_steps_by"],
  448. batch_mode=config["batch_mode"],
  449. episode_horizon=config["horizon"],
  450. preprocessor_pref=config["preprocessor_pref"],
  451. sample_async=config["sample_async"],
  452. compress_observations=config["compress_observations"],
  453. num_envs=config["num_envs_per_worker"],
  454. observation_fn=config["multiagent"]["observation_fn"],
  455. observation_filter=config["observation_filter"],
  456. clip_rewards=config["clip_rewards"],
  457. normalize_actions=config["normalize_actions"],
  458. clip_actions=config["clip_actions"],
  459. env_config=config["env_config"],
  460. policy_config=config,
  461. worker_index=worker_index,
  462. num_workers=num_workers,
  463. record_env=config["record_env"],
  464. log_dir=self._logdir,
  465. log_level=config["log_level"],
  466. callbacks=config["callbacks"],
  467. input_creator=input_creator,
  468. input_evaluation=input_evaluation,
  469. output_creator=output_creator,
  470. remote_worker_envs=config["remote_worker_envs"],
  471. remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"],
  472. soft_horizon=config["soft_horizon"],
  473. no_done_at_end=config["no_done_at_end"],
  474. seed=(config["seed"] + worker_index)
  475. if config["seed"] is not None else None,
  476. fake_sampler=config["fake_sampler"],
  477. extra_python_environs=extra_python_environs,
  478. spaces=spaces,
  479. )
  480. return worker