clip.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from typing import Any
  2. from ray.rllib.connectors.connector import (
  3. ActionConnector,
  4. ConnectorContext,
  5. )
  6. from ray.rllib.connectors.registry import register_connector
  7. from ray.rllib.utils.spaces.space_utils import clip_action, get_base_struct_from_space
  8. from ray.rllib.utils.typing import ActionConnectorDataType
  9. from ray.util.annotations import PublicAPI
  10. @PublicAPI(stability="alpha")
  11. class ClipActionsConnector(ActionConnector):
  12. def __init__(self, ctx: ConnectorContext):
  13. super().__init__(ctx)
  14. self._action_space_struct = get_base_struct_from_space(ctx.action_space)
  15. def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
  16. assert isinstance(
  17. ac_data.output, tuple
  18. ), "Action connector requires PolicyOutputType data."
  19. actions, states, fetches = ac_data.output
  20. return ActionConnectorDataType(
  21. ac_data.env_id,
  22. ac_data.agent_id,
  23. ac_data.input_dict,
  24. (clip_action(actions, self._action_space_struct), states, fetches),
  25. )
  26. def to_state(self):
  27. return ClipActionsConnector.__name__, None
  28. @staticmethod
  29. def from_state(ctx: ConnectorContext, params: Any):
  30. return ClipActionsConnector(ctx)
  31. register_connector(ClipActionsConnector.__name__, ClipActionsConnector)