bridge_qianfan.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import time, requests, json
  2. from multiprocessing import Process, Pipe
  3. from functools import wraps
  4. from datetime import datetime, timedelta
  5. from toolbox import get_conf, update_ui, is_any_api_key, select_api_key, what_keys, clip_history, trimmed_format_exc, get_conf
  6. model_name = '千帆大模型平台'
  7. timeout_bot_msg = '[Local Message] Request timeout. Network error.'
  8. def cache_decorator(timeout):
  9. cache = {}
  10. def decorator(func):
  11. @wraps(func)
  12. def wrapper(*args, **kwargs):
  13. key = (func.__name__, args, frozenset(kwargs.items()))
  14. # Check if result is already cached and not expired
  15. if key in cache:
  16. result, timestamp = cache[key]
  17. if datetime.now() - timestamp < timedelta(seconds=timeout):
  18. return result
  19. # Call the function and cache the result
  20. result = func(*args, **kwargs)
  21. cache[key] = (result, datetime.now())
  22. return result
  23. return wrapper
  24. return decorator
  25. @cache_decorator(timeout=3600)
  26. def get_access_token():
  27. """
  28. 使用 AK,SK 生成鉴权签名(Access Token)
  29. :return: access_token,或是None(如果错误)
  30. """
  31. # if (access_token_cache is None) or (time.time() - last_access_token_obtain_time > 3600):
  32. BAIDU_CLOUD_API_KEY, BAIDU_CLOUD_SECRET_KEY = get_conf('BAIDU_CLOUD_API_KEY', 'BAIDU_CLOUD_SECRET_KEY')
  33. if len(BAIDU_CLOUD_SECRET_KEY) == 0: raise RuntimeError("没有配置BAIDU_CLOUD_SECRET_KEY")
  34. if len(BAIDU_CLOUD_API_KEY) == 0: raise RuntimeError("没有配置BAIDU_CLOUD_API_KEY")
  35. url = "https://aip.baidubce.com/oauth/2.0/token"
  36. params = {"grant_type": "client_credentials", "client_id": BAIDU_CLOUD_API_KEY, "client_secret": BAIDU_CLOUD_SECRET_KEY}
  37. access_token_cache = str(requests.post(url, params=params).json().get("access_token"))
  38. return access_token_cache
  39. # else:
  40. # return access_token_cache
  41. def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
  42. conversation_cnt = len(history) // 2
  43. if system_prompt == "": system_prompt = "Hello"
  44. messages = [{"role": "user", "content": system_prompt}]
  45. messages.append({"role": "assistant", "content": 'Certainly!'})
  46. if conversation_cnt:
  47. for index in range(0, 2*conversation_cnt, 2):
  48. what_i_have_asked = {}
  49. what_i_have_asked["role"] = "user"
  50. what_i_have_asked["content"] = history[index] if history[index]!="" else "Hello"
  51. what_gpt_answer = {}
  52. what_gpt_answer["role"] = "assistant"
  53. what_gpt_answer["content"] = history[index+1] if history[index]!="" else "Hello"
  54. if what_i_have_asked["content"] != "":
  55. if what_gpt_answer["content"] == "": continue
  56. if what_gpt_answer["content"] == timeout_bot_msg: continue
  57. messages.append(what_i_have_asked)
  58. messages.append(what_gpt_answer)
  59. else:
  60. messages[-1]['content'] = what_gpt_answer['content']
  61. what_i_ask_now = {}
  62. what_i_ask_now["role"] = "user"
  63. what_i_ask_now["content"] = inputs
  64. messages.append(what_i_ask_now)
  65. return messages
  66. def generate_from_baidu_qianfan(inputs, llm_kwargs, history, system_prompt):
  67. BAIDU_CLOUD_QIANFAN_MODEL = get_conf('BAIDU_CLOUD_QIANFAN_MODEL')
  68. url_lib = {
  69. "ERNIE-Bot-4": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro",
  70. "ERNIE-Bot": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions",
  71. "ERNIE-Bot-turbo": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant",
  72. "BLOOMZ-7B": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
  73. "Llama-2-70B-Chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/llama_2_70b",
  74. "Llama-2-13B-Chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/llama_2_13b",
  75. "Llama-2-7B-Chat": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/llama_2_7b",
  76. }
  77. url = url_lib[BAIDU_CLOUD_QIANFAN_MODEL]
  78. url += "?access_token=" + get_access_token()
  79. payload = json.dumps({
  80. "messages": generate_message_payload(inputs, llm_kwargs, history, system_prompt),
  81. "stream": True
  82. })
  83. headers = {
  84. 'Content-Type': 'application/json'
  85. }
  86. response = requests.request("POST", url, headers=headers, data=payload, stream=True)
  87. buffer = ""
  88. for line in response.iter_lines():
  89. if len(line) == 0: continue
  90. try:
  91. dec = line.decode().lstrip('data:')
  92. dec = json.loads(dec)
  93. incoming = dec['result']
  94. buffer += incoming
  95. yield buffer
  96. except:
  97. if ('error_code' in dec) and ("max length" in dec['error_msg']):
  98. raise ConnectionAbortedError(dec['error_msg']) # 上下文太长导致 token 溢出
  99. elif ('error_code' in dec):
  100. raise RuntimeError(dec['error_msg'])
  101. def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
  102. """
  103. ⭐多线程方法
  104. 函数的说明请见 request_llms/bridge_all.py
  105. """
  106. watch_dog_patience = 5
  107. response = ""
  108. for response in generate_from_baidu_qianfan(inputs, llm_kwargs, history, sys_prompt):
  109. if len(observe_window) >= 1:
  110. observe_window[0] = response
  111. if len(observe_window) >= 2:
  112. if (time.time()-observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
  113. return response
  114. def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
  115. """
  116. ⭐单线程方法
  117. 函数的说明请见 request_llms/bridge_all.py
  118. """
  119. chatbot.append((inputs, ""))
  120. if additional_fn is not None:
  121. from core_functional import handle_core_functionality
  122. inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
  123. yield from update_ui(chatbot=chatbot, history=history)
  124. # 开始接收回复
  125. try:
  126. for response in generate_from_baidu_qianfan(inputs, llm_kwargs, history, system_prompt):
  127. chatbot[-1] = (inputs, response)
  128. yield from update_ui(chatbot=chatbot, history=history)
  129. except ConnectionAbortedError as e:
  130. from .bridge_all import model_info
  131. if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
  132. history = clip_history(inputs=inputs, history=history, tokenizer=model_info[llm_kwargs['llm_model']]['tokenizer'],
  133. max_token_limit=(model_info[llm_kwargs['llm_model']]['max_token'])) # history至少释放二分之一
  134. chatbot[-1] = (chatbot[-1][0], "[Local Message] Reduce the length. 本次输入过长, 或历史数据过长. 历史缓存数据已部分释放, 您可以请再次尝试. (若再次失败则更可能是因为输入过长.)")
  135. yield from update_ui(chatbot=chatbot, history=history, msg="异常") # 刷新界面
  136. return
  137. # 总结输出
  138. response = f"[Local Message] {model_name}响应异常 ..."
  139. if response == f"[Local Message] 等待{model_name}响应中 ...":
  140. response = f"[Local Message] {model_name}响应异常 ..."
  141. history.extend([inputs, response])
  142. yield from update_ui(chatbot=chatbot, history=history)