chatglm_sse.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import os
  2. from typing import List
  3. import uvicorn
  4. from fastapi import FastAPI, APIRouter
  5. from gluon_meson_components.models.chat_model import ChatModel
  6. from pydantic import BaseModel
  7. from util import convert_history
  8. from sse_starlette.sse import EventSourceResponse
  9. PLAIN_MODEL_TYPE = os.getenv('DEFAULT_PLAIN_MODEL_TYPE', 'chatglm2-6b')
  10. STREAM_MODEL_TYPE = os.getenv('DEFAULT_STREAM_MODEL_TYPE', f'{PLAIN_MODEL_TYPE}_streaming')
  11. FLASH_ENV = os.getenv('FLASK_ENV')
  12. APP_PORT = 8000
  13. chat_model = ChatModel()
  14. app = FastAPI()
  15. router = APIRouter()
  16. class MessageInChat(BaseModel):
  17. role: str
  18. message: str
  19. class MessageInResponseChat(BaseModel):
  20. role: str
  21. content: str
  22. class ChatCommand(BaseModel):
  23. messages: List[MessageInChat]
  24. model: str = STREAM_MODEL_TYPE
  25. class ChatResponse(BaseModel):
  26. choices: List[MessageInResponseChat]
  27. model: str
  28. def parse_request(chat_command: ChatCommand):
  29. model = chat_command.model
  30. messages = chat_command.messages
  31. text = messages.pop().message
  32. history = convert_history([i.__dict__ for i in messages])
  33. return model, text, history
  34. @router.post('/messages/stream')
  35. def stream(chat_command: ChatCommand):
  36. model, text, history = parse_request(chat_command)
  37. history = list(map(tuple, history))
  38. def handle():
  39. previous_response_len = 0
  40. for response in chat_model.chat_single_streaming(text=text, model_type=model, history=history):
  41. response_data = response.response[previous_response_len:]
  42. previous_response_len = len(response.response)
  43. yield ChatResponse(choices=[MessageInResponseChat(role='assistant', content=response_data)],
  44. model=model).json()
  45. yield {'data': '[DONE]'}
  46. return EventSourceResponse(handle())
  47. app.include_router(router)
  48. if __name__ == '__main__':
  49. if FLASH_ENV == 'production':
  50. print(f"Starting server in production mode at port {APP_PORT}")
  51. uvicorn.run(app, host="0.0.0.0", port=APP_PORT)
  52. else:
  53. uvicorn.run(app, host="0.0.0.0", port=APP_PORT, log_level='debug')