123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133 |
- """This example script shows how to load a connector enabled policy,
- and adapt/use it with a different version of the environment.
- """
- import gymnasium as gym
- import numpy as np
- import os
- import tempfile
- from typing import Dict
- from ray.rllib.connectors.connector import ConnectorContext
- from ray.rllib.connectors.action.lambdas import register_lambda_action_connector
- from ray.rllib.connectors.agent.lambdas import register_lambda_agent_connector
- from ray.rllib.examples.connectors.prepare_checkpoint import (
- # For demo purpose only. Would normally not need this.
- create_appo_cartpole_checkpoint,
- )
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.policy import local_policy_inference
- from ray.rllib.utils.typing import (
- PolicyOutputType,
- StateBatches,
- TensorStructType,
- )
- # __sphinx_doc_begin__
- class MyCartPole(gym.Env):
- """A mock CartPole environment.
- Gives 2 additional observation states and takes 2 discrete actions.
- """
- def __init__(self):
- self._env = gym.make("CartPole-v1")
- self.observation_space = gym.spaces.Box(low=-10, high=10, shape=(6,))
- self.action_space = gym.spaces.MultiDiscrete(nvec=[2, 2])
- def step(self, actions):
- # Take the first action.
- action = actions[0]
- obs, reward, done, truncated, info = self._env.step(action)
- # Fake additional data points to the obs.
- obs = np.hstack((obs, [8.0, 6.0]))
- return obs, reward, done, truncated, info
- def reset(self, *, seed=None, options=None):
- obs, info = self._env.reset()
- return np.hstack((obs, [8.0, 6.0])), info
- # Custom agent connector to drop the last 2 feature values.
- def v2_to_v1_obs(data: Dict[str, TensorStructType]) -> Dict[str, TensorStructType]:
- data[SampleBatch.NEXT_OBS] = data[SampleBatch.NEXT_OBS][:-2]
- return data
- # Agent connector that adapts observations from the new CartPole env
- # into old format.
- V2ToV1ObsAgentConnector = register_lambda_agent_connector(
- "V2ToV1ObsAgentConnector", v2_to_v1_obs
- )
- # Custom action connector to add a placeholder action as the addtional action input.
- def v1_to_v2_action(
- actions: TensorStructType, states: StateBatches, fetches: Dict
- ) -> PolicyOutputType:
- return np.hstack((actions, [0])), states, fetches
- # Action connector that adapts action outputs from the old policy
- # into new actions for the mock environment.
- V1ToV2ActionConnector = register_lambda_action_connector(
- "V1ToV2ActionConnector", v1_to_v2_action
- )
- def run(checkpoint_path, policy_id):
- # Restore policy.
- policy = Policy.from_checkpoint(
- checkpoint=checkpoint_path,
- policy_ids=[policy_id],
- )
- # Adapt policy trained for standard CartPole to the new env.
- ctx: ConnectorContext = ConnectorContext.from_policy(policy)
- # When this policy was trained, it relied on FlattenDataAgentConnector
- # to add a batch dimension to single observations.
- # This is not necessary anymore, so we first remove the previously used
- # FlattenDataAgentConnector.
- policy.agent_connectors.remove("FlattenDataAgentConnector")
- # We then add the two adapter connectors.
- policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
- policy.action_connectors.append(V1ToV2ActionConnector(ctx))
- # Run CartPole.
- env = MyCartPole()
- obs, info = env.reset()
- done = False
- step = 0
- while not done:
- step += 1
- # Use local_policy_inference() to easily run poicy with observations.
- policy_outputs = local_policy_inference(policy, "env_1", "agent_1", obs)
- assert len(policy_outputs) == 1
- actions, _, _ = policy_outputs[0]
- print(f"step {step}", obs, actions)
- obs, _, done, _, _ = env.step(actions)
- # __sphinx_doc_end__
- if __name__ == "__main__":
- with tempfile.TemporaryDirectory() as tmpdir:
- policy_id = "default_policy"
- # Note, this is just for demo purpose.
- # Normally, you would use a policy checkpoint from a real training run.
- create_appo_cartpole_checkpoint(tmpdir)
- policy_checkpoint_path = os.path.join(
- tmpdir,
- "policies",
- policy_id,
- )
- run(policy_checkpoint_path, policy_id)
|