__init__.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import logging
  2. # Note: do not introduce unnecessary library dependencies here, e.g. gym.
  3. # This file is imported from the tune module in order to register RLlib agents.
  4. from ray.rllib.env.base_env import BaseEnv
  5. from ray.rllib.env.external_env import ExternalEnv
  6. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  7. from ray.rllib.env.vector_env import VectorEnv
  8. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  9. from ray.rllib.policy.policy import Policy
  10. from ray.rllib.policy.sample_batch import SampleBatch
  11. from ray.rllib.policy.tf_policy import TFPolicy
  12. from ray.rllib.policy.torch_policy import TorchPolicy
  13. from ray.tune.registry import register_trainable
  14. def _setup_logger():
  15. logger = logging.getLogger("ray.rllib")
  16. handler = logging.StreamHandler()
  17. handler.setFormatter(
  18. logging.Formatter(
  19. "%(asctime)s\t%(levelname)s %(filename)s:%(lineno)s -- %(message)s"
  20. ))
  21. logger.addHandler(handler)
  22. logger.propagate = False
  23. def _register_all():
  24. from ray.rllib.agents.trainer import Trainer
  25. from ray.rllib.agents.registry import ALGORITHMS, get_trainer_class
  26. from ray.rllib.contrib.registry import CONTRIBUTED_ALGORITHMS
  27. for key in list(ALGORITHMS.keys()) + list(CONTRIBUTED_ALGORITHMS.keys(
  28. )) + ["__fake", "__sigmoid_fake_data", "__parameter_tuning"]:
  29. register_trainable(key, get_trainer_class(key))
  30. def _see_contrib(name):
  31. """Returns dummy agent class warning algo is in contrib/."""
  32. class _SeeContrib(Trainer):
  33. def setup(self, config):
  34. raise NameError(
  35. "Please run `contrib/{}` instead.".format(name))
  36. return _SeeContrib
  37. # also register the aliases minus contrib/ to give a good error message
  38. for key in list(CONTRIBUTED_ALGORITHMS.keys()):
  39. assert key.startswith("contrib/")
  40. alias = key.split("/", 1)[1]
  41. register_trainable(alias, _see_contrib(alias))
  42. _setup_logger()
  43. _register_all()
  44. __all__ = [
  45. "Policy",
  46. "TFPolicy",
  47. "TorchPolicy",
  48. "RolloutWorker",
  49. "SampleBatch",
  50. "BaseEnv",
  51. "MultiAgentEnv",
  52. "VectorEnv",
  53. "ExternalEnv",
  54. ]