ray_serve.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from fastapi import FastAPI
  2. from transformers import pipeline
  3. # TODO: actually use this to predict something
  4. import ray
  5. from ray import serve
  6. app = FastAPI()
  7. # Define our deployment.
  8. @serve.deployment(num_replicas=2)
  9. class GPT2:
  10. def __init__(self):
  11. self.nlp_model = pipeline("text-generation", model="gpt2")
  12. async def predict(self, query: str):
  13. return self.nlp_model(query, max_length=50)
  14. async def __call__(self, request):
  15. return self.predict(await request.body())
  16. @app.on_event("startup") # Code to be run when the server starts.
  17. async def startup_event():
  18. ray.init(address="auto") # Connect to the running Ray cluster.
  19. serve.start(http_host=None) # Start the Ray Serve instance.
  20. # Deploy our GPT2 Deployment.
  21. GPT2.deploy()
  22. @app.get("/generate")
  23. async def generate(query: str):
  24. # Get a handle to our deployment so we can query it in Python.
  25. handle = GPT2.get_handle()
  26. return await handle.predict.remote(query)
  27. @app.on_event("shutdown") # Code to be run when the server shuts down.
  28. async def shutdown_event():
  29. serve.shutdown() # Shut down Ray Serve.