google_gemini_bot.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. """
  2. Google gemini bot
  3. @author zhayujie
  4. @Date 2023/12/15
  5. """
  6. # encoding:utf-8
  7. from bot.bot import Bot
  8. import google.generativeai as genai
  9. from bot.session_manager import SessionManager
  10. from bridge.context import ContextType, Context
  11. from bridge.reply import Reply, ReplyType
  12. from common.log import logger
  13. from config import conf
  14. from bot.chatgpt.chat_gpt_session import ChatGPTSession
  15. from bot.baidu.baidu_wenxin_session import BaiduWenxinSession
  16. from google.generativeai.types import HarmCategory, HarmBlockThreshold
  17. # OpenAI对话模型API (可用)
  18. class GoogleGeminiBot(Bot):
  19. def __init__(self):
  20. super().__init__()
  21. self.api_key = conf().get("gemini_api_key")
  22. # 复用chatGPT的token计算方式
  23. self.sessions = SessionManager(ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo")
  24. self.model = conf().get("model") or "gemini-pro"
  25. if self.model == "gemini":
  26. self.model = "gemini-pro"
  27. def reply(self, query, context: Context = None) -> Reply:
  28. try:
  29. if context.type != ContextType.TEXT:
  30. logger.warn(f"[Gemini] Unsupported message type, type={context.type}")
  31. return Reply(ReplyType.TEXT, None)
  32. logger.info(f"[Gemini] query={query}")
  33. session_id = context["session_id"]
  34. session = self.sessions.session_query(query, session_id)
  35. gemini_messages = self._convert_to_gemini_messages(self.filter_messages(session.messages))
  36. logger.debug(f"[Gemini] messages={gemini_messages}")
  37. genai.configure(api_key=self.api_key)
  38. model = genai.GenerativeModel(self.model)
  39. # 添加安全设置
  40. safety_settings = {
  41. HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
  42. HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
  43. HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
  44. HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
  45. }
  46. # 生成回复,包含安全设置
  47. response = model.generate_content(
  48. gemini_messages,
  49. safety_settings=safety_settings
  50. )
  51. if response.candidates and response.candidates[0].content:
  52. reply_text = response.candidates[0].content.parts[0].text
  53. logger.info(f"[Gemini] reply={reply_text}")
  54. self.sessions.session_reply(reply_text, session_id)
  55. return Reply(ReplyType.TEXT, reply_text)
  56. else:
  57. # 没有有效响应内容,可能内容被屏蔽,输出安全评分
  58. logger.warning("[Gemini] No valid response generated. Checking safety ratings.")
  59. if hasattr(response, 'candidates') and response.candidates:
  60. for rating in response.candidates[0].safety_ratings:
  61. logger.warning(f"Safety rating: {rating.category} - {rating.probability}")
  62. error_message = "No valid response generated due to safety constraints."
  63. self.sessions.session_reply(error_message, session_id)
  64. return Reply(ReplyType.ERROR, error_message)
  65. except Exception as e:
  66. logger.error(f"[Gemini] Error generating response: {str(e)}", exc_info=True)
  67. error_message = "Failed to invoke [Gemini] api!"
  68. self.sessions.session_reply(error_message, session_id)
  69. return Reply(ReplyType.ERROR, error_message)
  70. def _convert_to_gemini_messages(self, messages: list):
  71. res = []
  72. for msg in messages:
  73. if msg.get("role") == "user":
  74. role = "user"
  75. elif msg.get("role") == "assistant":
  76. role = "model"
  77. elif msg.get("role") == "system":
  78. role = "user"
  79. else:
  80. continue
  81. res.append({
  82. "role": role,
  83. "parts": [{"text": msg.get("content")}]
  84. })
  85. return res
  86. @staticmethod
  87. def filter_messages(messages: list):
  88. res = []
  89. turn = "user"
  90. if not messages:
  91. return res
  92. for i in range(len(messages) - 1, -1, -1):
  93. message = messages[i]
  94. role = message.get("role")
  95. if role == "system":
  96. res.insert(0, message)
  97. continue
  98. if role != turn:
  99. continue
  100. res.insert(0, message)
  101. if turn == "user":
  102. turn = "assistant"
  103. elif turn == "assistant":
  104. turn = "user"
  105. return res