123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- import logging
- import platform
- from typing import List, Dict, Any
- import ray
- from ray.rllib.evaluation.worker_set import WorkerSet
- from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
- STEPS_SAMPLED_COUNTER, _get_shared_metrics
- from ray.rllib.execution.replay_ops import MixInReplay
- from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
- from ray.rllib.policy.sample_batch import MultiAgentBatch
- from ray.rllib.utils.actors import create_colocated_actors
- from ray.rllib.utils.typing import SampleBatchType, ModelWeights
- from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
- from_actors, LocalIterator
- logger = logging.getLogger(__name__)
- @ray.remote(num_cpus=0)
- class Aggregator(ParallelIteratorWorker):
- """An aggregation worker used by gather_experiences_tree_aggregation().
- Each of these actors is a shard of a parallel iterator that consumes
- batches from RolloutWorker actors, and emits batches of size
- train_batch_size. This allows expensive decompression / concatenation
- work to be offloaded to these actors instead of run in the learner.
- """
- def __init__(self, config: Dict,
- rollout_group: "ParallelIterator[SampleBatchType]"):
- self.weights = None
- self.global_vars = None
- def generator():
- it = rollout_group.gather_async(
- num_async=config["max_sample_requests_in_flight_per_worker"])
- # Update the rollout worker with our latest policy weights.
- def update_worker(item):
- worker, batch = item
- if self.weights:
- worker.set_weights.remote(self.weights, self.global_vars)
- return batch
- # Augment with replay and concat to desired train batch size.
- it = it.zip_with_source_actor() \
- .for_each(update_worker) \
- .for_each(lambda batch: batch.decompress_if_needed()) \
- .for_each(MixInReplay(
- num_slots=config["replay_buffer_num_slots"],
- replay_proportion=config["replay_proportion"])) \
- .flatten() \
- .combine(
- ConcatBatches(
- min_batch_size=config["train_batch_size"],
- count_steps_by=config["multiagent"]["count_steps_by"],
- ))
- for train_batch in it:
- yield train_batch
- super().__init__(generator, repeat=False)
- def get_host(self) -> str:
- return platform.node()
- def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None:
- self.weights = weights
- self.global_vars = global_vars
- def gather_experiences_tree_aggregation(workers: WorkerSet,
- config: Dict) -> "LocalIterator[Any]":
- """Tree aggregation version of gather_experiences_directly()."""
- rollouts = ParallelRollouts(workers, mode="raw")
- # Divide up the workers between aggregators.
- worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
- i = 0
- for worker_idx in range(len(workers.remote_workers())):
- worker_assignments[i].append(worker_idx)
- i += 1
- i %= len(worker_assignments)
- logger.info("Worker assignments: {}".format(worker_assignments))
- # Create parallel iterators that represent each aggregation group.
- rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
- rollouts.select_shards(assigned) for assigned in worker_assignments
- ]
- # This spawns |num_aggregation_workers| intermediate actors that aggregate
- # experiences in parallel. We force colocation on the same node (localhost)
- # to maximize data bandwidth between them and the driver.
- localhost = platform.node()
- assert localhost != "", \
- "ERROR: Cannot determine local node name! " \
- "`platform.node()` returned empty string."
- all_co_located = create_colocated_actors(
- actor_specs=[
- # (class, args, kwargs={}, count=1)
- (Aggregator, [config, g], {}, 1) for g in rollout_groups
- ],
- node=localhost)
- # Use the first ([0]) of each created group (each group only has one
- # actor: count=1).
- train_batches = from_actors([group[0] for group in all_co_located])
- # TODO(ekl) properly account for replay.
- def record_steps_sampled(batch):
- metrics = _get_shared_metrics()
- metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
- if isinstance(batch, MultiAgentBatch):
- metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
- batch.agent_steps()
- else:
- metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
- return batch
- return train_batches.gather_async().for_each(record_steps_sampled)
|