bridge_google_gemini.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # encoding: utf-8
  2. # @Time : 2023/12/21
  3. # @Author : Spike
  4. # @Descr :
  5. import json
  6. import re
  7. import os
  8. import time
  9. from request_llms.com_google import GoogleChatInit
  10. from toolbox import get_conf, update_ui, update_ui_lastest_msg, have_any_recent_upload_image_files, trimmed_format_exc
  11. proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
  12. timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
  13. '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
  14. def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
  15. console_slience=False):
  16. # 检查API_KEY
  17. if get_conf("GEMINI_API_KEY") == "":
  18. raise ValueError(f"请配置 GEMINI_API_KEY。")
  19. genai = GoogleChatInit()
  20. watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
  21. gpt_replying_buffer = ''
  22. stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
  23. for response in stream_response:
  24. results = response.decode()
  25. match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
  26. error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
  27. if match:
  28. try:
  29. paraphrase = json.loads('{"text": "%s"}' % match.group(1))
  30. except:
  31. raise ValueError(f"解析GEMINI消息出错。")
  32. buffer = paraphrase['text']
  33. gpt_replying_buffer += buffer
  34. if len(observe_window) >= 1:
  35. observe_window[0] = gpt_replying_buffer
  36. if len(observe_window) >= 2:
  37. if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
  38. if error_match:
  39. raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
  40. return gpt_replying_buffer
  41. def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
  42. # 检查API_KEY
  43. if get_conf("GEMINI_API_KEY") == "":
  44. yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
  45. return
  46. if "vision" in llm_kwargs["llm_model"]:
  47. have_recent_file, image_paths = have_any_recent_upload_image_files(chatbot)
  48. def make_media_input(inputs, image_paths):
  49. for image_path in image_paths:
  50. inputs = inputs + f'<br/><br/><div align="center"><img src="file={os.path.abspath(image_path)}"></div>'
  51. return inputs
  52. if have_recent_file:
  53. inputs = make_media_input(inputs, image_paths)
  54. chatbot.append((inputs, ""))
  55. yield from update_ui(chatbot=chatbot, history=history)
  56. genai = GoogleChatInit()
  57. retry = 0
  58. while True:
  59. try:
  60. stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
  61. break
  62. except Exception as e:
  63. retry += 1
  64. chatbot[-1] = ((chatbot[-1][0], trimmed_format_exc()))
  65. yield from update_ui(chatbot=chatbot, history=history, msg="请求失败") # 刷新界面
  66. return
  67. gpt_replying_buffer = ""
  68. gpt_security_policy = ""
  69. history.extend([inputs, ''])
  70. for response in stream_response:
  71. results = response.decode("utf-8") # 被这个解码给耍了。。
  72. gpt_security_policy += results
  73. match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
  74. error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
  75. if match:
  76. try:
  77. paraphrase = json.loads('{"text": "%s"}' % match.group(1))
  78. except:
  79. raise ValueError(f"解析GEMINI消息出错。")
  80. gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
  81. chatbot[-1] = (inputs, gpt_replying_buffer)
  82. history[-1] = gpt_replying_buffer
  83. yield from update_ui(chatbot=chatbot, history=history)
  84. if error_match:
  85. history = history[-2] # 错误的不纳入对话
  86. chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
  87. yield from update_ui(chatbot=chatbot, history=history)
  88. raise RuntimeError('对话错误')
  89. if not gpt_replying_buffer:
  90. history = history[-2] # 错误的不纳入对话
  91. chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
  92. yield from update_ui(chatbot=chatbot, history=history)
  93. if __name__ == '__main__':
  94. import sys
  95. llm_kwargs = {'llm_model': 'gemini-pro'}
  96. result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
  97. for i in result:
  98. print(i)