rollout_ops.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import logging
  2. import time
  3. from typing import Any, Callable, Dict, List, Optional, Tuple, \
  4. TYPE_CHECKING
  5. import ray
  6. from ray.actor import ActorHandle
  7. from ray.rllib.evaluation.rollout_worker import get_global_worker
  8. from ray.rllib.evaluation.worker_set import WorkerSet
  9. from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
  10. STEPS_SAMPLED_COUNTER, SAMPLE_TIMER, GRAD_WAIT_TIMER, \
  11. _check_sample_batch_type, _get_shared_metrics
  12. from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
  13. MultiAgentBatch
  14. from ray.rllib.utils.annotations import ExperimentalAPI
  15. from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, \
  16. LEARNER_STATS_KEY
  17. from ray.rllib.utils.sgd import standardized
  18. from ray.rllib.utils.typing import PolicyID, SampleBatchType, ModelGradients
  19. from ray.util.iter import from_actors, LocalIterator
  20. from ray.util.iter_metrics import SharedMetrics
  21. if TYPE_CHECKING:
  22. from ray.rllib.agents.trainer import Trainer
  23. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  24. logger = logging.getLogger(__name__)
  25. @ExperimentalAPI
  26. def synchronous_parallel_sample(
  27. worker_set: WorkerSet,
  28. remote_fn: Optional[Callable[["RolloutWorker"], None]] = None,
  29. ) -> List[SampleBatch]:
  30. """Runs parallel and synchronous rollouts on all remote workers.
  31. Waits for all workers to return from the remote calls.
  32. If no remote workers exist (num_workers == 0), use the local worker
  33. for sampling.
  34. Alternatively to calling `worker.sample.remote()`, the user can provide a
  35. `remote_fn()`, which will be applied to the worker(s) instead.
  36. Args:
  37. worker_set: The WorkerSet to use for sampling.
  38. remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
  39. of `worker.sample.remote()` to generate the requests.
  40. Returns:
  41. The list of collected sample batch types (one for each parallel
  42. rollout worker in the given `worker_set`).
  43. Examples:
  44. >>> # 2 remote workers (num_workers=2):
  45. >>> batches = synchronous_parallel_sample(trainer.workers)
  46. >>> print(len(batches))
  47. ... 2
  48. >>> print(batches[0])
  49. ... SampleBatch(16: ['obs', 'actions', 'rewards', 'dones'])
  50. >>> # 0 remote workers (num_workers=0): Using the local worker.
  51. >>> batches = synchronous_parallel_sample(trainer.workers)
  52. >>> print(len(batches))
  53. ... 1
  54. """
  55. # No remote workers in the set -> Use local worker for collecting
  56. # samples.
  57. if not worker_set.remote_workers():
  58. return [worker_set.local_worker().sample()]
  59. # Loop over remote workers' `sample()` method in parallel.
  60. sample_batches = ray.get(
  61. [r.sample.remote() for r in worker_set.remote_workers()])
  62. # Return all collected batches.
  63. return sample_batches
  64. # TODO: Move to generic parallel ops module and rename to
  65. # `asynchronous_parallel_requests`:
  66. @ExperimentalAPI
  67. def asynchronous_parallel_sample(
  68. trainer: "Trainer",
  69. actors: List[ActorHandle],
  70. ray_wait_timeout_s: Optional[float] = None,
  71. max_remote_requests_in_flight_per_actor: int = 2,
  72. remote_fn: Optional[Callable[["RolloutWorker"], None]] = None,
  73. remote_args: Optional[List[List[Any]]] = None,
  74. remote_kwargs: Optional[List[Dict[str, Any]]] = None,
  75. ) -> Optional[List[SampleBatch]]:
  76. """Runs parallel and asynchronous rollouts on all remote workers.
  77. May use a timeout (if provided) on `ray.wait()` and returns only those
  78. samples that could be gathered in the timeout window. Allows a maximum
  79. of `max_remote_requests_in_flight_per_actor` remote calls to be in-flight
  80. per remote actor.
  81. Alternatively to calling `actor.sample.remote()`, the user can provide a
  82. `remote_fn()`, which will be applied to the actor(s) instead.
  83. Args:
  84. trainer: The Trainer object that we run the sampling for.
  85. actors: The List of ActorHandles to perform the remote requests on.
  86. ray_wait_timeout_s: Timeout (in sec) to be used for the underlying
  87. `ray.wait()` calls. If None (default), never time out (block
  88. until at least one actor returns something).
  89. max_remote_requests_in_flight_per_actor: Maximum number of remote
  90. requests sent to each actor. 2 (default) is probably
  91. sufficient to avoid idle times between two requests.
  92. remote_fn: If provided, use `actor.apply.remote(remote_fn)` instead of
  93. `actor.sample.remote()` to generate the requests.
  94. remote_args: If provided, use this list (per-actor) of lists (call
  95. args) as *args to be passed to the `remote_fn`.
  96. E.g.: actors=[A, B],
  97. remote_args=[[...] <- *args for A, [...] <- *args for B].
  98. remote_kwargs: If provided, use this list (per-actor) of dicts
  99. (kwargs) as **kwargs to be passed to the `remote_fn`.
  100. E.g.: actors=[A, B],
  101. remote_kwargs=[{...} <- **kwargs for A, {...} <- **kwargs for B].
  102. Returns:
  103. The list of asynchronously collected sample batch types. None, if no
  104. samples are ready.
  105. Examples:
  106. >>> # 2 remote rollout workers (num_workers=2):
  107. >>> batches = asynchronous_parallel_sample(
  108. ... trainer,
  109. ... actors=trainer.workers.remote_workers(),
  110. ... ray_wait_timeout_s=0.1,
  111. ... remote_fn=lambda w: time.sleep(1) # sleep 1sec
  112. ... )
  113. >>> print(len(batches))
  114. ... 2
  115. >>> # Expect a timeout to have happened.
  116. >>> batches[0] is None and batches[1] is None
  117. ... True
  118. """
  119. if remote_args is not None:
  120. assert len(remote_args) == len(actors)
  121. if remote_kwargs is not None:
  122. assert len(remote_kwargs) == len(actors)
  123. # Collect all currently pending remote requests into a single set of
  124. # object refs.
  125. pending_remotes = set()
  126. # Also build a map to get the associated actor for each remote request.
  127. remote_to_actor = {}
  128. for actor, set_ in trainer.remote_requests_in_flight.items():
  129. pending_remotes |= set_
  130. for r in set_:
  131. remote_to_actor[r] = actor
  132. # Add new requests, if possible (if
  133. # `max_remote_requests_in_flight_per_actor` setting allows it).
  134. for actor_idx, actor in enumerate(actors):
  135. # Still room for another request to this actor.
  136. if len(trainer.remote_requests_in_flight[actor]) < \
  137. max_remote_requests_in_flight_per_actor:
  138. if remote_fn is None:
  139. req = actor.sample.remote()
  140. else:
  141. args = remote_args[actor_idx] if remote_args else []
  142. kwargs = remote_kwargs[actor_idx] if remote_kwargs else {}
  143. req = actor.apply.remote(remote_fn, *args, **kwargs)
  144. # Add to our set to send to ray.wait().
  145. pending_remotes.add(req)
  146. # Keep our mappings properly updated.
  147. trainer.remote_requests_in_flight[actor].add(req)
  148. remote_to_actor[req] = actor
  149. # There must always be pending remote requests.
  150. assert len(pending_remotes) > 0
  151. pending_remote_list = list(pending_remotes)
  152. # No timeout: Block until at least one result is returned.
  153. if ray_wait_timeout_s is None:
  154. # First try to do a `ray.wait` w/o timeout for efficiency.
  155. ready, _ = ray.wait(
  156. pending_remote_list, num_returns=len(pending_remotes), timeout=0)
  157. # Nothing returned and `timeout` is None -> Fall back to a
  158. # blocking wait to make sure we can return something.
  159. if not ready:
  160. ready, _ = ray.wait(pending_remote_list, num_returns=1)
  161. # Timeout: Do a `ray.wait() call` w/ timeout.
  162. else:
  163. ready, _ = ray.wait(
  164. pending_remote_list,
  165. num_returns=len(pending_remotes),
  166. timeout=ray_wait_timeout_s)
  167. # Return None if nothing ready after the timeout.
  168. if not ready:
  169. return None
  170. for obj_ref in ready:
  171. # Remove in-flight record for this ref.
  172. trainer.remote_requests_in_flight[remote_to_actor[obj_ref]].remove(
  173. obj_ref)
  174. remote_to_actor.pop(obj_ref)
  175. results = ray.get(ready)
  176. return results
  177. def ParallelRollouts(workers: WorkerSet, *, mode="bulk_sync",
  178. num_async=1) -> LocalIterator[SampleBatch]:
  179. """Operator to collect experiences in parallel from rollout workers.
  180. If there are no remote workers, experiences will be collected serially from
  181. the local worker instance instead.
  182. Args:
  183. workers (WorkerSet): set of rollout workers to use.
  184. mode (str): One of 'async', 'bulk_sync', 'raw'. In 'async' mode,
  185. batches are returned as soon as they are computed by rollout
  186. workers with no order guarantees. In 'bulk_sync' mode, we collect
  187. one batch from each worker and concatenate them together into a
  188. large batch to return. In 'raw' mode, the ParallelIterator object
  189. is returned directly and the caller is responsible for implementing
  190. gather and updating the timesteps counter.
  191. num_async (int): In async mode, the max number of async
  192. requests in flight per actor.
  193. Returns:
  194. A local iterator over experiences collected in parallel.
  195. Examples:
  196. >>> rollouts = ParallelRollouts(workers, mode="async")
  197. >>> batch = next(rollouts)
  198. >>> print(batch.count)
  199. 50 # config.rollout_fragment_length
  200. >>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
  201. >>> batch = next(rollouts)
  202. >>> print(batch.count)
  203. 200 # config.rollout_fragment_length * config.num_workers
  204. Updates the STEPS_SAMPLED_COUNTER counter in the local iterator context.
  205. """
  206. # Ensure workers are initially in sync.
  207. workers.sync_weights()
  208. def report_timesteps(batch):
  209. metrics = _get_shared_metrics()
  210. metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
  211. if isinstance(batch, MultiAgentBatch):
  212. metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
  213. batch.agent_steps()
  214. else:
  215. metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
  216. return batch
  217. if not workers.remote_workers():
  218. # Handle the `num_workers=0` case, in which the local worker
  219. # has to do sampling as well.
  220. def sampler(_):
  221. while True:
  222. yield workers.local_worker().sample()
  223. return (LocalIterator(sampler,
  224. SharedMetrics()).for_each(report_timesteps))
  225. # Create a parallel iterator over generated experiences.
  226. rollouts = from_actors(workers.remote_workers())
  227. if mode == "bulk_sync":
  228. return rollouts \
  229. .batch_across_shards() \
  230. .for_each(lambda batches: SampleBatch.concat_samples(batches)) \
  231. .for_each(report_timesteps)
  232. elif mode == "async":
  233. return rollouts.gather_async(
  234. num_async=num_async).for_each(report_timesteps)
  235. elif mode == "raw":
  236. return rollouts
  237. else:
  238. raise ValueError("mode must be one of 'bulk_sync', 'async', 'raw', "
  239. "got '{}'".format(mode))
  240. def AsyncGradients(
  241. workers: WorkerSet) -> LocalIterator[Tuple[ModelGradients, int]]:
  242. """Operator to compute gradients in parallel from rollout workers.
  243. Args:
  244. workers (WorkerSet): set of rollout workers to use.
  245. Returns:
  246. A local iterator over policy gradients computed on rollout workers.
  247. Examples:
  248. >>> grads_op = AsyncGradients(workers)
  249. >>> print(next(grads_op))
  250. {"var_0": ..., ...}, 50 # grads, batch count
  251. Updates the STEPS_SAMPLED_COUNTER counter and LEARNER_INFO field in the
  252. local iterator context.
  253. """
  254. # Ensure workers are initially in sync.
  255. workers.sync_weights()
  256. # This function will be applied remotely on the workers.
  257. def samples_to_grads(samples):
  258. return get_global_worker().compute_gradients(samples), samples.count
  259. # Record learner metrics and pass through (grads, count).
  260. class record_metrics:
  261. def _on_fetch_start(self):
  262. self.fetch_start_time = time.perf_counter()
  263. def __call__(self, item):
  264. (grads, info), count = item
  265. metrics = _get_shared_metrics()
  266. metrics.counters[STEPS_SAMPLED_COUNTER] += count
  267. metrics.info[LEARNER_INFO] = {
  268. DEFAULT_POLICY_ID: info
  269. } if LEARNER_STATS_KEY in info else info
  270. metrics.timers[GRAD_WAIT_TIMER].push(time.perf_counter() -
  271. self.fetch_start_time)
  272. return grads, count
  273. rollouts = from_actors(workers.remote_workers())
  274. grads = rollouts.for_each(samples_to_grads)
  275. return grads.gather_async().for_each(record_metrics())
  276. class ConcatBatches:
  277. """Callable used to merge batches into larger batches for training.
  278. This should be used with the .combine() operator.
  279. Examples:
  280. >>> rollouts = ParallelRollouts(...)
  281. >>> rollouts = rollouts.combine(ConcatBatches(
  282. ... min_batch_size=10000, count_steps_by="env_steps"))
  283. >>> print(next(rollouts).count)
  284. 10000
  285. """
  286. def __init__(self, min_batch_size: int, count_steps_by: str = "env_steps"):
  287. self.min_batch_size = min_batch_size
  288. self.count_steps_by = count_steps_by
  289. self.buffer = []
  290. self.count = 0
  291. self.last_batch_time = time.perf_counter()
  292. def __call__(self, batch: SampleBatchType) -> List[SampleBatchType]:
  293. _check_sample_batch_type(batch)
  294. if self.count_steps_by == "env_steps":
  295. size = batch.count
  296. else:
  297. assert isinstance(batch, MultiAgentBatch), \
  298. "`count_steps_by=agent_steps` only allowed in multi-agent " \
  299. "environments!"
  300. size = batch.agent_steps()
  301. # Incoming batch is an empty dummy batch -> Ignore.
  302. # Possibly produced automatically by a PolicyServer to unblock
  303. # an external env waiting for inputs from unresponsive/disconnected
  304. # client(s).
  305. if size == 0:
  306. return []
  307. self.count += size
  308. self.buffer.append(batch)
  309. if self.count >= self.min_batch_size:
  310. if self.count > self.min_batch_size * 2:
  311. logger.info("Collected more training samples than expected "
  312. "(actual={}, expected={}). ".format(
  313. self.count, self.min_batch_size) +
  314. "This may be because you have many workers or "
  315. "long episodes in 'complete_episodes' batch mode.")
  316. out = SampleBatch.concat_samples(self.buffer)
  317. perf_counter = time.perf_counter()
  318. timer = _get_shared_metrics().timers[SAMPLE_TIMER]
  319. timer.push(perf_counter - self.last_batch_time)
  320. timer.push_units_processed(self.count)
  321. self.last_batch_time = perf_counter
  322. self.buffer = []
  323. self.count = 0
  324. return [out]
  325. return []
  326. class SelectExperiences:
  327. """Callable used to select experiences from a MultiAgentBatch.
  328. This should be used with the .for_each() operator.
  329. Examples:
  330. >>> rollouts = ParallelRollouts(...)
  331. >>> rollouts = rollouts.for_each(SelectExperiences(["pol1", "pol2"]))
  332. >>> print(next(rollouts).policy_batches.keys())
  333. {"pol1", "pol2"}
  334. """
  335. def __init__(self, policy_ids: List[PolicyID]):
  336. assert isinstance(policy_ids, list), policy_ids
  337. self.policy_ids = policy_ids
  338. def __call__(self, samples: SampleBatchType) -> SampleBatchType:
  339. _check_sample_batch_type(samples)
  340. if isinstance(samples, MultiAgentBatch):
  341. samples = MultiAgentBatch({
  342. k: v
  343. for k, v in samples.policy_batches.items()
  344. if k in self.policy_ids
  345. }, samples.count)
  346. return samples
  347. class StandardizeFields:
  348. """Callable used to standardize fields of batches.
  349. This should be used with the .for_each() operator. Note that the input
  350. may be mutated by this operator for efficiency.
  351. Examples:
  352. >>> rollouts = ParallelRollouts(...)
  353. >>> rollouts = rollouts.for_each(StandardizeFields(["advantages"]))
  354. >>> print(np.std(next(rollouts)["advantages"]))
  355. 1.0
  356. """
  357. def __init__(self, fields: List[str]):
  358. self.fields = fields
  359. def __call__(self, samples: SampleBatchType) -> SampleBatchType:
  360. _check_sample_batch_type(samples)
  361. wrapped = False
  362. if isinstance(samples, SampleBatch):
  363. samples = samples.as_multi_agent()
  364. wrapped = True
  365. for policy_id in samples.policy_batches:
  366. batch = samples.policy_batches[policy_id]
  367. for field in self.fields:
  368. if field not in batch:
  369. raise KeyError(
  370. f"`{field}` not found in SampleBatch for policy "
  371. f"`{policy_id}`! Maybe this policy fails to add "
  372. f"{field} in its `postprocess_trajectory` method? Or "
  373. "this policy is not meant to learn at all and you "
  374. "forgot to add it to the list under `config."
  375. "multiagent.policies_to_train`.")
  376. batch[field] = standardized(batch[field])
  377. if wrapped:
  378. samples = samples.policy_batches[DEFAULT_POLICY_ID]
  379. return samples