open_ai_vision.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import base64
  2. import requests
  3. from common.log import logger
  4. from common import const, utils, memory
  5. from config import conf
  6. # OPENAI提供的图像识别接口
  7. class OpenAIVision(object):
  8. def do_vision_completion_if_need(self, session_id: str, query: str):
  9. img_cache = memory.USER_IMAGE_CACHE.get(session_id)
  10. if img_cache and conf().get("image_recognition"):
  11. response, err = self.vision_completion(query, img_cache)
  12. if err:
  13. return {"completion_tokens": 0, "content": f"识别图片异常, {err}"}
  14. memory.USER_IMAGE_CACHE[session_id] = None
  15. return {
  16. "total_tokens": response["usage"]["total_tokens"],
  17. "completion_tokens": response["usage"]["completion_tokens"],
  18. "content": response['choices'][0]["message"]["content"],
  19. }
  20. return None
  21. def vision_completion(self, query: str, img_cache: dict):
  22. msg = img_cache.get("msg")
  23. path = img_cache.get("path")
  24. msg.prepare()
  25. logger.info(f"[CHATGPT] query with images, path={path}")
  26. payload = {
  27. "model": const.GPT4_VISION_PREVIEW,
  28. "messages": self.build_vision_msg(query, path),
  29. "temperature": conf().get("temperature"),
  30. "top_p": conf().get("top_p", 1),
  31. "frequency_penalty": conf().get("frequency_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  32. "presence_penalty": conf().get("presence_penalty", 0.0), # [-2,2]之间,该值越大则更倾向于产生不同的内容
  33. }
  34. headers = {"Authorization": "Bearer " + conf().get("open_ai_api_key", "")}
  35. # do http request
  36. base_url = conf().get("open_ai_api_base", "https://api.openai.com/v1")
  37. res = requests.post(url=base_url + "/chat/completions", json=payload, headers=headers,
  38. timeout=conf().get("request_timeout", 180))
  39. if res.status_code == 200:
  40. return res.json(), None
  41. else:
  42. logger.error(f"[CHATGPT] vision completion, status_code={res.status_code}, response={res.text}")
  43. return None, res.text
  44. def build_vision_msg(self, query: str, path: str):
  45. suffix = utils.get_path_suffix(path)
  46. with open(path, "rb") as file:
  47. base64_str = base64.b64encode(file.read()).decode('utf-8')
  48. messages = [{
  49. "role": "user",
  50. "content": [
  51. {
  52. "type": "text",
  53. "text": query
  54. },
  55. {
  56. "type": "image_url",
  57. "image_url": {
  58. "url": f"data:image/{suffix};base64,{base64_str}"
  59. }
  60. }
  61. ]
  62. }]
  63. return messages