util.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import logging
  2. from typing import Any, Tuple, TYPE_CHECKING
  3. from ray.rllib.connectors.action.clip import ClipActionsConnector
  4. from ray.rllib.connectors.action.immutable import ImmutableActionsConnector
  5. from ray.rllib.connectors.action.lambdas import ConvertToNumpyConnector
  6. from ray.rllib.connectors.action.normalize import NormalizeActionsConnector
  7. from ray.rllib.connectors.action.pipeline import ActionConnectorPipeline
  8. from ray.rllib.connectors.agent.clip_reward import ClipRewardAgentConnector
  9. from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector
  10. from ray.rllib.connectors.agent.pipeline import AgentConnectorPipeline
  11. from ray.rllib.connectors.agent.state_buffer import StateBufferConnector
  12. from ray.rllib.connectors.agent.view_requirement import ViewRequirementAgentConnector
  13. from ray.rllib.connectors.connector import Connector, ConnectorContext
  14. from ray.rllib.connectors.registry import get_connector
  15. from ray.rllib.connectors.agent.mean_std_filter import (
  16. MeanStdObservationFilterAgentConnector,
  17. ConcurrentMeanStdObservationFilterAgentConnector,
  18. )
  19. from ray.util.annotations import PublicAPI, DeveloperAPI
  20. from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
  21. if TYPE_CHECKING:
  22. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  23. from ray.rllib.policy.policy import Policy
  24. logger = logging.getLogger(__name__)
  25. def __preprocessing_enabled(config: "AlgorithmConfig"):
  26. if config._disable_preprocessor_api:
  27. return False
  28. # Same conditions as in RolloutWorker.__init__.
  29. if config.is_atari and config.preprocessor_pref == "deepmind":
  30. return False
  31. if config.preprocessor_pref is None:
  32. return False
  33. return True
  34. def __clip_rewards(config: "AlgorithmConfig"):
  35. # Same logic as in RolloutWorker.__init__.
  36. # We always clip rewards for Atari games.
  37. return config.clip_rewards or config.is_atari
  38. @PublicAPI(stability="alpha")
  39. def get_agent_connectors_from_config(
  40. ctx: ConnectorContext,
  41. config: "AlgorithmConfig",
  42. ) -> AgentConnectorPipeline:
  43. connectors = []
  44. clip_rewards = __clip_rewards(config)
  45. if clip_rewards is True:
  46. connectors.append(ClipRewardAgentConnector(ctx, sign=True))
  47. elif type(clip_rewards) == float:
  48. connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards)))
  49. if __preprocessing_enabled(config):
  50. connectors.append(ObsPreprocessorConnector(ctx))
  51. # Filters should be after observation preprocessing
  52. filter_connector = get_synced_filter_connector(
  53. ctx,
  54. )
  55. # Configuration option "NoFilter" results in `filter_connector==None`.
  56. if filter_connector:
  57. connectors.append(filter_connector)
  58. connectors.extend(
  59. [
  60. StateBufferConnector(ctx),
  61. ViewRequirementAgentConnector(ctx),
  62. ]
  63. )
  64. return AgentConnectorPipeline(ctx, connectors)
  65. @PublicAPI(stability="alpha")
  66. def get_action_connectors_from_config(
  67. ctx: ConnectorContext,
  68. config: "AlgorithmConfig",
  69. ) -> ActionConnectorPipeline:
  70. """Default list of action connectors to use for a new policy.
  71. Args:
  72. ctx: context used to create connectors.
  73. config: The AlgorithmConfig object.
  74. """
  75. connectors = [ConvertToNumpyConnector(ctx)]
  76. if config.get("normalize_actions", False):
  77. connectors.append(NormalizeActionsConnector(ctx))
  78. if config.get("clip_actions", False):
  79. connectors.append(ClipActionsConnector(ctx))
  80. connectors.append(ImmutableActionsConnector(ctx))
  81. return ActionConnectorPipeline(ctx, connectors)
  82. @PublicAPI(stability="alpha")
  83. def create_connectors_for_policy(policy: "Policy", config: "AlgorithmConfig"):
  84. """Util to create agent and action connectors for a Policy.
  85. Args:
  86. policy: Policy instance.
  87. config: Algorithm config dict.
  88. """
  89. ctx: ConnectorContext = ConnectorContext.from_policy(policy)
  90. assert (
  91. policy.agent_connectors is None and policy.action_connectors is None
  92. ), "Can not create connectors for a policy that already has connectors."
  93. policy.agent_connectors = get_agent_connectors_from_config(ctx, config)
  94. policy.action_connectors = get_action_connectors_from_config(ctx, config)
  95. logger.info("Using connectors:")
  96. logger.info(policy.agent_connectors.__str__(indentation=4))
  97. logger.info(policy.action_connectors.__str__(indentation=4))
  98. @PublicAPI(stability="alpha")
  99. def restore_connectors_for_policy(
  100. policy: "Policy", connector_config: Tuple[str, Tuple[Any]]
  101. ) -> Connector:
  102. """Util to create connector for a Policy based on serialized config.
  103. Args:
  104. policy: Policy instance.
  105. connector_config: Serialized connector config.
  106. """
  107. ctx: ConnectorContext = ConnectorContext.from_policy(policy)
  108. name, params = connector_config
  109. return get_connector(name, ctx, params)
  110. # We need this filter selection mechanism temporarily to remain compatible to old API
  111. @DeveloperAPI
  112. def get_synced_filter_connector(ctx: ConnectorContext):
  113. filter_specifier = ctx.config.get("observation_filter")
  114. if filter_specifier == "MeanStdFilter":
  115. return MeanStdObservationFilterAgentConnector(ctx, clip=None)
  116. elif filter_specifier == "ConcurrentMeanStdFilter":
  117. return ConcurrentMeanStdObservationFilterAgentConnector(ctx, clip=None)
  118. elif filter_specifier == "NoFilter":
  119. return None
  120. else:
  121. raise Exception("Unknown observation_filter: " + str(filter_specifier))
  122. @DeveloperAPI
  123. def maybe_get_filters_for_syncing(rollout_worker, policy_id):
  124. # As long as the historic filter synchronization mechanism is in
  125. # place, we need to put filters into self.filters so that they get
  126. # synchronized
  127. policy = rollout_worker.policy_map[policy_id]
  128. if not policy.agent_connectors:
  129. return
  130. filter_connectors = policy.agent_connectors[SyncedFilterAgentConnector]
  131. # There can only be one filter at a time
  132. if not filter_connectors:
  133. return
  134. assert len(filter_connectors) == 1, (
  135. "ConnectorPipeline has multiple connectors of type "
  136. "SyncedFilterAgentConnector but can only have one."
  137. )
  138. rollout_worker.filters[policy_id] = filter_connectors[0].filter