io_context.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. from typing import Any, Optional, TYPE_CHECKING
  3. from ray.rllib.utils.annotations import PublicAPI
  4. from ray.rllib.utils.typing import TrainerConfigDict
  5. if TYPE_CHECKING:
  6. from ray.rllib.evaluation.sampler import SamplerInput
  7. @PublicAPI
  8. class IOContext:
  9. """Class containing attributes to pass to input/output class constructors.
  10. RLlib auto-sets these attributes when constructing input/output classes,
  11. such as InputReaders and OutputWriters.
  12. """
  13. @PublicAPI
  14. def __init__(self,
  15. log_dir: Optional[str] = None,
  16. config: Optional[TrainerConfigDict] = None,
  17. worker_index: int = 0,
  18. worker: Optional[Any] = None):
  19. """Initializes a IOContext object.
  20. Args:
  21. log_dir: The logging directory to read from/write to.
  22. config: The Trainer's main config dict.
  23. worker_index (int): When there are multiple workers created, this
  24. uniquely identifies the current worker. 0 for the local
  25. worker, >0 for any of the remote workers.
  26. worker (RolloutWorker): The RolloutWorker object reference.
  27. """
  28. self.log_dir = log_dir or os.getcwd()
  29. self.config = config or {}
  30. self.worker_index = worker_index
  31. self.worker = worker
  32. @PublicAPI
  33. def default_sampler_input(self) -> Optional["SamplerInput"]:
  34. """Returns the RolloutWorker's SamplerInput object, if any.
  35. Returns None if the RolloutWorker has no SamplerInput. Note that local
  36. workers in case there are also one or more remote workers by default
  37. do not create a SamplerInput object.
  38. Returns:
  39. The RolloutWorkers' SamplerInput object or None if none exists.
  40. """
  41. return self.worker.sampler
  42. @PublicAPI
  43. @property
  44. def input_config(self):
  45. return self.config.get("input_config", {})