12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- """Registry of connector names for global access."""
- from typing import Any
- from ray.util.annotations import PublicAPI
- from ray.rllib.connectors.connector import Connector, ConnectorContext
- ALL_CONNECTORS = dict()
- @PublicAPI(stability="alpha")
- def register_connector(name: str, cls: Connector):
- """Register a connector for use with RLlib.
- Args:
- name: Name to register.
- cls: Callable that creates an env.
- """
- if name in ALL_CONNECTORS:
- return
- if not issubclass(cls, Connector):
- raise TypeError("Can only register Connector type.", cls)
- # Record it in local registry in case we need to register everything
- # again in the global registry, for example in the event of cluster
- # restarts.
- ALL_CONNECTORS[name] = cls
- @PublicAPI(stability="alpha")
- def get_connector(name: str, ctx: ConnectorContext, params: Any = None) -> Connector:
- # TODO(jungong) : switch the order of parameters man!!
- """Get a connector by its name and serialized config.
- Args:
- name: name of the connector.
- ctx: Connector context.
- params: serialized parameters of the connector.
- Returns:
- Constructed connector.
- """
- if name not in ALL_CONNECTORS:
- raise NameError("connector not found.", name)
- return ALL_CONNECTORS[name].from_state(ctx, params)
|