tree_agg.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import logging
  2. import platform
  3. from typing import List, Dict, Any
  4. import ray
  5. from ray.rllib.evaluation.worker_set import WorkerSet
  6. from ray.rllib.execution.common import AGENT_STEPS_SAMPLED_COUNTER, \
  7. STEPS_SAMPLED_COUNTER, _get_shared_metrics
  8. from ray.rllib.execution.replay_ops import MixInReplay
  9. from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
  10. from ray.rllib.policy.sample_batch import MultiAgentBatch
  11. from ray.rllib.utils.actors import create_colocated
  12. from ray.rllib.utils.typing import SampleBatchType, ModelWeights
  13. from ray.util.iter import ParallelIterator, ParallelIteratorWorker, \
  14. from_actors, LocalIterator
  15. logger = logging.getLogger(__name__)
  16. @ray.remote(num_cpus=0)
  17. class Aggregator(ParallelIteratorWorker):
  18. """An aggregation worker used by gather_experiences_tree_aggregation().
  19. Each of these actors is a shard of a parallel iterator that consumes
  20. batches from RolloutWorker actors, and emits batches of size
  21. train_batch_size. This allows expensive decompression / concatenation
  22. work to be offloaded to these actors instead of run in the learner.
  23. """
  24. def __init__(self, config: Dict,
  25. rollout_group: "ParallelIterator[SampleBatchType]"):
  26. self.weights = None
  27. self.global_vars = None
  28. def generator():
  29. it = rollout_group.gather_async(
  30. num_async=config["max_sample_requests_in_flight_per_worker"])
  31. # Update the rollout worker with our latest policy weights.
  32. def update_worker(item):
  33. worker, batch = item
  34. if self.weights:
  35. worker.set_weights.remote(self.weights, self.global_vars)
  36. return batch
  37. # Augment with replay and concat to desired train batch size.
  38. it = it.zip_with_source_actor() \
  39. .for_each(update_worker) \
  40. .for_each(lambda batch: batch.decompress_if_needed()) \
  41. .for_each(MixInReplay(
  42. num_slots=config["replay_buffer_num_slots"],
  43. replay_proportion=config["replay_proportion"])) \
  44. .flatten() \
  45. .combine(
  46. ConcatBatches(
  47. min_batch_size=config["train_batch_size"],
  48. count_steps_by=config["multiagent"]["count_steps_by"],
  49. ))
  50. for train_batch in it:
  51. yield train_batch
  52. super().__init__(generator, repeat=False)
  53. def get_host(self) -> str:
  54. return platform.node()
  55. def set_weights(self, weights: ModelWeights, global_vars: Dict) -> None:
  56. self.weights = weights
  57. self.global_vars = global_vars
  58. def gather_experiences_tree_aggregation(workers: WorkerSet,
  59. config: Dict) -> "LocalIterator[Any]":
  60. """Tree aggregation version of gather_experiences_directly()."""
  61. rollouts = ParallelRollouts(workers, mode="raw")
  62. # Divide up the workers between aggregators.
  63. worker_assignments = [[] for _ in range(config["num_aggregation_workers"])]
  64. i = 0
  65. for worker_idx in range(len(workers.remote_workers())):
  66. worker_assignments[i].append(worker_idx)
  67. i += 1
  68. i %= len(worker_assignments)
  69. logger.info("Worker assignments: {}".format(worker_assignments))
  70. # Create parallel iterators that represent each aggregation group.
  71. rollout_groups: List["ParallelIterator[SampleBatchType]"] = [
  72. rollouts.select_shards(assigned) for assigned in worker_assignments
  73. ]
  74. # This spawns |num_aggregation_workers| intermediate actors that aggregate
  75. # experiences in parallel. We force colocation on the same node to maximize
  76. # data bandwidth between them and the driver.
  77. train_batches = from_actors([
  78. create_colocated(Aggregator, [config, g], 1)[0] for g in rollout_groups
  79. ])
  80. # TODO(ekl) properly account for replay.
  81. def record_steps_sampled(batch):
  82. metrics = _get_shared_metrics()
  83. metrics.counters[STEPS_SAMPLED_COUNTER] += batch.count
  84. if isinstance(batch, MultiAgentBatch):
  85. metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += \
  86. batch.agent_steps()
  87. else:
  88. metrics.counters[AGENT_STEPS_SAMPLED_COUNTER] += batch.count
  89. return batch
  90. return train_batches.gather_async().for_each(record_steps_sampled)