serve_and_rllib.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. """This example script shows how one can use Ray Serve to serve an already
  2. trained RLlib Policy (and its model) to serve action computations.
  3. For a complete tutorial, also see:
  4. https://docs.ray.io/en/master/serve/tutorials/rllib.html
  5. """
  6. import argparse
  7. import gym
  8. import requests
  9. from starlette.requests import Request
  10. import ray
  11. import ray.rllib.agents.dqn as dqn
  12. from ray.rllib.env.wrappers.atari_wrappers import FrameStack, WarpFrame
  13. from ray import serve
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument(
  16. "--framework",
  17. choices=["tf", "tf2", "tfe", "torch"],
  18. default="tf",
  19. help="The DL framework specifier.")
  20. parser.add_argument("--train-iters", type=int, default=1)
  21. parser.add_argument("--no-render", action="store_true")
  22. args = parser.parse_args()
  23. class ServeRLlibPolicy:
  24. """Callable class used by Ray Serve to handle async requests.
  25. All the necessary serving logic is implemented in here:
  26. - Creation and restoring of the (already trained) RLlib Trainer.
  27. - Calls to trainer.compute_action upon receiving an action request
  28. (with a current observation).
  29. """
  30. def __init__(self, config, checkpoint_path):
  31. # Create the Trainer.
  32. self.trainer = dqn.DQNTrainer(config=config)
  33. # Load an already trained state for the trainer.
  34. self.trainer.restore(checkpoint_path)
  35. async def __call__(self, request: Request):
  36. json_input = await request.json()
  37. # Compute and return the action for the given observation.
  38. obs = json_input["observation"]
  39. action = self.trainer.compute_single_action(obs)
  40. return {"action": int(action)}
  41. def train_rllib_policy(config):
  42. """Trains a DQNTrainer on MsPacman-v0 for n iterations.
  43. Saves the trained Trainer to disk and returns the checkpoint path.
  44. Returns:
  45. str: The saved checkpoint to restore the trainer DQNTrainer from.
  46. """
  47. # Create trainer from config.
  48. trainer = dqn.DQNTrainer(config=config)
  49. # Train for n iterations, then save.
  50. for _ in range(args.train_iters):
  51. print(trainer.train())
  52. return trainer.save()
  53. if __name__ == "__main__":
  54. # Config for the served RLlib Policy/Trainer.
  55. config = {
  56. "framework": args.framework,
  57. # local mode -> local env inside Trainer not needed!
  58. "num_workers": 0,
  59. "env": "MsPacman-v0",
  60. }
  61. # Train the policy for some time, then save it and get the checkpoint path.
  62. checkpoint_path = train_rllib_policy(config)
  63. ray.init(num_cpus=8)
  64. # Start Ray serve (create the RLlib Policy service defined by
  65. # our `ServeRLlibPolicy` class above).
  66. client = serve.start()
  67. client.create_backend("backend", ServeRLlibPolicy, config, checkpoint_path)
  68. client.create_endpoint(
  69. "endpoint", backend="backend", route="/mspacman-rllib-policy")
  70. # Create the environment that we would like to receive
  71. # served actions for.
  72. env = FrameStack(WarpFrame(gym.make("MsPacman-v0"), 84), 4)
  73. obs = env.reset()
  74. while True:
  75. print("-> Requesting action for obs ...")
  76. # Send a request to serve.
  77. resp = requests.get(
  78. "http://localhost:8000/mspacman-rllib-policy",
  79. json={"observation": obs.tolist()})
  80. response = resp.json()
  81. print("<- Received response {}".format(response))
  82. # Apply the action in the env.
  83. action = response["action"]
  84. obs, reward, done, _ = env.step(action)
  85. # If episode done -> reset to get initial observation of new episode.
  86. if done:
  87. obs = env.reset()
  88. # Render if necessary.
  89. if not args.no_render:
  90. env.render()