adapt_connector_policy.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """This example script shows how to load a connector enabled policy,
  2. and adapt/use it with a different version of the environment.
  3. """
  4. import gymnasium as gym
  5. import numpy as np
  6. import os
  7. import tempfile
  8. from typing import Dict
  9. from ray.rllib.connectors.connector import ConnectorContext
  10. from ray.rllib.connectors.action.lambdas import register_lambda_action_connector
  11. from ray.rllib.connectors.agent.lambdas import register_lambda_agent_connector
  12. from ray.rllib.examples.connectors.prepare_checkpoint import (
  13. # For demo purpose only. Would normally not need this.
  14. create_appo_cartpole_checkpoint,
  15. )
  16. from ray.rllib.policy.policy import Policy
  17. from ray.rllib.policy.sample_batch import SampleBatch
  18. from ray.rllib.utils.policy import local_policy_inference
  19. from ray.rllib.utils.typing import (
  20. PolicyOutputType,
  21. StateBatches,
  22. TensorStructType,
  23. )
  24. # __sphinx_doc_begin__
  25. class MyCartPole(gym.Env):
  26. """A mock CartPole environment.
  27. Gives 2 additional observation states and takes 2 discrete actions.
  28. """
  29. def __init__(self):
  30. self._env = gym.make("CartPole-v1")
  31. self.observation_space = gym.spaces.Box(low=-10, high=10, shape=(6,))
  32. self.action_space = gym.spaces.MultiDiscrete(nvec=[2, 2])
  33. def step(self, actions):
  34. # Take the first action.
  35. action = actions[0]
  36. obs, reward, done, truncated, info = self._env.step(action)
  37. # Fake additional data points to the obs.
  38. obs = np.hstack((obs, [8.0, 6.0]))
  39. return obs, reward, done, truncated, info
  40. def reset(self, *, seed=None, options=None):
  41. obs, info = self._env.reset()
  42. return np.hstack((obs, [8.0, 6.0])), info
  43. # Custom agent connector to drop the last 2 feature values.
  44. def v2_to_v1_obs(data: Dict[str, TensorStructType]) -> Dict[str, TensorStructType]:
  45. data[SampleBatch.NEXT_OBS] = data[SampleBatch.NEXT_OBS][:-2]
  46. return data
  47. # Agent connector that adapts observations from the new CartPole env
  48. # into old format.
  49. V2ToV1ObsAgentConnector = register_lambda_agent_connector(
  50. "V2ToV1ObsAgentConnector", v2_to_v1_obs
  51. )
  52. # Custom action connector to add a placeholder action as the addtional action input.
  53. def v1_to_v2_action(
  54. actions: TensorStructType, states: StateBatches, fetches: Dict
  55. ) -> PolicyOutputType:
  56. return np.hstack((actions, [0])), states, fetches
  57. # Action connector that adapts action outputs from the old policy
  58. # into new actions for the mock environment.
  59. V1ToV2ActionConnector = register_lambda_action_connector(
  60. "V1ToV2ActionConnector", v1_to_v2_action
  61. )
  62. def run(checkpoint_path, policy_id):
  63. # Restore policy.
  64. policy = Policy.from_checkpoint(
  65. checkpoint=checkpoint_path,
  66. policy_ids=[policy_id],
  67. )
  68. # Adapt policy trained for standard CartPole to the new env.
  69. ctx: ConnectorContext = ConnectorContext.from_policy(policy)
  70. # When this policy was trained, it relied on FlattenDataAgentConnector
  71. # to add a batch dimension to single observations.
  72. # This is not necessary anymore, so we first remove the previously used
  73. # FlattenDataAgentConnector.
  74. policy.agent_connectors.remove("FlattenDataAgentConnector")
  75. # We then add the two adapter connectors.
  76. policy.agent_connectors.prepend(V2ToV1ObsAgentConnector(ctx))
  77. policy.action_connectors.append(V1ToV2ActionConnector(ctx))
  78. # Run CartPole.
  79. env = MyCartPole()
  80. obs, info = env.reset()
  81. done = False
  82. step = 0
  83. while not done:
  84. step += 1
  85. # Use local_policy_inference() to easily run poicy with observations.
  86. policy_outputs = local_policy_inference(policy, "env_1", "agent_1", obs)
  87. assert len(policy_outputs) == 1
  88. actions, _, _ = policy_outputs[0]
  89. print(f"step {step}", obs, actions)
  90. obs, _, done, _, _ = env.step(actions)
  91. # __sphinx_doc_end__
  92. if __name__ == "__main__":
  93. with tempfile.TemporaryDirectory() as tmpdir:
  94. policy_id = "default_policy"
  95. # Note, this is just for demo purpose.
  96. # Normally, you would use a policy checkpoint from a real training run.
  97. create_appo_cartpole_checkpoint(tmpdir)
  98. policy_checkpoint_path = os.path.join(
  99. tmpdir,
  100. "policies",
  101. policy_id,
  102. )
  103. run(policy_checkpoint_path, policy_id)