multi_gpu_learner_thread.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import logging
  2. from six.moves import queue
  3. import threading
  4. from ray.rllib.execution.learner_thread import LearnerThread
  5. from ray.rllib.execution.minibatch_buffer import MinibatchBuffer
  6. from ray.rllib.policy.sample_batch import SampleBatch
  7. from ray.rllib.utils.annotations import override
  8. from ray.rllib.utils.deprecation import deprecation_warning
  9. from ray.rllib.utils.framework import try_import_tf
  10. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder, \
  11. LEARNER_STATS_KEY
  12. from ray.rllib.utils.timer import TimerStat
  13. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  14. tf1, tf, tfv = try_import_tf()
  15. logger = logging.getLogger(__name__)
  16. class MultiGPULearnerThread(LearnerThread):
  17. """Learner that can use multiple GPUs and parallel loading.
  18. This class is used for async sampling algorithms.
  19. Example workflow: 2 GPUs and 3 multi-GPU tower stacks.
  20. -> On each GPU, there are 3 slots for batches, indexed 0, 1, and 2.
  21. Workers collect data from env and push it into inqueue:
  22. Workers -> (data) -> self.inqueue
  23. We also have two queues, indicating, which stacks are loaded and which
  24. are not.
  25. - idle_tower_stacks = [0, 1, 2] <- all 3 stacks are free at first.
  26. - ready_tower_stacks = [] <- None of the 3 stacks is loaded with data.
  27. `ready_tower_stacks` is managed by `ready_tower_stacks_buffer` for
  28. possible minibatch-SGD iterations per loaded batch (this avoids a reload
  29. from CPU to GPU for each SGD iter).
  30. n _MultiGPULoaderThreads: self.inqueue -get()->
  31. policy.load_batch_into_buffer() -> ready_stacks = [0 ...]
  32. This thread: self.ready_tower_stacks_buffer -get()->
  33. policy.learn_on_loaded_batch() -> if SGD-iters done,
  34. put stack index back in idle_tower_stacks queue.
  35. """
  36. def __init__(
  37. self,
  38. local_worker: RolloutWorker,
  39. num_gpus: int = 1,
  40. lr=None, # deprecated.
  41. train_batch_size: int = 500,
  42. num_multi_gpu_tower_stacks: int = 1,
  43. num_sgd_iter: int = 1,
  44. learner_queue_size: int = 16,
  45. learner_queue_timeout: int = 300,
  46. num_data_load_threads: int = 16,
  47. _fake_gpus: bool = False,
  48. # Deprecated arg, use
  49. minibatch_buffer_size=None,
  50. ):
  51. """Initializes a MultiGPULearnerThread instance.
  52. Args:
  53. local_worker (RolloutWorker): Local RolloutWorker holding
  54. policies this thread will call `load_batch_into_buffer` and
  55. `learn_on_loaded_batch` on.
  56. num_gpus (int): Number of GPUs to use for data-parallel SGD.
  57. train_batch_size (int): Size of batches (minibatches if
  58. `num_sgd_iter` > 1) to learn on.
  59. num_multi_gpu_tower_stacks (int): Number of buffers to parallelly
  60. load data into on one device. Each buffer is of size of
  61. `train_batch_size` and hence increases GPU memory usage
  62. accordingly.
  63. num_sgd_iter (int): Number of passes to learn on per train batch
  64. (minibatch if `num_sgd_iter` > 1).
  65. learner_queue_size (int): Max size of queue of inbound
  66. train batches to this thread.
  67. num_data_load_threads (int): Number of threads to use to load
  68. data into GPU memory in parallel.
  69. """
  70. # Deprecated: No need to specify as we don't need the actual
  71. # minibatch-buffer anyways.
  72. if minibatch_buffer_size:
  73. deprecation_warning(
  74. old="MultiGPULearnerThread.minibatch_buffer_size",
  75. error=False,
  76. )
  77. super().__init__(
  78. local_worker=local_worker,
  79. minibatch_buffer_size=0,
  80. num_sgd_iter=num_sgd_iter,
  81. learner_queue_size=learner_queue_size,
  82. learner_queue_timeout=learner_queue_timeout,
  83. )
  84. # Delete reference to parent's minibatch_buffer, which is not needed.
  85. # Instead, in multi-GPU mode, we pull tower stack indices from the
  86. # `self.ready_tower_stacks_buffer` buffer, whose size is exactly
  87. # `num_multi_gpu_tower_stacks`.
  88. self.minibatch_buffer = None
  89. self.train_batch_size = train_batch_size
  90. self.policy_map = self.local_worker.policy_map
  91. self.devices = next(iter(self.policy_map.values())).devices
  92. logger.info("MultiGPULearnerThread devices {}".format(self.devices))
  93. assert self.train_batch_size % len(self.devices) == 0
  94. assert self.train_batch_size >= len(self.devices),\
  95. "batch too small"
  96. self.tower_stack_indices = list(range(num_multi_gpu_tower_stacks))
  97. # Two queues for tower stacks:
  98. # a) Those that are loaded with data ("ready")
  99. # b) Those that are ready to be loaded with new data ("idle").
  100. self.idle_tower_stacks = queue.Queue()
  101. self.ready_tower_stacks = queue.Queue()
  102. # In the beginning, all stacks are idle (no loading has taken place
  103. # yet).
  104. for idx in self.tower_stack_indices:
  105. self.idle_tower_stacks.put(idx)
  106. # Start n threads that are responsible for loading data into the
  107. # different (idle) stacks.
  108. for i in range(num_data_load_threads):
  109. self.loader_thread = _MultiGPULoaderThread(
  110. self, share_stats=(i == 0))
  111. self.loader_thread.start()
  112. # Create a buffer that holds stack indices that are "ready"
  113. # (loaded with data). Those are stacks that we can call
  114. # "learn_on_loaded_batch" on.
  115. self.ready_tower_stacks_buffer = MinibatchBuffer(
  116. self.ready_tower_stacks, num_multi_gpu_tower_stacks,
  117. learner_queue_timeout, num_sgd_iter)
  118. @override(LearnerThread)
  119. def step(self) -> None:
  120. assert self.loader_thread.is_alive()
  121. with self.load_wait_timer:
  122. buffer_idx, released = self.ready_tower_stacks_buffer.get()
  123. get_num_samples_loaded_into_buffer = 0
  124. with self.grad_timer:
  125. # Use LearnerInfoBuilder as a unified way to build the final
  126. # results dict from `learn_on_loaded_batch` call(s).
  127. # This makes sure results dicts always have the same structure
  128. # no matter the setup (multi-GPU, multi-agent, minibatch SGD,
  129. # tf vs torch).
  130. learner_info_builder = LearnerInfoBuilder(
  131. num_devices=len(self.devices))
  132. for pid in self.policy_map.keys():
  133. # Not a policy-to-train.
  134. if pid not in self.local_worker.policies_to_train:
  135. continue
  136. policy = self.policy_map[pid]
  137. default_policy_results = policy.learn_on_loaded_batch(
  138. offset=0, buffer_index=buffer_idx)
  139. learner_info_builder.add_learn_on_batch_results(
  140. default_policy_results)
  141. self.weights_updated = True
  142. get_num_samples_loaded_into_buffer += \
  143. policy.get_num_samples_loaded_into_buffer(buffer_idx)
  144. self.learner_info = learner_info_builder.finalize()
  145. learner_stats = {
  146. pid: self.learner_info[pid][LEARNER_STATS_KEY]
  147. for pid in self.learner_info.keys()
  148. }
  149. if released:
  150. self.idle_tower_stacks.put(buffer_idx)
  151. self.outqueue.put((get_num_samples_loaded_into_buffer, learner_stats))
  152. self.learner_queue_size.push(self.inqueue.qsize())
  153. class _MultiGPULoaderThread(threading.Thread):
  154. def __init__(self, multi_gpu_learner_thread: MultiGPULearnerThread,
  155. share_stats: bool):
  156. threading.Thread.__init__(self)
  157. self.multi_gpu_learner_thread = multi_gpu_learner_thread
  158. self.daemon = True
  159. if share_stats:
  160. self.queue_timer = multi_gpu_learner_thread.queue_timer
  161. self.load_timer = multi_gpu_learner_thread.load_timer
  162. else:
  163. self.queue_timer = TimerStat()
  164. self.load_timer = TimerStat()
  165. def run(self) -> None:
  166. while True:
  167. self._step()
  168. def _step(self) -> None:
  169. s = self.multi_gpu_learner_thread
  170. policy_map = s.policy_map
  171. # Get a new batch from the data (inqueue).
  172. with self.queue_timer:
  173. batch = s.inqueue.get()
  174. # Get next idle stack for loading.
  175. buffer_idx = s.idle_tower_stacks.get()
  176. # Load the batch into the idle stack.
  177. with self.load_timer:
  178. for pid in policy_map.keys():
  179. if pid not in s.local_worker.policies_to_train:
  180. continue
  181. policy = policy_map[pid]
  182. policy.load_batch_into_buffer(
  183. batch=batch if isinstance(batch, SampleBatch) else
  184. batch.policy_batches[pid],
  185. buffer_index=buffer_idx)
  186. # Tag just-loaded stack as "ready".
  187. s.ready_tower_stacks.put(buffer_idx)