normalize.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from typing import Any
  2. from ray.rllib.connectors.connector import (
  3. ActionConnector,
  4. ConnectorContext,
  5. )
  6. from ray.rllib.connectors.registry import register_connector
  7. from ray.rllib.utils.spaces.space_utils import (
  8. get_base_struct_from_space,
  9. unsquash_action,
  10. )
  11. from ray.rllib.utils.typing import ActionConnectorDataType
  12. from ray.util.annotations import PublicAPI
  13. @PublicAPI(stability="alpha")
  14. class NormalizeActionsConnector(ActionConnector):
  15. def __init__(self, ctx: ConnectorContext):
  16. super().__init__(ctx)
  17. self._action_space_struct = get_base_struct_from_space(ctx.action_space)
  18. def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType:
  19. assert isinstance(
  20. ac_data.output, tuple
  21. ), "Action connector requires PolicyOutputType data."
  22. actions, states, fetches = ac_data.output
  23. return ActionConnectorDataType(
  24. ac_data.env_id,
  25. ac_data.agent_id,
  26. ac_data.input_dict,
  27. (unsquash_action(actions, self._action_space_struct), states, fetches),
  28. )
  29. def to_state(self):
  30. return NormalizeActionsConnector.__name__, None
  31. @staticmethod
  32. def from_state(ctx: ConnectorContext, params: Any):
  33. return NormalizeActionsConnector(ctx)
  34. register_connector(NormalizeActionsConnector.__name__, NormalizeActionsConnector)