registry.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """Registry of connector names for global access."""
  2. from typing import Any
  3. from ray.util.annotations import PublicAPI
  4. from ray.rllib.connectors.connector import Connector, ConnectorContext
  5. ALL_CONNECTORS = dict()
  6. @PublicAPI(stability="alpha")
  7. def register_connector(name: str, cls: Connector):
  8. """Register a connector for use with RLlib.
  9. Args:
  10. name: Name to register.
  11. cls: Callable that creates an env.
  12. """
  13. if name in ALL_CONNECTORS:
  14. return
  15. if not issubclass(cls, Connector):
  16. raise TypeError("Can only register Connector type.", cls)
  17. # Record it in local registry in case we need to register everything
  18. # again in the global registry, for example in the event of cluster
  19. # restarts.
  20. ALL_CONNECTORS[name] = cls
  21. @PublicAPI(stability="alpha")
  22. def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
  23. # TODO(jungong) : switch the order of parameters man!!
  24. """Get a connector by its name and serialized config.
  25. Args:
  26. name: name of the connector.
  27. ctx: Connector context.
  28. params: serialized parameters of the connector.
  29. Returns:
  30. Constructed connector.
  31. """
  32. if name not in ALL_CONNECTORS:
  33. raise NameError("connector not found.", name)
  34. return ALL_CONNECTORS[name].from_state(ctx, params)