bytedance_coze_bot.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # encoding:utf-8
  2. import time
  3. from typing import List, Tuple
  4. import requests
  5. from requests import Response
  6. from bot.bot import Bot
  7. from bot.chatgpt.chat_gpt_session import ChatGPTSession
  8. from bot.session_manager import SessionManager
  9. from bridge.context import ContextType
  10. from bridge.reply import Reply, ReplyType
  11. from common.log import logger
  12. from config import conf
  13. class ByteDanceCozeBot(Bot):
  14. def __init__(self):
  15. super().__init__()
  16. self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "coze")
  17. def reply(self, query, context=None):
  18. # acquire reply content
  19. if context.type == ContextType.TEXT:
  20. logger.info("[COZE] query={}".format(query))
  21. session_id = context["session_id"]
  22. session = self.sessions.session_query(query, session_id)
  23. logger.debug("[COZE] session query={}".format(session.messages))
  24. reply_content, err = self._reply_text(session_id, session)
  25. if err is not None:
  26. logger.error("[COZE] reply error={}".format(err))
  27. return Reply(ReplyType.ERROR, "我暂时遇到了一些问题,请您稍后重试~")
  28. logger.debug(
  29. "[COZE] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
  30. session.messages,
  31. session_id,
  32. reply_content["content"],
  33. reply_content["completion_tokens"],
  34. )
  35. )
  36. return Reply(ReplyType.TEXT, reply_content["content"])
  37. else:
  38. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  39. return reply
  40. def _get_api_base_url(self):
  41. return conf().get("coze_api_base", "https://api.coze.cn/open_api/v2")
  42. def _get_headers(self):
  43. return {
  44. 'Authorization': f"Bearer {conf().get('coze_api_key', '')}"
  45. }
  46. def _get_payload(self, user: str, query: str, chat_history: List[dict]):
  47. coze_bot_id = conf().get('coze_bot_id', '')
  48. coze_bot_id = str(coze_bot_id)
  49. if not coze_bot_id:
  50. logger.error("[COZE] coze_bot_id is not set")
  51. raise Exception("coze_bot_id is not set")
  52. return {
  53. 'bot_id': coze_bot_id,
  54. "user": user,
  55. "query": query,
  56. "chat_history": chat_history,
  57. "stream": False
  58. }
  59. def _reply_text(self, session_id: str, session: ChatGPTSession, retry_count=0):
  60. try:
  61. query, chat_history = self._convert_messages_format(session.messages)
  62. base_url = self._get_api_base_url()
  63. chat_url = f'{base_url}/chat'
  64. headers = self._get_headers()
  65. payload = self._get_payload(session.session_id, query, chat_history)
  66. logger.debug("[COZE] headers={}, payload={}".format(headers, payload))
  67. response = requests.post(chat_url, headers=headers, json=payload)
  68. if response.status_code != 200:
  69. error_info = f"[COZE] response text={response.text} status_code={response.status_code}"
  70. logger.warn(error_info)
  71. return None, error_info
  72. answer, err = self._get_completion_content(response)
  73. if err is not None:
  74. return None, err
  75. completion_tokens, total_tokens = self._calc_tokens(session.messages, answer)
  76. return {
  77. "total_tokens": total_tokens,
  78. "completion_tokens": completion_tokens,
  79. "content": answer
  80. }, None
  81. except Exception as e:
  82. if retry_count < 2:
  83. time.sleep(3)
  84. logger.warn(f"[COZE] Exception: {repr(e)} 第{retry_count + 1}次重试")
  85. return self._reply_text(session_id, session, retry_count + 1)
  86. else:
  87. return None, f"[COZE] Exception: {repr(e)} 超过最大重试次数"
  88. def _convert_messages_format(self, messages) -> Tuple[str, List[dict]]:
  89. # [
  90. # {"role":"user","content":"你好","content_type":"text"},
  91. # {"role":"assistant","type":"answer","content":"你好,请问有什么可以帮助你的吗?","content_type":"text"}
  92. # ]
  93. chat_history = []
  94. for message in messages:
  95. role = message.get('role')
  96. if role == 'user':
  97. content = message.get('content')
  98. chat_history.append({"role": "user", "content": content, "content_type": "text"})
  99. elif role == 'assistant':
  100. content = message.get('content')
  101. chat_history.append({"role": "assistant", "type": "answer", "content": content, "content_type": "text"})
  102. elif role == 'system':
  103. # TODO: deal system message
  104. pass
  105. user_message = chat_history.pop()
  106. if user_message.get('role') != 'user' or user_message.get('content', '') == '':
  107. raise Exception('no user message')
  108. query = user_message.get('content')
  109. logger.debug("[COZE] converted coze messages: {}".format([item for item in chat_history]))
  110. logger.debug("[COZE] user content as query: {}".format(query))
  111. return query, chat_history
  112. def _get_completion_content(self, response: Response):
  113. json_response = response.json()
  114. if json_response['msg'] != 'success':
  115. return None, f"[COZE] Error: {json_response['msg']}"
  116. answer = None
  117. for message in json_response['messages']:
  118. if message.get('type') == 'answer':
  119. answer = message.get('content')
  120. break
  121. if not answer:
  122. return None, "[COZE] Error: empty answer"
  123. return answer, None
  124. def _calc_tokens(self, messages, answer):
  125. # 简单统计token
  126. completion_tokens = len(answer)
  127. prompt_tokens = 0
  128. for message in messages:
  129. prompt_tokens += len(message["content"])
  130. return completion_tokens, prompt_tokens + completion_tokens