learner_thread.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import copy
  2. import queue
  3. import threading
  4. from typing import Dict, Optional
  5. from ray.util.timer import _Timer
  6. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  7. from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
  8. from ray.rllib.utils.framework import try_import_tf
  9. from ray.rllib.utils.deprecation import deprecation_warning
  10. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, LEARNER_INFO
  11. from ray.rllib.utils.metrics.window_stat import WindowStat
  12. from ray.util.iter import _NextValueNotReady
  13. from ray.util import log_once
  14. tf1, tf, tfv = try_import_tf()
  15. class LearnerThread(threading.Thread):
  16. """Background thread that updates the local model from sample trajectories.
  17. The learner thread communicates with the main thread through Queues. This
  18. is needed since Ray operations can only be run on the main thread. In
  19. addition, moving heavyweight gradient ops session runs off the main thread
  20. improves overall throughput.
  21. """
  22. def __init__(
  23. self,
  24. local_worker: RolloutWorker,
  25. minibatch_buffer_size: int,
  26. num_sgd_iter: int,
  27. learner_queue_size: int,
  28. learner_queue_timeout: int,
  29. ):
  30. """Initialize the learner thread.
  31. Args:
  32. local_worker: process local rollout worker holding
  33. policies this thread will call learn_on_batch() on
  34. minibatch_buffer_size: max number of train batches to store
  35. in the minibatching buffer
  36. num_sgd_iter: number of passes to learn on per train batch
  37. learner_queue_size: max size of queue of inbound
  38. train batches to this thread
  39. learner_queue_timeout: raise an exception if the queue has
  40. been empty for this long in seconds
  41. """
  42. threading.Thread.__init__(self)
  43. self.learner_queue_size = WindowStat("size", 50)
  44. self.local_worker = local_worker
  45. self.inqueue = queue.Queue(maxsize=learner_queue_size)
  46. self.outqueue = queue.Queue()
  47. self.minibatch_buffer = MinibatchBuffer(
  48. inqueue=self.inqueue,
  49. size=minibatch_buffer_size,
  50. timeout=learner_queue_timeout,
  51. num_passes=num_sgd_iter,
  52. init_num_passes=num_sgd_iter,
  53. )
  54. self.queue_timer = _Timer()
  55. self.grad_timer = _Timer()
  56. self.load_timer = _Timer()
  57. self.load_wait_timer = _Timer()
  58. self.daemon = True
  59. self.policy_ids_updated = []
  60. self.learner_info = {}
  61. self.stopped = False
  62. self.num_steps = 0
  63. if log_once("learner-thread-deprecation-warning"):
  64. deprecation_warning(old="ray.rllib.execution.learner_thread.LearnerThread")
  65. def run(self) -> None:
  66. # Switch on eager mode if configured.
  67. if self.local_worker.config.framework_str == "tf2":
  68. tf1.enable_eager_execution()
  69. while not self.stopped:
  70. self.step()
  71. def step(self) -> Optional[_NextValueNotReady]:
  72. with self.queue_timer:
  73. try:
  74. batch, _ = self.minibatch_buffer.get()
  75. except queue.Empty:
  76. return _NextValueNotReady()
  77. with self.grad_timer:
  78. # Use LearnerInfoBuilder as a unified way to build the final
  79. # results dict from `learn_on_loaded_batch` call(s).
  80. # This makes sure results dicts always have the same structure
  81. # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
  82. # tf vs torch).
  83. learner_info_builder = LearnerInfoBuilder(num_devices=1)
  84. if self.local_worker.config.policy_states_are_swappable:
  85. self.local_worker.lock()
  86. multi_agent_results = self.local_worker.learn_on_batch(batch)
  87. if self.local_worker.config.policy_states_are_swappable:
  88. self.local_worker.unlock()
  89. self.policy_ids_updated.extend(list(multi_agent_results.keys()))
  90. for pid, results in multi_agent_results.items():
  91. learner_info_builder.add_learn_on_batch_results(results, pid)
  92. self.learner_info = learner_info_builder.finalize()
  93. self.num_steps += 1
  94. # Put tuple: env-steps, agent-steps, and learner info into the queue.
  95. self.outqueue.put((batch.count, batch.agent_steps(), self.learner_info))
  96. self.learner_queue_size.push(self.inqueue.qsize())
  97. def add_learner_metrics(self, result: Dict, overwrite_learner_info=True) -> Dict:
  98. """Add internal metrics to a result dict."""
  99. def timer_to_ms(timer):
  100. return round(1000 * timer.mean, 3)
  101. if overwrite_learner_info:
  102. result["info"].update(
  103. {
  104. "learner_queue": self.learner_queue_size.stats(),
  105. LEARNER_INFO: copy.deepcopy(self.learner_info),
  106. "timing_breakdown": {
  107. "learner_grad_time_ms": timer_to_ms(self.grad_timer),
  108. "learner_load_time_ms": timer_to_ms(self.load_timer),
  109. "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
  110. "learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
  111. },
  112. }
  113. )
  114. else:
  115. result["info"].update(
  116. {
  117. "learner_queue": self.learner_queue_size.stats(),
  118. "timing_breakdown": {
  119. "learner_grad_time_ms": timer_to_ms(self.grad_timer),
  120. "learner_load_time_ms": timer_to_ms(self.load_timer),
  121. "learner_load_wait_time_ms": timer_to_ms(self.load_wait_timer),
  122. "learner_dequeue_time_ms": timer_to_ms(self.queue_timer),
  123. },
  124. }
  125. )
  126. return result