com_sparkapi.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from toolbox import get_conf, get_pictures_list, encode_image
  2. import base64
  3. import datetime
  4. import hashlib
  5. import hmac
  6. import json
  7. from urllib.parse import urlparse
  8. import ssl
  9. from datetime import datetime
  10. from time import mktime
  11. from urllib.parse import urlencode
  12. from wsgiref.handlers import format_date_time
  13. import websocket
  14. import threading, time
  15. timeout_bot_msg = '[Local Message] Request timeout. Network error.'
  16. class Ws_Param(object):
  17. # 初始化
  18. def __init__(self, APPID, APIKey, APISecret, gpt_url):
  19. self.APPID = APPID
  20. self.APIKey = APIKey
  21. self.APISecret = APISecret
  22. self.host = urlparse(gpt_url).netloc
  23. self.path = urlparse(gpt_url).path
  24. self.gpt_url = gpt_url
  25. # 生成url
  26. def create_url(self):
  27. # 生成RFC1123格式的时间戳
  28. now = datetime.now()
  29. date = format_date_time(mktime(now.timetuple()))
  30. # 拼接字符串
  31. signature_origin = "host: " + self.host + "\n"
  32. signature_origin += "date: " + date + "\n"
  33. signature_origin += "GET " + self.path + " HTTP/1.1"
  34. # 进行hmac-sha256进行加密
  35. signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest()
  36. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  37. authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  38. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  39. # 将请求的鉴权参数组合为字典
  40. v = {
  41. "authorization": authorization,
  42. "date": date,
  43. "host": self.host
  44. }
  45. # 拼接鉴权参数,生成url
  46. url = self.gpt_url + '?' + urlencode(v)
  47. # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
  48. return url
  49. class SparkRequestInstance():
  50. def __init__(self):
  51. XFYUN_APPID, XFYUN_API_SECRET, XFYUN_API_KEY = get_conf('XFYUN_APPID', 'XFYUN_API_SECRET', 'XFYUN_API_KEY')
  52. if XFYUN_APPID == '00000000' or XFYUN_APPID == '': raise RuntimeError('请配置讯飞星火大模型的XFYUN_APPID, XFYUN_API_KEY, XFYUN_API_SECRET')
  53. self.appid = XFYUN_APPID
  54. self.api_secret = XFYUN_API_SECRET
  55. self.api_key = XFYUN_API_KEY
  56. self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
  57. self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat"
  58. self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat"
  59. self.gpt_url_img = "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image"
  60. self.time_to_yield_event = threading.Event()
  61. self.time_to_exit_event = threading.Event()
  62. self.result_buf = ""
  63. def generate(self, inputs, llm_kwargs, history, system_prompt, use_image_api=False):
  64. llm_kwargs = llm_kwargs
  65. history = history
  66. system_prompt = system_prompt
  67. import _thread as thread
  68. thread.start_new_thread(self.create_blocking_request, (inputs, llm_kwargs, history, system_prompt, use_image_api))
  69. while True:
  70. self.time_to_yield_event.wait(timeout=1)
  71. if self.time_to_yield_event.is_set():
  72. yield self.result_buf
  73. if self.time_to_exit_event.is_set():
  74. return self.result_buf
  75. def create_blocking_request(self, inputs, llm_kwargs, history, system_prompt, use_image_api):
  76. if llm_kwargs['llm_model'] == 'sparkv2':
  77. gpt_url = self.gpt_url_v2
  78. elif llm_kwargs['llm_model'] == 'sparkv3':
  79. gpt_url = self.gpt_url_v3
  80. else:
  81. gpt_url = self.gpt_url
  82. file_manifest = []
  83. if use_image_api and llm_kwargs.get('most_recent_uploaded'):
  84. if llm_kwargs['most_recent_uploaded'].get('path'):
  85. file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
  86. if len(file_manifest) > 0:
  87. print('正在使用讯飞图片理解API')
  88. gpt_url = self.gpt_url_img
  89. wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
  90. websocket.enableTrace(False)
  91. wsUrl = wsParam.create_url()
  92. # 收到websocket连接建立的处理
  93. def on_open(ws):
  94. import _thread as thread
  95. thread.start_new_thread(run, (ws,))
  96. def run(ws, *args):
  97. data = json.dumps(gen_params(ws.appid, *ws.all_args, file_manifest))
  98. ws.send(data)
  99. # 收到websocket消息的处理
  100. def on_message(ws, message):
  101. data = json.loads(message)
  102. code = data['header']['code']
  103. if code != 0:
  104. print(f'请求错误: {code}, {data}')
  105. self.result_buf += str(data)
  106. ws.close()
  107. self.time_to_exit_event.set()
  108. else:
  109. choices = data["payload"]["choices"]
  110. status = choices["status"]
  111. content = choices["text"][0]["content"]
  112. ws.content += content
  113. self.result_buf += content
  114. if status == 2:
  115. ws.close()
  116. self.time_to_exit_event.set()
  117. self.time_to_yield_event.set()
  118. # 收到websocket错误的处理
  119. def on_error(ws, error):
  120. print("error:", error)
  121. self.time_to_exit_event.set()
  122. # 收到websocket关闭的处理
  123. def on_close(ws, *args):
  124. self.time_to_exit_event.set()
  125. # websocket
  126. ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
  127. ws.appid = self.appid
  128. ws.content = ""
  129. ws.all_args = (inputs, llm_kwargs, history, system_prompt)
  130. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  131. def generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest):
  132. conversation_cnt = len(history) // 2
  133. messages = []
  134. if file_manifest:
  135. base64_images = []
  136. for image_path in file_manifest:
  137. base64_images.append(encode_image(image_path))
  138. for img_s in base64_images:
  139. if img_s not in str(messages):
  140. messages.append({"role": "user", "content": img_s, "content_type": "image"})
  141. else:
  142. messages = [{"role": "system", "content": system_prompt}]
  143. if conversation_cnt:
  144. for index in range(0, 2*conversation_cnt, 2):
  145. what_i_have_asked = {}
  146. what_i_have_asked["role"] = "user"
  147. what_i_have_asked["content"] = history[index]
  148. what_gpt_answer = {}
  149. what_gpt_answer["role"] = "assistant"
  150. what_gpt_answer["content"] = history[index+1]
  151. if what_i_have_asked["content"] != "":
  152. if what_gpt_answer["content"] == "": continue
  153. if what_gpt_answer["content"] == timeout_bot_msg: continue
  154. messages.append(what_i_have_asked)
  155. messages.append(what_gpt_answer)
  156. else:
  157. messages[-1]['content'] = what_gpt_answer['content']
  158. what_i_ask_now = {}
  159. what_i_ask_now["role"] = "user"
  160. what_i_ask_now["content"] = inputs
  161. messages.append(what_i_ask_now)
  162. return messages
  163. def gen_params(appid, inputs, llm_kwargs, history, system_prompt, file_manifest):
  164. """
  165. 通过appid和用户的提问来生成请参数
  166. """
  167. domains = {
  168. "spark": "general",
  169. "sparkv2": "generalv2",
  170. "sparkv3": "generalv3",
  171. }
  172. domains_select = domains[llm_kwargs['llm_model']]
  173. if file_manifest: domains_select = 'image'
  174. data = {
  175. "header": {
  176. "app_id": appid,
  177. "uid": "1234"
  178. },
  179. "parameter": {
  180. "chat": {
  181. "domain": domains_select,
  182. "temperature": llm_kwargs["temperature"],
  183. "random_threshold": 0.5,
  184. "max_tokens": 4096,
  185. "auditing": "default"
  186. }
  187. },
  188. "payload": {
  189. "message": {
  190. "text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest)
  191. }
  192. }
  193. }
  194. return data