123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233 |
- from http.server import HTTPServer, SimpleHTTPRequestHandler
- import logging
- import queue
- from socketserver import ThreadingMixIn
- import threading
- import time
- import traceback
- import ray.cloudpickle as pickle
- from ray.rllib.env.policy_client import PolicyClient, \
- _create_embedded_rollout_worker
- from ray.rllib.offline.input_reader import InputReader
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.annotations import override, PublicAPI
- logger = logging.getLogger(__name__)
- class PolicyServerInput(ThreadingMixIn, HTTPServer, InputReader):
- """REST policy server that acts as an offline data source.
- This launches a multi-threaded server that listens on the specified host
- and port to serve policy requests and forward experiences to RLlib. For
- high performance experience collection, it implements InputReader.
- For an example, run `examples/serving/cartpole_server.py` along
- with `examples/serving/cartpole_client.py --inference-mode=local|remote`.
- Examples:
- >>> pg = PGTrainer(
- ... env="CartPole-v0", config={
- ... "input": lambda ioctx:
- ... PolicyServerInput(ioctx, addr, port),
- ... "num_workers": 0, # Run just 1 server, in the trainer.
- ... }
- >>> while True:
- >>> pg.train()
- >>> client = PolicyClient("localhost:9900", inference_mode="local")
- >>> eps_id = client.start_episode()
- >>> action = client.get_action(eps_id, obs)
- >>> ...
- >>> client.log_returns(eps_id, reward)
- >>> ...
- >>> client.log_returns(eps_id, reward)
- """
- @PublicAPI
- def __init__(self, ioctx, address, port, idle_timeout=3.0):
- """Create a PolicyServerInput.
- This class implements rllib.offline.InputReader, and can be used with
- any Trainer by configuring
- {"num_workers": 0,
- "input": lambda ioctx: PolicyServerInput(ioctx, addr, port)}
- Note that by setting num_workers: 0, the trainer will only create one
- rollout worker / PolicyServerInput. Clients can connect to the launched
- server using rllib.env.PolicyClient.
- Args:
- ioctx (IOContext): IOContext provided by RLlib.
- address (str): Server addr (e.g., "localhost").
- port (int): Server port (e.g., 9900).
- """
- self.rollout_worker = ioctx.worker
- self.samples_queue = queue.Queue()
- self.metrics_queue = queue.Queue()
- self.idle_timeout = idle_timeout
- def get_metrics():
- completed = []
- while True:
- try:
- completed.append(self.metrics_queue.get_nowait())
- except queue.Empty:
- break
- return completed
- # Forwards client-reported rewards directly into the local rollout
- # worker. This is a bit of a hack since it is patching the get_metrics
- # function of the sampler.
- if self.rollout_worker.sampler is not None:
- self.rollout_worker.sampler.get_metrics = get_metrics
- # Create a request handler that receives commands from the clients
- # and sends data and metrics into the queues.
- handler = _make_handler(self.rollout_worker, self.samples_queue,
- self.metrics_queue)
- try:
- import time
- time.sleep(1)
- HTTPServer.__init__(self, (address, port), handler)
- except OSError:
- print(f"Creating a PolicyServer on {address}:{port} failed!")
- import time
- time.sleep(1)
- raise
- logger.info("Starting connector server at "
- f"{self.server_name}:{self.server_port}")
- # Start the serving thread, listening on socket and handling commands.
- serving_thread = threading.Thread(
- name="server", target=self.serve_forever)
- serving_thread.daemon = True
- serving_thread.start()
- # Start a dummy thread that puts empty SampleBatches on the queue, just
- # in case we don't receive anything from clients (or there aren't
- # any). The latter would block sample collection entirely otherwise,
- # even if other workers' PolicyServerInput receive incoming data from
- # actual clients.
- heart_beat_thread = threading.Thread(
- name="heart-beat", target=self._put_empty_sample_batch_every_n_sec)
- heart_beat_thread.daemon = True
- heart_beat_thread.start()
- @override(InputReader)
- def next(self):
- return self.samples_queue.get()
- def _put_empty_sample_batch_every_n_sec(self):
- # Places an empty SampleBatch every `idle_timeout` seconds onto the
- # `samples_queue`. This avoids hanging of all RolloutWorkers parallel
- # to this one in case this PolicyServerInput does not have incoming
- # data (e.g. no client connected).
- while True:
- time.sleep(self.idle_timeout)
- self.samples_queue.put(SampleBatch())
- def _make_handler(rollout_worker, samples_queue, metrics_queue):
- # Only used in remote inference mode. We must create a new rollout worker
- # then since the original worker doesn't have the env properly wrapped in
- # an ExternalEnv interface.
- child_rollout_worker = None
- inference_thread = None
- lock = threading.Lock()
- def setup_child_rollout_worker():
- nonlocal lock
- nonlocal child_rollout_worker
- nonlocal inference_thread
- with lock:
- if child_rollout_worker is None:
- (child_rollout_worker,
- inference_thread) = _create_embedded_rollout_worker(
- rollout_worker.creation_args(), report_data)
- child_rollout_worker.set_weights(rollout_worker.get_weights())
- def report_data(data):
- nonlocal child_rollout_worker
- batch = data["samples"]
- batch.decompress_if_needed()
- samples_queue.put(batch)
- for rollout_metric in data["metrics"]:
- metrics_queue.put(rollout_metric)
- if child_rollout_worker is not None:
- child_rollout_worker.set_weights(rollout_worker.get_weights(),
- rollout_worker.get_global_vars())
- class Handler(SimpleHTTPRequestHandler):
- def __init__(self, *a, **kw):
- super().__init__(*a, **kw)
- def do_POST(self):
- content_len = int(self.headers.get("Content-Length"), 0)
- raw_body = self.rfile.read(content_len)
- parsed_input = pickle.loads(raw_body)
- try:
- response = self.execute_command(parsed_input)
- self.send_response(200)
- self.end_headers()
- self.wfile.write(pickle.dumps(response))
- except Exception:
- self.send_error(500, traceback.format_exc())
- def execute_command(self, args):
- command = args["command"]
- response = {}
- # Local inference commands:
- if command == PolicyClient.GET_WORKER_ARGS:
- logger.info("Sending worker creation args to client.")
- response["worker_args"] = rollout_worker.creation_args()
- elif command == PolicyClient.GET_WEIGHTS:
- logger.info("Sending worker weights to client.")
- response["weights"] = rollout_worker.get_weights()
- response["global_vars"] = rollout_worker.get_global_vars()
- elif command == PolicyClient.REPORT_SAMPLES:
- logger.info("Got sample batch of size {} from client.".format(
- args["samples"].count))
- report_data(args)
- # Remote inference commands:
- elif command == PolicyClient.START_EPISODE:
- setup_child_rollout_worker()
- assert inference_thread.is_alive()
- response["episode_id"] = (
- child_rollout_worker.env.start_episode(
- args["episode_id"], args["training_enabled"]))
- elif command == PolicyClient.GET_ACTION:
- assert inference_thread.is_alive()
- response["action"] = child_rollout_worker.env.get_action(
- args["episode_id"], args["observation"])
- elif command == PolicyClient.LOG_ACTION:
- assert inference_thread.is_alive()
- child_rollout_worker.env.log_action(
- args["episode_id"], args["observation"], args["action"])
- elif command == PolicyClient.LOG_RETURNS:
- assert inference_thread.is_alive()
- if args["done"]:
- child_rollout_worker.env.log_returns(
- args["episode_id"], args["reward"], args["info"],
- args["done"])
- else:
- child_rollout_worker.env.log_returns(
- args["episode_id"], args["reward"], args["info"])
- elif command == PolicyClient.END_EPISODE:
- assert inference_thread.is_alive()
- child_rollout_worker.env.end_episode(args["episode_id"],
- args["observation"])
- else:
- raise ValueError("Unknown command: {}".format(command))
- return response
- return Handler
|