learner_thread.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import queue
  2. import threading
  3. from ray.util.timer import _Timer
  4. from ray.rllib.utils.framework import try_import_tf
  5. from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder
  6. from ray.rllib.utils.metrics.window_stat import WindowStat
  7. LEARNER_QUEUE_MAX_SIZE = 16
  8. tf1, tf, tfv = try_import_tf()
  9. class LearnerThread(threading.Thread):
  10. """Background thread that updates the local model from replay data.
  11. The learner thread communicates with the main thread through Queues. This
  12. is needed since Ray operations can only be run on the main thread. In
  13. addition, moving heavyweight gradient ops session runs off the main thread
  14. improves overall throughput.
  15. """
  16. def __init__(self, local_worker):
  17. threading.Thread.__init__(self)
  18. self.learner_queue_size = WindowStat("size", 50)
  19. self.local_worker = local_worker
  20. self.inqueue = queue.Queue(maxsize=LEARNER_QUEUE_MAX_SIZE)
  21. self.outqueue = queue.Queue()
  22. self.queue_timer = _Timer()
  23. self.grad_timer = _Timer()
  24. self.overall_timer = _Timer()
  25. self.daemon = True
  26. self.policy_ids_updated = []
  27. self.stopped = False
  28. self.learner_info = {}
  29. def run(self):
  30. # Switch on eager mode if configured.
  31. if self.local_worker.config.framework_str == "tf2":
  32. tf1.enable_eager_execution()
  33. while not self.stopped:
  34. self.step()
  35. def step(self):
  36. with self.overall_timer:
  37. with self.queue_timer:
  38. replay_actor, ma_batch = self.inqueue.get()
  39. if ma_batch is not None:
  40. prio_dict = {}
  41. with self.grad_timer:
  42. # Use LearnerInfoBuilder as a unified way to build the
  43. # final results dict from `learn_on_loaded_batch` call(s).
  44. # This makes sure results dicts always have the same
  45. # structure no matter the setup (multi-GPU, multi-agent,
  46. # minibatch SGD, tf vs torch).
  47. learner_info_builder = LearnerInfoBuilder(num_devices=1)
  48. multi_agent_results = self.local_worker.learn_on_batch(ma_batch)
  49. self.policy_ids_updated.extend(list(multi_agent_results.keys()))
  50. for pid, results in multi_agent_results.items():
  51. learner_info_builder.add_learn_on_batch_results(results, pid)
  52. td_error = results["td_error"]
  53. # Switch off auto-conversion from numpy to torch/tf
  54. # tensors for the indices. This may lead to errors
  55. # when sent to the buffer for processing
  56. # (may get manipulated if they are part of a tensor).
  57. ma_batch.policy_batches[pid].set_get_interceptor(None)
  58. prio_dict[pid] = (
  59. ma_batch.policy_batches[pid].get("batch_indexes"),
  60. td_error,
  61. )
  62. self.learner_info = learner_info_builder.finalize()
  63. self.grad_timer.push_units_processed(ma_batch.count)
  64. # Put tuple: replay_actor, prio-dict, env-steps, and agent-steps into
  65. # the queue.
  66. self.outqueue.put(
  67. (replay_actor, prio_dict, ma_batch.count, ma_batch.agent_steps())
  68. )
  69. self.learner_queue_size.push(self.inqueue.qsize())
  70. self.overall_timer.push_units_processed(
  71. ma_batch and ma_batch.count or 0
  72. )
  73. del ma_batch