com_google.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # encoding: utf-8
  2. # @Time : 2023/12/25
  3. # @Author : Spike
  4. # @Descr :
  5. import json
  6. import os
  7. import re
  8. import requests
  9. from typing import List, Dict, Tuple
  10. from toolbox import get_conf, encode_image, get_pictures_list
  11. proxies, TIMEOUT_SECONDS = get_conf("proxies", "TIMEOUT_SECONDS")
  12. """
  13. ========================================================================
  14. 第五部分 一些文件处理方法
  15. files_filter_handler 根据type过滤文件
  16. input_encode_handler 提取input中的文件,并解析
  17. file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
  18. link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
  19. html_view_blank 超链接
  20. html_local_file 本地文件取相对路径
  21. to_markdown_tabs 文件list 转换为 md tab
  22. """
  23. def files_filter_handler(file_list):
  24. new_list = []
  25. filter_ = [
  26. "png",
  27. "jpg",
  28. "jpeg",
  29. "bmp",
  30. "svg",
  31. "webp",
  32. "ico",
  33. "tif",
  34. "tiff",
  35. "raw",
  36. "eps",
  37. ]
  38. for file in file_list:
  39. file = str(file).replace("file=", "")
  40. if os.path.exists(file):
  41. if str(os.path.basename(file)).split(".")[-1] in filter_:
  42. new_list.append(file)
  43. return new_list
  44. def input_encode_handler(inputs, llm_kwargs):
  45. if llm_kwargs["most_recent_uploaded"].get("path"):
  46. image_paths = get_pictures_list(llm_kwargs["most_recent_uploaded"]["path"])
  47. md_encode = []
  48. for md_path in image_paths:
  49. type_ = os.path.splitext(md_path)[1].replace(".", "")
  50. type_ = "jpeg" if type_ == "jpg" else type_
  51. md_encode.append({"data": encode_image(md_path), "type": type_})
  52. return inputs, md_encode
  53. def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
  54. new_list = []
  55. if not filter_:
  56. filter_ = [
  57. "png",
  58. "jpg",
  59. "jpeg",
  60. "bmp",
  61. "svg",
  62. "webp",
  63. "ico",
  64. "tif",
  65. "tiff",
  66. "raw",
  67. "eps",
  68. ]
  69. for file in file_list:
  70. if str(os.path.basename(file)).split(".")[-1] in filter_:
  71. new_list.append(html_local_img(file, md=md_type))
  72. elif os.path.exists(file):
  73. new_list.append(link_mtime_to_md(file))
  74. else:
  75. new_list.append(file)
  76. return new_list
  77. def link_mtime_to_md(file):
  78. link_local = html_local_file(file)
  79. link_name = os.path.basename(file)
  80. a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
  81. return a
  82. def html_local_file(file):
  83. base_path = os.path.dirname(__file__) # 项目目录
  84. if os.path.exists(str(file)):
  85. file = f'file={file.replace(base_path, ".")}'
  86. return file
  87. def html_local_img(__file, layout="left", max_width=None, max_height=None, md=True):
  88. style = ""
  89. if max_width is not None:
  90. style += f"max-width: {max_width};"
  91. if max_height is not None:
  92. style += f"max-height: {max_height};"
  93. __file = html_local_file(__file)
  94. a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
  95. if md:
  96. a = f"![{__file}]({__file})"
  97. return a
  98. def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False):
  99. """
  100. Args:
  101. head: 表头:[]
  102. tabs: 表值:[[列1], [列2], [列3], [列4]]
  103. alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
  104. column: True to keep data in columns, False to keep data in rows (default).
  105. Returns:
  106. A string representation of the markdown table.
  107. """
  108. if column:
  109. transposed_tabs = list(map(list, zip(*tabs)))
  110. else:
  111. transposed_tabs = tabs
  112. # Find the maximum length among the columns
  113. max_len = max(len(column) for column in transposed_tabs)
  114. tab_format = "| %s "
  115. tabs_list = "".join([tab_format % i for i in head]) + "|\n"
  116. tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
  117. for i in range(max_len):
  118. row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
  119. row_data = file_manifest_filter_html(row_data, filter_=None)
  120. tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
  121. return tabs_list
  122. class GoogleChatInit:
  123. def __init__(self):
  124. self.url_gemini = "https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k"
  125. def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
  126. headers, payload = self.generate_message_payload(
  127. inputs, llm_kwargs, history, system_prompt
  128. )
  129. response = requests.post(
  130. url=self.url_gemini,
  131. headers=headers,
  132. data=json.dumps(payload),
  133. stream=True,
  134. proxies=proxies,
  135. timeout=TIMEOUT_SECONDS,
  136. )
  137. return response.iter_lines()
  138. def __conversation_user(self, user_input, llm_kwargs):
  139. what_i_have_asked = {"role": "user", "parts": []}
  140. if "vision" not in self.url_gemini:
  141. input_ = user_input
  142. encode_img = []
  143. else:
  144. input_, encode_img = input_encode_handler(user_input, llm_kwargs=llm_kwargs)
  145. what_i_have_asked["parts"].append({"text": input_})
  146. if encode_img:
  147. for data in encode_img:
  148. what_i_have_asked["parts"].append(
  149. {
  150. "inline_data": {
  151. "mime_type": f"image/{data['type']}",
  152. "data": data["data"],
  153. }
  154. }
  155. )
  156. return what_i_have_asked
  157. def __conversation_history(self, history, llm_kwargs):
  158. messages = []
  159. conversation_cnt = len(history) // 2
  160. if conversation_cnt:
  161. for index in range(0, 2 * conversation_cnt, 2):
  162. what_i_have_asked = self.__conversation_user(history[index], llm_kwargs)
  163. what_gpt_answer = {
  164. "role": "model",
  165. "parts": [{"text": history[index + 1]}],
  166. }
  167. messages.append(what_i_have_asked)
  168. messages.append(what_gpt_answer)
  169. return messages
  170. def generate_message_payload(
  171. self, inputs, llm_kwargs, history, system_prompt
  172. ) -> Tuple[Dict, Dict]:
  173. messages = [
  174. # {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
  175. # {"role": "user", "parts": [{"text": ""}]},
  176. # {"role": "model", "parts": [{"text": ""}]}
  177. ]
  178. self.url_gemini = self.url_gemini.replace(
  179. "%m", llm_kwargs["llm_model"]
  180. ).replace("%k", get_conf("GEMINI_API_KEY"))
  181. header = {"Content-Type": "application/json"}
  182. if "vision" not in self.url_gemini: # 不是vision 才处理history
  183. messages.extend(
  184. self.__conversation_history(history, llm_kwargs)
  185. ) # 处理 history
  186. messages.append(self.__conversation_user(inputs, llm_kwargs)) # 处理用户对话
  187. payload = {
  188. "contents": messages,
  189. "generationConfig": {
  190. # "maxOutputTokens": 800,
  191. "stopSequences": str(llm_kwargs.get("stop", "")).split(" "),
  192. "temperature": llm_kwargs.get("temperature", 1),
  193. "topP": llm_kwargs.get("top_p", 0.8),
  194. "topK": 10,
  195. },
  196. }
  197. return header, payload
  198. if __name__ == "__main__":
  199. google = GoogleChatInit()
  200. # print(gootle.generate_message_payload('你好呀', {}, ['123123', '3123123'], ''))
  201. # gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')