run_connector_policy.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """This example script shows how to load a connector enabled policy,
  2. and use it in a serving/inference setting.
  3. """
  4. import gymnasium as gym
  5. import os
  6. import tempfile
  7. from ray.rllib.examples.connectors.prepare_checkpoint import (
  8. # For demo purpose only. Would normally not need this.
  9. create_appo_cartpole_checkpoint,
  10. )
  11. from ray.rllib.policy.policy import Policy
  12. from ray.rllib.utils.policy import local_policy_inference
  13. def run(checkpoint_path, policy_id):
  14. # __sphinx_doc_begin__
  15. # Restore policy.
  16. policy = Policy.from_checkpoint(
  17. checkpoint=checkpoint_path,
  18. policy_ids=[policy_id],
  19. )
  20. # Run CartPole.
  21. env = gym.make("CartPole-v1")
  22. obs, info = env.reset()
  23. terminated = truncated = False
  24. step = 0
  25. while not terminated and not truncated:
  26. step += 1
  27. # Use local_policy_inference() to run inference, so we do not have to
  28. # provide policy states or extra fetch dictionaries.
  29. # "env_1" and "agent_1" are dummy env and agent IDs to run connectors with.
  30. policy_outputs = local_policy_inference(
  31. policy, "env_1", "agent_1", obs, explore=False
  32. )
  33. assert len(policy_outputs) == 1
  34. action, _, _ = policy_outputs[0]
  35. print(f"step {step}", obs, action)
  36. # Step environment forward one more step.
  37. obs, _, terminated, truncated, _ = env.step(action)
  38. # __sphinx_doc_end__
  39. if __name__ == "__main__":
  40. with tempfile.TemporaryDirectory() as tmpdir:
  41. policy_id = "default_policy"
  42. # Note, this is just for demo purpose.
  43. # Normally, you would use a policy checkpoint from a real training run.
  44. create_appo_cartpole_checkpoint(tmpdir)
  45. policy_checkpoint_path = os.path.join(
  46. tmpdir,
  47. "policies",
  48. policy_id,
  49. )
  50. run(policy_checkpoint_path, policy_id)