moonshot_bot.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # encoding:utf-8
  2. import time
  3. import openai
  4. import openai.error
  5. from bot.bot import Bot
  6. from bot.session_manager import SessionManager
  7. from bridge.context import ContextType
  8. from bridge.reply import Reply, ReplyType
  9. from common.log import logger
  10. from config import conf, load_config
  11. from .moonshot_session import MoonshotSession
  12. import requests
  13. # ZhipuAI对话模型API
  14. class MoonshotBot(Bot):
  15. def __init__(self):
  16. super().__init__()
  17. self.sessions = SessionManager(MoonshotSession, model=conf().get("model") or "moonshot-v1-128k")
  18. model = conf().get("model") or "moonshot-v1-128k"
  19. if model == "moonshot":
  20. model = "moonshot-v1-32k"
  21. self.args = {
  22. "model": model, # 对话模型的名称
  23. "temperature": conf().get("temperature", 0.3), # 如果设置,值域须为 [0, 1] 我们推荐 0.3,以达到较合适的效果。
  24. "top_p": conf().get("top_p", 1.0), # 使用默认值
  25. }
  26. self.api_key = conf().get("moonshot_api_key")
  27. self.base_url = conf().get("moonshot_base_url", "https://api.moonshot.cn/v1/chat/completions")
  28. def reply(self, query, context=None):
  29. # acquire reply content
  30. if context.type == ContextType.TEXT:
  31. logger.info("[MOONSHOT_AI] query={}".format(query))
  32. session_id = context["session_id"]
  33. reply = None
  34. clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"])
  35. if query in clear_memory_commands:
  36. self.sessions.clear_session(session_id)
  37. reply = Reply(ReplyType.INFO, "记忆已清除")
  38. elif query == "#清除所有":
  39. self.sessions.clear_all_session()
  40. reply = Reply(ReplyType.INFO, "所有人记忆已清除")
  41. elif query == "#更新配置":
  42. load_config()
  43. reply = Reply(ReplyType.INFO, "配置已更新")
  44. if reply:
  45. return reply
  46. session = self.sessions.session_query(query, session_id)
  47. logger.debug("[MOONSHOT_AI] session query={}".format(session.messages))
  48. model = context.get("moonshot_model")
  49. new_args = self.args.copy()
  50. if model:
  51. new_args["model"] = model
  52. # if context.get('stream'):
  53. # # reply in stream
  54. # return self.reply_text_stream(query, new_query, session_id)
  55. reply_content = self.reply_text(session, args=new_args)
  56. logger.debug(
  57. "[MOONSHOT_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(
  58. session.messages,
  59. session_id,
  60. reply_content["content"],
  61. reply_content["completion_tokens"],
  62. )
  63. )
  64. if reply_content["completion_tokens"] == 0 and len(reply_content["content"]) > 0:
  65. reply = Reply(ReplyType.ERROR, reply_content["content"])
  66. elif reply_content["completion_tokens"] > 0:
  67. self.sessions.session_reply(reply_content["content"], session_id, reply_content["total_tokens"])
  68. reply = Reply(ReplyType.TEXT, reply_content["content"])
  69. else:
  70. reply = Reply(ReplyType.ERROR, reply_content["content"])
  71. logger.debug("[MOONSHOT_AI] reply {} used 0 tokens.".format(reply_content))
  72. return reply
  73. else:
  74. reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type))
  75. return reply
  76. def reply_text(self, session: MoonshotSession, args=None, retry_count=0) -> dict:
  77. """
  78. call openai's ChatCompletion to get the answer
  79. :param session: a conversation session
  80. :param session_id: session id
  81. :param retry_count: retry count
  82. :return: {}
  83. """
  84. try:
  85. headers = {
  86. "Content-Type": "application/json",
  87. "Authorization": "Bearer " + self.api_key
  88. }
  89. body = args
  90. body["messages"] = session.messages
  91. # logger.debug("[MOONSHOT_AI] response={}".format(response))
  92. # logger.info("[MOONSHOT_AI] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"]))
  93. res = requests.post(
  94. self.base_url,
  95. headers=headers,
  96. json=body
  97. )
  98. if res.status_code == 200:
  99. response = res.json()
  100. return {
  101. "total_tokens": response["usage"]["total_tokens"],
  102. "completion_tokens": response["usage"]["completion_tokens"],
  103. "content": response["choices"][0]["message"]["content"]
  104. }
  105. else:
  106. response = res.json()
  107. error = response.get("error")
  108. logger.error(f"[MOONSHOT_AI] chat failed, status_code={res.status_code}, "
  109. f"msg={error.get('message')}, type={error.get('type')}")
  110. result = {"completion_tokens": 0, "content": "提问太快啦,请休息一下再问我吧"}
  111. need_retry = False
  112. if res.status_code >= 500:
  113. # server error, need retry
  114. logger.warn(f"[MOONSHOT_AI] do retry, times={retry_count}")
  115. need_retry = retry_count < 2
  116. elif res.status_code == 401:
  117. result["content"] = "授权失败,请检查API Key是否正确"
  118. elif res.status_code == 429:
  119. result["content"] = "请求过于频繁,请稍后再试"
  120. need_retry = retry_count < 2
  121. else:
  122. need_retry = False
  123. if need_retry:
  124. time.sleep(3)
  125. return self.reply_text(session, args, retry_count + 1)
  126. else:
  127. return result
  128. except Exception as e:
  129. logger.exception(e)
  130. need_retry = retry_count < 2
  131. result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"}
  132. if need_retry:
  133. return self.reply_text(session, args, retry_count + 1)
  134. else:
  135. return result