policy_server_input.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from http.server import HTTPServer, SimpleHTTPRequestHandler
  2. import logging
  3. import queue
  4. from socketserver import ThreadingMixIn
  5. import threading
  6. import time
  7. import traceback
  8. import ray.cloudpickle as pickle
  9. from ray.rllib.env.policy_client import PolicyClient, \
  10. _create_embedded_rollout_worker
  11. from ray.rllib.offline.input_reader import InputReader
  12. from ray.rllib.policy.sample_batch import SampleBatch
  13. from ray.rllib.utils.annotations import override, PublicAPI
  14. logger = logging.getLogger(__name__)
  15. class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
  16. """REST policy server that acts as an offline data source.
  17. This launches a multi-threaded server that listens on the specified host
  18. and port to serve policy requests and forward experiences to RLlib. For
  19. high performance experience collection, it implements InputReader.
  20. For an example, run `examples/serving/cartpole_server.py` along
  21. with `examples/serving/cartpole_client.py --inference-mode=local|remote`.
  22. Examples:
  23. >>> pg = PGTrainer(
  24. ... env="CartPole-v0", config={
  25. ... "input": lambda ioctx:
  26. ... PolicyServerInput(ioctx, addr, port),
  27. ... "num_workers": 0, # Run just 1 server, in the trainer.
  28. ... }
  29. >>> while True:
  30. >>> pg.train()
  31. >>> client = PolicyClient("localhost:9900", inference_mode="local")
  32. >>> eps_id = client.start_episode()
  33. >>> action = client.get_action(eps_id, obs)
  34. >>> ...
  35. >>> client.log_returns(eps_id, reward)
  36. >>> ...
  37. >>> client.log_returns(eps_id, reward)
  38. """
  39. @PublicAPI
  40. def __init__(self, ioctx, address, port, idle_timeout=3.0):
  41. """Create a PolicyServerInput.
  42. This class implements rllib.offline.InputReader, and can be used with
  43. any Trainer by configuring
  44. {"num_workers": 0,
  45. "input": lambda ioctx: PolicyServerInput(ioctx, addr, port)}
  46. Note that by setting num_workers: 0, the trainer will only create one
  47. rollout worker / PolicyServerInput. Clients can connect to the launched
  48. server using rllib.env.PolicyClient.
  49. Args:
  50. ioctx (IOContext): IOContext provided by RLlib.
  51. address (str): Server addr (e.g., "localhost").
  52. port (int): Server port (e.g., 9900).
  53. """
  54. self.rollout_worker = ioctx.worker
  55. self.samples_queue = queue.Queue()
  56. self.metrics_queue = queue.Queue()
  57. self.idle_timeout = idle_timeout
  58. def get_metrics():
  59. completed = []
  60. while True:
  61. try:
  62. completed.append(self.metrics_queue.get_nowait())
  63. except queue.Empty:
  64. break
  65. return completed
  66. # Forwards client-reported rewards directly into the local rollout
  67. # worker. This is a bit of a hack since it is patching the get_metrics
  68. # function of the sampler.
  69. if self.rollout_worker.sampler is not None:
  70. self.rollout_worker.sampler.get_metrics = get_metrics
  71. # Create a request handler that receives commands from the clients
  72. # and sends data and metrics into the queues.
  73. handler = _make_handler(self.rollout_worker, self.samples_queue,
  74. self.metrics_queue)
  75. try:
  76. import time
  77. time.sleep(1)
  78. HTTPServer.__init__(self, (address, port), handler)
  79. except OSError:
  80. print(f"Creating a PolicyServer on {address}:{port} failed!")
  81. import time
  82. time.sleep(1)
  83. raise
  84. logger.info("Starting connector server at "
  85. f"{self.server_name}:{self.server_port}")
  86. # Start the serving thread, listening on socket and handling commands.
  87. serving_thread = threading.Thread(
  88. name="server", target=self.serve_forever)
  89. serving_thread.daemon = True
  90. serving_thread.start()
  91. # Start a dummy thread that puts empty SampleBatches on the queue, just
  92. # in case we don't receive anything from clients (or there aren't
  93. # any). The latter would block sample collection entirely otherwise,
  94. # even if other workers' PolicyServerInput receive incoming data from
  95. # actual clients.
  96. heart_beat_thread = threading.Thread(
  97. name="heart-beat", target=self._put_empty_sample_batch_every_n_sec)
  98. heart_beat_thread.daemon = True
  99. heart_beat_thread.start()
  100. @override(InputReader)
  101. def next(self):
  102. return self.samples_queue.get()
  103. def _put_empty_sample_batch_every_n_sec(self):
  104. # Places an empty SampleBatch every `idle_timeout` seconds onto the
  105. # `samples_queue`. This avoids hanging of all RolloutWorkers parallel
  106. # to this one in case this PolicyServerInput does not have incoming
  107. # data (e.g. no client connected).
  108. while True:
  109. time.sleep(self.idle_timeout)
  110. self.samples_queue.put(SampleBatch())
  111. def _make_handler(rollout_worker, samples_queue, metrics_queue):
  112. # Only used in remote inference mode. We must create a new rollout worker
  113. # then since the original worker doesn't have the env properly wrapped in
  114. # an ExternalEnv interface.
  115. child_rollout_worker = None
  116. inference_thread = None
  117. lock = threading.Lock()
  118. def setup_child_rollout_worker():
  119. nonlocal lock
  120. nonlocal child_rollout_worker
  121. nonlocal inference_thread
  122. with lock:
  123. if child_rollout_worker is None:
  124. (child_rollout_worker,
  125. inference_thread) = _create_embedded_rollout_worker(
  126. rollout_worker.creation_args(), report_data)
  127. child_rollout_worker.set_weights(rollout_worker.get_weights())
  128. def report_data(data):
  129. nonlocal child_rollout_worker
  130. batch = data["samples"]
  131. batch.decompress_if_needed()
  132. samples_queue.put(batch)
  133. for rollout_metric in data["metrics"]:
  134. metrics_queue.put(rollout_metric)
  135. if child_rollout_worker is not None:
  136. child_rollout_worker.set_weights(rollout_worker.get_weights(),
  137. rollout_worker.get_global_vars())
  138. class Handler(SimpleHTTPRequestHandler):
  139. def __init__(self, *a, **kw):
  140. super().__init__(*a, **kw)
  141. def do_POST(self):
  142. content_len = int(self.headers.get("Content-Length"), 0)
  143. raw_body = self.rfile.read(content_len)
  144. parsed_input = pickle.loads(raw_body)
  145. try:
  146. response = self.execute_command(parsed_input)
  147. self.send_response(200)
  148. self.end_headers()
  149. self.wfile.write(pickle.dumps(response))
  150. except Exception:
  151. self.send_error(500, traceback.format_exc())
  152. def execute_command(self, args):
  153. command = args["command"]
  154. response = {}
  155. # Local inference commands:
  156. if command == PolicyClient.GET_WORKER_ARGS:
  157. logger.info("Sending worker creation args to client.")
  158. response["worker_args"] = rollout_worker.creation_args()
  159. elif command == PolicyClient.GET_WEIGHTS:
  160. logger.info("Sending worker weights to client.")
  161. response["weights"] = rollout_worker.get_weights()
  162. response["global_vars"] = rollout_worker.get_global_vars()
  163. elif command == PolicyClient.REPORT_SAMPLES:
  164. logger.info("Got sample batch of size {} from client.".format(
  165. args["samples"].count))
  166. report_data(args)
  167. # Remote inference commands:
  168. elif command == PolicyClient.START_EPISODE:
  169. setup_child_rollout_worker()
  170. assert inference_thread.is_alive()
  171. response["episode_id"] = (
  172. child_rollout_worker.env.start_episode(
  173. args["episode_id"], args["training_enabled"]))
  174. elif command == PolicyClient.GET_ACTION:
  175. assert inference_thread.is_alive()
  176. response["action"] = child_rollout_worker.env.get_action(
  177. args["episode_id"], args["observation"])
  178. elif command == PolicyClient.LOG_ACTION:
  179. assert inference_thread.is_alive()
  180. child_rollout_worker.env.log_action(
  181. args["episode_id"], args["observation"], args["action"])
  182. elif command == PolicyClient.LOG_RETURNS:
  183. assert inference_thread.is_alive()
  184. if args["done"]:
  185. child_rollout_worker.env.log_returns(
  186. args["episode_id"], args["reward"], args["info"],
  187. args["done"])
  188. else:
  189. child_rollout_worker.env.log_returns(
  190. args["episode_id"], args["reward"], args["info"])
  191. elif command == PolicyClient.END_EPISODE:
  192. assert inference_thread.is_alive()
  193. child_rollout_worker.env.end_episode(args["episode_id"],
  194. args["observation"])
  195. else:
  196. raise ValueError("Unknown command: {}".format(command))
  197. return response
  198. return Handler