__init__.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import logging
  2. from ray._private.usage import usage_lib
  3. # Note: do not introduce unnecessary library dependencies here, e.g. gym.
  4. # This file is imported from the tune module in order to register RLlib agents.
  5. from ray.rllib.env.base_env import BaseEnv
  6. from ray.rllib.env.external_env import ExternalEnv
  7. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  8. from ray.rllib.env.vector_env import VectorEnv
  9. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  10. from ray.rllib.policy.policy import Policy
  11. from ray.rllib.policy.sample_batch import SampleBatch
  12. from ray.rllib.policy.tf_policy import TFPolicy
  13. from ray.rllib.policy.torch_policy import TorchPolicy
  14. from ray.tune.registry import register_trainable
  15. def _setup_logger():
  16. logger = logging.getLogger("ray.rllib")
  17. handler = logging.StreamHandler()
  18. handler.setFormatter(
  19. logging.Formatter(
  20. "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
  21. )
  22. )
  23. logger.addHandler(handler)
  24. logger.propagate = False
  25. def _register_all():
  26. from ray.rllib.algorithms.registry import ALGORITHMS, _get_algorithm_class
  27. for key, get_trainable_class_and_config in ALGORITHMS.items():
  28. register_trainable(key, get_trainable_class_and_config()[0])
  29. for key in ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
  30. register_trainable(key, _get_algorithm_class(key))
  31. _setup_logger()
  32. usage_lib.record_library_usage("rllib")
  33. __all__ = [
  34. "Policy",
  35. "TFPolicy",
  36. "TorchPolicy",
  37. "RolloutWorker",
  38. "SampleBatch",
  39. "BaseEnv",
  40. "MultiAgentEnv",
  41. "VectorEnv",
  42. "ExternalEnv",
  43. ]