123456789101112131415161718192021222324252627282930313233343536373839404142 |
- from fastapi import FastAPI
- from transformers import pipeline
- # TODO: actually use this to predict something
- import ray
- from ray import serve
- app = FastAPI()
- # Define our deployment.
- @serve.deployment(num_replicas=2)
- class GPT2:
- def __init__(self):
- self.nlp_model = pipeline("text-generation", model="gpt2")
- async def predict(self, query: str):
- return self.nlp_model(query, max_length=50)
- async def __call__(self, request):
- return self.predict(await request.body())
- @app.on_event("startup") # Code to be run when the server starts.
- async def startup_event():
- ray.init(address="auto") # Connect to the running Ray cluster.
- serve.start(http_host=None) # Start the Ray Serve instance.
- # Deploy our GPT2 Deployment.
- GPT2.deploy()
- @app.get("/generate")
- async def generate(query: str):
- # Get a handle to our deployment so we can query it in Python.
- handle = GPT2.get_handle()
- return await handle.predict.remote(query)
- @app.on_event("shutdown") # Code to be run when the server shuts down.
- async def shutdown_event():
- serve.shutdown() # Shut down Ray Serve.
|