input_reader.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from abc import ABCMeta, abstractmethod
  2. import logging
  3. import numpy as np
  4. import threading
  5. from ray.rllib.policy.sample_batch import MultiAgentBatch
  6. from ray.rllib.utils.annotations import PublicAPI
  7. from ray.rllib.utils.framework import try_import_tf
  8. from typing import Dict, List
  9. from ray.rllib.utils.typing import TensorType, SampleBatchType
  10. tf1, tf, tfv = try_import_tf()
  11. logger = logging.getLogger(__name__)
  12. @PublicAPI
  13. class InputReader(metaclass=ABCMeta):
  14. """API for collecting and returning experiences during policy evaluation."""
  15. @abstractmethod
  16. @PublicAPI
  17. def next(self) -> SampleBatchType:
  18. """Returns the next batch of read experiences.
  19. Returns:
  20. The experience read (SampleBatch or MultiAgentBatch).
  21. """
  22. raise NotImplementedError
  23. @PublicAPI
  24. def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]:
  25. """Returns TensorFlow queue ops for reading inputs from this reader.
  26. The main use of these ops is for integration into custom model losses.
  27. For example, you can use tf_input_ops() to read from files of external
  28. experiences to add an imitation learning loss to your model.
  29. This method creates a queue runner thread that will call next() on this
  30. reader repeatedly to feed the TensorFlow queue.
  31. Args:
  32. queue_size: Max elements to allow in the TF queue.
  33. Example:
  34. >>> from ray.rllib.models.modelv2 import ModelV2
  35. >>> from ray.rllib.offline.json_reader import JsonReader
  36. >>> imitation_loss = ... # doctest +SKIP
  37. >>> class MyModel(ModelV2): # doctest +SKIP
  38. ... def custom_loss(self, policy_loss, loss_inputs):
  39. ... reader = JsonReader(...)
  40. ... input_ops = reader.tf_input_ops()
  41. ... logits, _ = self._build_layers_v2(
  42. ... {"obs": input_ops["obs"]},
  43. ... self.num_outputs, self.options)
  44. ... il_loss = imitation_loss(logits, input_ops["action"])
  45. ... return policy_loss + il_loss
  46. You can find a runnable version of this in examples/custom_loss.py.
  47. Returns:
  48. Dict of Tensors, one for each column of the read SampleBatch.
  49. """
  50. if hasattr(self, "_queue_runner"):
  51. raise ValueError(
  52. "A queue runner already exists for this input reader. "
  53. "You can only call tf_input_ops() once per reader."
  54. )
  55. logger.info("Reading initial batch of data from input reader.")
  56. batch = self.next()
  57. if isinstance(batch, MultiAgentBatch):
  58. raise NotImplementedError(
  59. "tf_input_ops() is not implemented for multi agent batches"
  60. )
  61. # Note on casting to `np.array(batch[k])`: In order to get all keys that
  62. # are numbers, we need to convert to numpy everything that is not a numpy array.
  63. # This is because SampleBatches used to only hold numpy arrays, but since our
  64. # RNN efforts under RLModules, we also allow lists.
  65. keys = [
  66. k
  67. for k in sorted(batch.keys())
  68. if np.issubdtype(np.array(batch[k]).dtype, np.number)
  69. ]
  70. dtypes = [batch[k].dtype for k in keys]
  71. shapes = {k: (-1,) + s[1:] for (k, s) in [(k, batch[k].shape) for k in keys]}
  72. queue = tf1.FIFOQueue(capacity=queue_size, dtypes=dtypes, names=keys)
  73. tensors = queue.dequeue()
  74. logger.info("Creating TF queue runner for {}".format(self))
  75. self._queue_runner = _QueueRunner(self, queue, keys, dtypes)
  76. self._queue_runner.enqueue(batch)
  77. self._queue_runner.start()
  78. out = {k: tf.reshape(t, shapes[k]) for k, t in tensors.items()}
  79. return out
  80. class _QueueRunner(threading.Thread):
  81. """Thread that feeds a TF queue from a InputReader."""
  82. def __init__(
  83. self,
  84. input_reader: InputReader,
  85. queue: "tf1.FIFOQueue",
  86. keys: List[str],
  87. dtypes: "tf.dtypes.DType",
  88. ):
  89. threading.Thread.__init__(self)
  90. self.sess = tf1.get_default_session()
  91. self.daemon = True
  92. self.input_reader = input_reader
  93. self.keys = keys
  94. self.queue = queue
  95. self.placeholders = [tf1.placeholder(dtype) for dtype in dtypes]
  96. self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders)))
  97. def enqueue(self, batch: SampleBatchType):
  98. data = {self.placeholders[i]: batch[key] for i, key in enumerate(self.keys)}
  99. self.sess.run(self.enqueue_op, feed_dict=data)
  100. def run(self):
  101. while True:
  102. try:
  103. batch = self.input_reader.next()
  104. self.enqueue(batch)
  105. except Exception:
  106. logger.exception("Error reading from input")