__init__.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import logging
  2. # Note: do not introduce unnecessary library dependencies here, e.g. gym.
  3. # This file is imported from the tune module in order to register RLlib agents.
  4. from ray.rllib.env.base_env import BaseEnv
  5. from ray.rllib.env.external_env import ExternalEnv
  6. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  7. from ray.rllib.env.vector_env import VectorEnv
  8. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  9. from ray.rllib.policy.policy import Policy
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.tf_policy import TFPolicy
  12. from ray.rllib.policy.torch_policy import TorchPolicy
  13. from ray.tune.registry import register_trainable
  14. def _setup_logger():
  15. logger = logging.getLogger("ray.rllib")
  16. handler = logging.StreamHandler()
  17. handler.setFormatter(
  18. logging.Formatter(
  19. "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
  20. ))
  21. logger.addHandler(handler)
  22. logger.propagate = False
  23. def _register_all():
  24. from ray.rllib.agents.trainer import Trainer, with_common_config
  25. from ray.rllib.agents.registry import ALGORITHMS, get_trainer_class
  26. from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
  27. for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
  28. )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
  29. register_trainable(key, get_trainer_class(key))
  30. def _see_contrib(name):
  31. """Returns dummy agent class warning algo is in contrib/."""
  32. class _SeeContrib(Trainer):
  33. _name = "SeeContrib"
  34. _default_config = with_common_config({})
  35. def setup(self, config):
  36. raise NameError(
  37. "Please run `contrib/{}` instead.".format(name))
  38. return _SeeContrib
  39. # also register the aliases minus contrib/ to give a good error message
  40. for key in list(CONTRIBUTED_ALGORITHMS.keys()):
  41. assert key.startswith("contrib/")
  42. alias = key.split("/", 1)[1]
  43. register_trainable(alias, _see_contrib(alias))
  44. _setup_logger()
  45. _register_all()
  46. __all__ = [
  47. "Policy",
  48. "TFPolicy",
  49. "TorchPolicy",
  50. "RolloutWorker",
  51. "SampleBatch",
  52. "BaseEnv",
  53. "MultiAgentEnv",
  54. "VectorEnv",
  55. "ExternalEnv",
  56. ]