bridge_chatgpt_website.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # 借鉴了 https://github.com/GaiZhenbiao/ChuanhuChatGPT 项目
  2. """
  3. 该文件中主要包含三个函数
  4. 不具备多线程能力的函数:
  5. 1. predict: 正常对话时使用,具备完备的交互功能,不可多线程
  6. 具备多线程调用能力的函数
  7. 2. predict_no_ui_long_connection:支持多线程
  8. """
  9. import json
  10. import time
  11. import gradio as gr
  12. import logging
  13. import traceback
  14. import requests
  15. import importlib
  16. # config_private.py放自己的秘密如API和代理网址
  17. # 读取时首先看是否存在私密的config_private配置文件(不受git管控),如果有,则覆盖原config文件
  18. from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc
  19. proxies, TIMEOUT_SECONDS, MAX_RETRY, API_ORG = \
  20. get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY', 'API_ORG')
  21. timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
  22. '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
  23. def get_full_error(chunk, stream_response):
  24. """
  25. 获取完整的从Openai返回的报错
  26. """
  27. while True:
  28. try:
  29. chunk += next(stream_response)
  30. except:
  31. break
  32. return chunk
  33. def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, console_slience=False):
  34. """
  35. 发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免中途网线被掐。
  36. inputs:
  37. 是本次问询的输入
  38. sys_prompt:
  39. 系统静默prompt
  40. llm_kwargs:
  41. chatGPT的内部调优参数
  42. history:
  43. 是之前的对话列表
  44. observe_window = None:
  45. 用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
  46. """
  47. watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
  48. headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
  49. retry = 0
  50. while True:
  51. try:
  52. # make a POST request to the API endpoint, stream=False
  53. from .bridge_all import model_info
  54. endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
  55. response = requests.post(endpoint, headers=headers, proxies=proxies,
  56. json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
  57. except requests.exceptions.ReadTimeout as e:
  58. retry += 1
  59. traceback.print_exc()
  60. if retry > MAX_RETRY: raise TimeoutError
  61. if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
  62. stream_response = response.iter_lines()
  63. result = ''
  64. while True:
  65. try: chunk = next(stream_response).decode()
  66. except StopIteration:
  67. break
  68. except requests.exceptions.ConnectionError:
  69. chunk = next(stream_response).decode() # 失败了,重试一次?再失败就没办法了。
  70. if len(chunk)==0: continue
  71. if not chunk.startswith('data:'):
  72. error_msg = get_full_error(chunk.encode('utf8'), stream_response).decode()
  73. if "reduce the length" in error_msg:
  74. raise ConnectionAbortedError("OpenAI拒绝了请求:" + error_msg)
  75. else:
  76. raise RuntimeError("OpenAI拒绝了请求:" + error_msg)
  77. if ('data: [DONE]' in chunk): break # api2d 正常完成
  78. json_data = json.loads(chunk.lstrip('data:'))['choices'][0]
  79. delta = json_data["delta"]
  80. if len(delta) == 0: break
  81. if "role" in delta: continue
  82. if "content" in delta:
  83. result += delta["content"]
  84. if not console_slience: print(delta["content"], end='')
  85. if observe_window is not None:
  86. # 观测窗,把已经获取的数据显示出去
  87. if len(observe_window) >= 1: observe_window[0] += delta["content"]
  88. # 看门狗,如果超过期限没有喂狗,则终止
  89. if len(observe_window) >= 2:
  90. if (time.time()-observe_window[1]) > watch_dog_patience:
  91. raise RuntimeError("用户取消了程序。")
  92. else: raise RuntimeError("意外Json结构:"+delta)
  93. if json_data['finish_reason'] == 'content_filter':
  94. raise RuntimeError("由于提问含不合规内容被Azure过滤。")
  95. if json_data['finish_reason'] == 'length':
  96. raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
  97. return result
  98. def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
  99. """
  100. 发送至chatGPT,流式获取输出。
  101. 用于基础的对话功能。
  102. inputs 是本次问询的输入
  103. top_p, temperature是chatGPT的内部调优参数
  104. history 是之前的对话列表(注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误)
  105. chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
  106. additional_fn代表点击的哪个按钮,按钮见functional.py
  107. """
  108. if additional_fn is not None:
  109. from core_functional import handle_core_functionality
  110. inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
  111. raw_input = inputs
  112. logging.info(f'[raw_input] {raw_input}')
  113. chatbot.append((inputs, ""))
  114. yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
  115. try:
  116. headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt, stream)
  117. except RuntimeError as e:
  118. chatbot[-1] = (inputs, f"您提供的api-key不满足要求,不包含任何可用于{llm_kwargs['llm_model']}的api-key。您可能选择了错误的模型或请求源。")
  119. yield from update_ui(chatbot=chatbot, history=history, msg="api-key不满足要求") # 刷新界面
  120. return
  121. history.append(inputs); history.append("")
  122. retry = 0
  123. while True:
  124. try:
  125. # make a POST request to the API endpoint, stream=True
  126. from .bridge_all import model_info
  127. endpoint = model_info[llm_kwargs['llm_model']]['endpoint']
  128. response = requests.post(endpoint, headers=headers, proxies=proxies,
  129. json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
  130. except:
  131. retry += 1
  132. chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
  133. retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
  134. yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
  135. if retry > MAX_RETRY: raise TimeoutError
  136. gpt_replying_buffer = ""
  137. is_head_of_the_stream = True
  138. if stream:
  139. stream_response = response.iter_lines()
  140. while True:
  141. try:
  142. chunk = next(stream_response)
  143. except StopIteration:
  144. # 非OpenAI官方接口的出现这样的报错,OpenAI和API2D不会走这里
  145. chunk_decoded = chunk.decode()
  146. error_msg = chunk_decoded
  147. chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
  148. yield from update_ui(chatbot=chatbot, history=history, msg="非Openai官方接口返回了错误:" + chunk.decode()) # 刷新界面
  149. return
  150. # print(chunk.decode()[6:])
  151. if is_head_of_the_stream and (r'"object":"error"' not in chunk.decode()):
  152. # 数据流的第一帧不携带content
  153. is_head_of_the_stream = False; continue
  154. if chunk:
  155. try:
  156. chunk_decoded = chunk.decode()
  157. # 前者是API2D的结束条件,后者是OPENAI的结束条件
  158. if 'data: [DONE]' in chunk_decoded:
  159. # 判定为数据流的结束,gpt_replying_buffer也写完了
  160. logging.info(f'[response] {gpt_replying_buffer}')
  161. break
  162. # 处理数据流的主体
  163. chunkjson = json.loads(chunk_decoded[6:])
  164. status_text = f"finish_reason: {chunkjson['choices'][0]['finish_reason']}"
  165. delta = chunkjson['choices'][0]["delta"]
  166. if "content" in delta:
  167. gpt_replying_buffer = gpt_replying_buffer + delta["content"]
  168. history[-1] = gpt_replying_buffer
  169. chatbot[-1] = (history[-2], history[-1])
  170. yield from update_ui(chatbot=chatbot, history=history, msg=status_text) # 刷新界面
  171. except Exception as e:
  172. yield from update_ui(chatbot=chatbot, history=history, msg="Json解析不合常规") # 刷新界面
  173. chunk = get_full_error(chunk, stream_response)
  174. chunk_decoded = chunk.decode()
  175. error_msg = chunk_decoded
  176. chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
  177. yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
  178. print(error_msg)
  179. return
  180. def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
  181. from .bridge_all import model_info
  182. openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
  183. if "reduce the length" in error_msg:
  184. if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
  185. history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
  186. max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
  187. chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
  188. # history = [] # 清除历史
  189. elif "does not exist" in error_msg:
  190. chatbot[-1] = (chatbot[-1][0], f"[Local Message] Model {llm_kwargs['llm_model']} does not exist. 模型不存在, 或者您没有获得体验资格.")
  191. elif "Incorrect API key" in error_msg:
  192. chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key. OpenAI以提供了不正确的API_KEY为由, 拒绝服务. " + openai_website)
  193. elif "exceeded your current quota" in error_msg:
  194. chatbot[-1] = (chatbot[-1][0], "[Local Message] You exceeded your current quota. OpenAI以账户额度不足为由, 拒绝服务." + openai_website)
  195. elif "account is not active" in error_msg:
  196. chatbot[-1] = (chatbot[-1][0], "[Local Message] Your account is not active. OpenAI以账户失效为由, 拒绝服务." + openai_website)
  197. elif "associated with a deactivated account" in error_msg:
  198. chatbot[-1] = (chatbot[-1][0], "[Local Message] You are associated with a deactivated account. OpenAI以账户失效为由, 拒绝服务." + openai_website)
  199. elif "bad forward key" in error_msg:
  200. chatbot[-1] = (chatbot[-1][0], "[Local Message] Bad forward key. API2D账户额度不足.")
  201. elif "Not enough point" in error_msg:
  202. chatbot[-1] = (chatbot[-1][0], "[Local Message] Not enough point. API2D账户点数不足.")
  203. else:
  204. from toolbox import regular_txt_to_markdown
  205. tb_str = '```\n' + trimmed_format_exc() + '```'
  206. chatbot[-1] = (chatbot[-1][0], f"[Local Message] 异常 \n\n{tb_str} \n\n{regular_txt_to_markdown(chunk_decoded)}")
  207. return chatbot, history
  208. def generate_payload(inputs, llm_kwargs, history, system_prompt, stream):
  209. """
  210. 整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
  211. """
  212. if not is_any_api_key(llm_kwargs['api_key']):
  213. raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
  214. headers = {
  215. "Content-Type": "application/json",
  216. }
  217. conversation_cnt = len(history) // 2
  218. messages = [{"role": "system", "content": system_prompt}]
  219. if conversation_cnt:
  220. for index in range(0, 2*conversation_cnt, 2):
  221. what_i_have_asked = {}
  222. what_i_have_asked["role"] = "user"
  223. what_i_have_asked["content"] = history[index]
  224. what_gpt_answer = {}
  225. what_gpt_answer["role"] = "assistant"
  226. what_gpt_answer["content"] = history[index+1]
  227. if what_i_have_asked["content"] != "":
  228. if what_gpt_answer["content"] == "": continue
  229. if what_gpt_answer["content"] == timeout_bot_msg: continue
  230. messages.append(what_i_have_asked)
  231. messages.append(what_gpt_answer)
  232. else:
  233. messages[-1]['content'] = what_gpt_answer['content']
  234. what_i_ask_now = {}
  235. what_i_ask_now["role"] = "user"
  236. what_i_ask_now["content"] = inputs
  237. messages.append(what_i_ask_now)
  238. payload = {
  239. "model": llm_kwargs['llm_model'].strip('api2d-'),
  240. "messages": messages,
  241. "temperature": llm_kwargs['temperature'], # 1.0,
  242. "top_p": llm_kwargs['top_p'], # 1.0,
  243. "n": 1,
  244. "stream": stream,
  245. "presence_penalty": 0,
  246. "frequency_penalty": 0,
  247. }
  248. try:
  249. print(f" {llm_kwargs['llm_model']} : {conversation_cnt} : {inputs[:100]} ..........")
  250. except:
  251. print('输入中可能存在乱码。')
  252. return headers,payload