toolbox.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052
  1. import importlib
  2. import time
  3. import inspect
  4. import re
  5. import os
  6. import base64
  7. import gradio
  8. import shutil
  9. import glob
  10. import uuid
  11. from loguru import logger
  12. from functools import wraps
  13. from textwrap import dedent
  14. from shared_utils.config_loader import get_conf
  15. from shared_utils.config_loader import set_conf
  16. from shared_utils.config_loader import set_multi_conf
  17. from shared_utils.config_loader import read_single_conf_with_lru_cache
  18. from shared_utils.advanced_markdown_format import format_io
  19. from shared_utils.advanced_markdown_format import markdown_convertion
  20. from shared_utils.key_pattern_manager import select_api_key
  21. from shared_utils.key_pattern_manager import is_any_api_key
  22. from shared_utils.key_pattern_manager import what_keys
  23. from shared_utils.connect_void_terminal import get_chat_handle
  24. from shared_utils.connect_void_terminal import get_plugin_handle
  25. from shared_utils.connect_void_terminal import get_plugin_default_kwargs
  26. from shared_utils.connect_void_terminal import get_chat_default_kwargs
  27. from shared_utils.text_mask import apply_gpt_academic_string_mask
  28. from shared_utils.text_mask import build_gpt_academic_masked_string
  29. from shared_utils.text_mask import apply_gpt_academic_string_mask_langbased
  30. from shared_utils.text_mask import build_gpt_academic_masked_string_langbased
  31. from shared_utils.map_names import map_friendly_names_to_model
  32. from shared_utils.map_names import map_model_to_friendly_names
  33. from shared_utils.map_names import read_one_api_model_name
  34. from shared_utils.handle_upload import html_local_file
  35. from shared_utils.handle_upload import html_local_img
  36. from shared_utils.handle_upload import file_manifest_filter_type
  37. from shared_utils.handle_upload import extract_archive
  38. from typing import List
  39. pj = os.path.join
  40. default_user_name = "default_user"
  41. """
  42. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
  43. 第一部分
  44. 函数插件输入输出接驳区
  45. - ChatBotWithCookies: 带Cookies的Chatbot类,为实现更多强大的功能做基础
  46. - ArgsGeneralWrapper: 装饰器函数,用于重组输入参数,改变输入参数的顺序与结构
  47. - update_ui: 刷新界面用 yield from update_ui(chatbot, history)
  48. - CatchException: 将插件中出的所有问题显示在界面上
  49. - HotReload: 实现插件的热更新
  50. - trimmed_format_exc: 打印traceback,为了安全而隐藏绝对地址
  51. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
  52. """
  53. class ChatBotWithCookies(list):
  54. def __init__(self, cookie):
  55. """
  56. cookies = {
  57. 'top_p': top_p,
  58. 'temperature': temperature,
  59. 'lock_plugin': bool,
  60. "files_to_promote": ["file1", "file2"],
  61. "most_recent_uploaded": {
  62. "path": "uploaded_path",
  63. "time": time.time(),
  64. "time_str": "timestr",
  65. }
  66. }
  67. """
  68. self._cookies = cookie
  69. def write_list(self, list):
  70. for t in list:
  71. self.append(t)
  72. def get_list(self):
  73. return [t for t in self]
  74. def get_cookies(self):
  75. return self._cookies
  76. def get_user(self):
  77. return self._cookies.get("user_name", default_user_name)
  78. def ArgsGeneralWrapper(f):
  79. """
  80. 装饰器函数ArgsGeneralWrapper,用于重组输入参数,改变输入参数的顺序与结构。
  81. 该装饰器是大多数功能调用的入口。
  82. 函数示意图:https://mermaid.live/edit#pako:eNqNVFtPGkEY_StkntoEDQtLoTw0sWqapjQxVWPabmOm7AiEZZcsQ9QiiW012qixqdeqqIn10geBh6ZR8PJnmAWe-hc6l3VhrWnLEzNzzvnO953ZyYOYoSIQAWOaMR5LQBN7hvoU3UN_g5iu7imAXEyT4wUF3Pd0dT3y9KGYYUJsmK8V0GPGs0-QjkyojZgwk0Fm82C2dVghX08U8EaoOHjOfoEMU0XmADRhOksVWnNLjdpM82qFzB6S5Q_WWsUhuqCc3JtAsVR_OoMnhyZwXgHWwbS1d4gnsLVZJp-P6mfVxveqAgqC70Jz_pQCOGDKM5xFdNNPDdilF6uSU_hOYqu4a3MHYDZLDzq5fodrC3PWcEaFGPUaRiqJWK_W9g9rvRITa4dhy_0nw67SiePMp3oSR6PPn41DGgllkvkizYwsrmtaejTFd8V4yekGmT1zqrt4XGlAy8WTuiPULF01LksZvukSajfQQRAxmYi5S0D81sDcyzapVdn6sYFHkjhhGyel3frVQnvsnbR23lEjlhIlaOJiFPWzU5G4tfNJo8ejwp47-TbvJkKKZvmxA6SKo16oaazJysfG6klr9T0pbTW2ZqzlL_XaT8fYbQLXe4mSmvoCZXMaa7FePW6s7jVqK9bujvse3WFjY5_Z4KfsA4oiPY4T7Drvn1tLJTbG1to1qR79ulgk89-oJbvZzbIwJty6u20LOReWa9BvwserUd9s9MIKc3x5TUWEoAhUyJK5y85w_yG-dFu_R9waoU7K581y8W_qLle35-rG9Nxcrz8QHRsc0K-r9NViYRT36KsFvCCNzDRMqvSVyzOKAnACpZECIvSvCs2UAhS9QHEwh43BST0GItjMIS_I8e-sLwnj9A262cxA_ZVh0OUY1LJiDSJ5MAEiUijYLUtBORR6KElyQPaCSRDpksNSd8AfluSgHPaFC17wjrOlbgbzyyFf4IFPDvoD_sJvnkdK-g
  83. """
  84. def decorated(request: gradio.Request, cookies:dict, max_length:int, llm_model:str,
  85. txt:str, txt2:str, top_p:float, temperature:float, chatbot:list,
  86. history:list, system_prompt:str, plugin_advanced_arg:dict, *args):
  87. txt_passon = txt
  88. if txt == "" and txt2 != "": txt_passon = txt2
  89. # 引入一个有cookie的chatbot
  90. if request.username is not None:
  91. user_name = request.username
  92. else:
  93. user_name = default_user_name
  94. embed_model = get_conf("EMBEDDING_MODEL")
  95. cookies.update({
  96. 'top_p': top_p,
  97. 'api_key': cookies['api_key'],
  98. 'llm_model': llm_model,
  99. 'embed_model': embed_model,
  100. 'temperature': temperature,
  101. 'user_name': user_name,
  102. })
  103. llm_kwargs = {
  104. 'api_key': cookies['api_key'],
  105. 'llm_model': llm_model,
  106. 'embed_model': embed_model,
  107. 'top_p': top_p,
  108. 'max_length': max_length,
  109. 'temperature': temperature,
  110. 'client_ip': request.client.host,
  111. 'most_recent_uploaded': cookies.get('most_recent_uploaded')
  112. }
  113. if isinstance(plugin_advanced_arg, str):
  114. plugin_kwargs = {"advanced_arg": plugin_advanced_arg}
  115. else:
  116. plugin_kwargs = plugin_advanced_arg
  117. chatbot_with_cookie = ChatBotWithCookies(cookies)
  118. chatbot_with_cookie.write_list(chatbot)
  119. if cookies.get('lock_plugin', None) is None:
  120. # 正常状态
  121. if len(args) == 0: # 插件通道
  122. yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, request)
  123. else: # 对话通道,或者基础功能通道
  124. yield from f(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, *args)
  125. else:
  126. # 处理少数情况下的特殊插件的锁定状态
  127. module, fn_name = cookies['lock_plugin'].split('->')
  128. f_hot_reload = getattr(importlib.import_module(module, fn_name), fn_name)
  129. yield from f_hot_reload(txt_passon, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, system_prompt, request)
  130. # 判断一下用户是否错误地通过对话通道进入,如果是,则进行提醒
  131. final_cookies = chatbot_with_cookie.get_cookies()
  132. # len(args) != 0 代表“提交”键对话通道,或者基础功能通道
  133. if len(args) != 0 and 'files_to_promote' in final_cookies and len(final_cookies['files_to_promote']) > 0:
  134. chatbot_with_cookie.append(
  135. ["检测到**滞留的缓存文档**,请及时处理。", "请及时点击“**保存当前对话**”获取所有滞留文档。"])
  136. yield from update_ui(chatbot_with_cookie, final_cookies['history'], msg="检测到被滞留的缓存文档")
  137. return decorated
  138. def update_ui(chatbot:ChatBotWithCookies, history, msg="正常", **kwargs): # 刷新界面
  139. """
  140. 刷新用户界面
  141. """
  142. assert isinstance(
  143. chatbot, ChatBotWithCookies
  144. ), "在传递chatbot的过程中不要将其丢弃。必要时, 可用clear将其清空, 然后用for+append循环重新赋值。"
  145. cookies = chatbot.get_cookies()
  146. # 备份一份History作为记录
  147. cookies.update({"history": history})
  148. # 解决插件锁定时的界面显示问题
  149. if cookies.get("lock_plugin", None):
  150. label = (
  151. cookies.get("llm_model", "")
  152. + " | "
  153. + "正在锁定插件"
  154. + cookies.get("lock_plugin", None)
  155. )
  156. chatbot_gr = gradio.update(value=chatbot, label=label)
  157. if cookies.get("label", "") != label:
  158. cookies["label"] = label # 记住当前的label
  159. elif cookies.get("label", None):
  160. chatbot_gr = gradio.update(value=chatbot, label=cookies.get("llm_model", ""))
  161. cookies["label"] = None # 清空label
  162. else:
  163. chatbot_gr = chatbot
  164. yield cookies, chatbot_gr, history, msg
  165. def update_ui_lastest_msg(lastmsg:str, chatbot:ChatBotWithCookies, history:list, delay=1, msg="正常"): # 刷新界面
  166. """
  167. 刷新用户界面
  168. """
  169. if len(chatbot) == 0:
  170. chatbot.append(["update_ui_last_msg", lastmsg])
  171. chatbot[-1] = list(chatbot[-1])
  172. chatbot[-1][-1] = lastmsg
  173. yield from update_ui(chatbot=chatbot, history=history, msg=msg)
  174. time.sleep(delay)
  175. def trimmed_format_exc():
  176. import os, traceback
  177. str = traceback.format_exc()
  178. current_path = os.getcwd()
  179. replace_path = "."
  180. return str.replace(current_path, replace_path)
  181. def trimmed_format_exc_markdown():
  182. return '\n\n```\n' + trimmed_format_exc() + '```'
  183. class FriendlyException(Exception):
  184. def generate_error_html(self):
  185. return dedent(f"""
  186. <div class="center-div" style="color: crimson;text-align: center;">
  187. {"<br>".join(self.args)}
  188. </div>
  189. """)
  190. def CatchException(f):
  191. """
  192. 装饰器函数,捕捉函数f中的异常并封装到一个生成器中返回,并显示到聊天当中。
  193. """
  194. @wraps(f)
  195. def decorated(main_input:str, llm_kwargs:dict, plugin_kwargs:dict,
  196. chatbot_with_cookie:ChatBotWithCookies, history:list, *args, **kwargs):
  197. try:
  198. yield from f(main_input, llm_kwargs, plugin_kwargs, chatbot_with_cookie, history, *args, **kwargs)
  199. except FriendlyException as e:
  200. tb_str = '```\n' + trimmed_format_exc() + '```'
  201. if len(chatbot_with_cookie) == 0:
  202. chatbot_with_cookie.clear()
  203. chatbot_with_cookie.append(["插件调度异常:\n" + tb_str, None])
  204. chatbot_with_cookie[-1] = [chatbot_with_cookie[-1][0], e.generate_error_html()]
  205. yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常') # 刷新界面
  206. except Exception as e:
  207. tb_str = '```\n' + trimmed_format_exc() + '```'
  208. if len(chatbot_with_cookie) == 0:
  209. chatbot_with_cookie.clear()
  210. chatbot_with_cookie.append(["插件调度异常", "异常原因"])
  211. chatbot_with_cookie[-1] = [chatbot_with_cookie[-1][0], f"[Local Message] 插件调用出错: \n\n{tb_str} \n"]
  212. yield from update_ui(chatbot=chatbot_with_cookie, history=history, msg=f'异常 {e}') # 刷新界面
  213. return decorated
  214. def HotReload(f):
  215. """
  216. HotReload的装饰器函数,用于实现Python函数插件的热更新。
  217. 函数热更新是指在不停止程序运行的情况下,更新函数代码,从而达到实时更新功能。
  218. 在装饰器内部,使用wraps(f)来保留函数的元信息,并定义了一个名为decorated的内部函数。
  219. 内部函数通过使用importlib模块的reload函数和inspect模块的getmodule函数来重新加载并获取函数模块,
  220. 然后通过getattr函数获取函数名,并在新模块中重新加载函数。
  221. 最后,使用yield from语句返回重新加载过的函数,并在被装饰的函数上执行。
  222. 最终,装饰器函数返回内部函数。这个内部函数可以将函数的原始定义更新为最新版本,并执行函数的新版本。
  223. """
  224. if get_conf("PLUGIN_HOT_RELOAD"):
  225. @wraps(f)
  226. def decorated(*args, **kwargs):
  227. fn_name = f.__name__
  228. f_hot_reload = getattr(importlib.reload(inspect.getmodule(f)), fn_name)
  229. yield from f_hot_reload(*args, **kwargs)
  230. return decorated
  231. else:
  232. return f
  233. """
  234. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
  235. 第二部分
  236. 其他小工具:
  237. - write_history_to_file: 将结果写入markdown文件中
  238. - regular_txt_to_markdown: 将普通文本转换为Markdown格式的文本。
  239. - report_exception: 向chatbot中添加简单的意外错误信息
  240. - text_divide_paragraph: 将文本按照段落分隔符分割开,生成带有段落标签的HTML代码。
  241. - markdown_convertion: 用多种方式组合,将markdown转化为好看的html
  242. - format_io: 接管gradio默认的markdown处理方式
  243. - on_file_uploaded: 处理文件的上传(自动解压)
  244. - on_report_generated: 将生成的报告自动投射到文件上传区
  245. - clip_history: 当历史上下文过长时,自动截断
  246. - get_conf: 获取设置
  247. - select_api_key: 根据当前的模型类别,抽取可用的api-key
  248. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
  249. """
  250. def get_reduce_token_percent(text:str):
  251. """
  252. * 此函数未来将被弃用
  253. """
  254. try:
  255. # text = "maximum context length is 4097 tokens. However, your messages resulted in 4870 tokens"
  256. pattern = r"(\d+)\s+tokens\b"
  257. match = re.findall(pattern, text)
  258. EXCEED_ALLO = 500 # 稍微留一点余地,否则在回复时会因余量太少出问题
  259. max_limit = float(match[0]) - EXCEED_ALLO
  260. current_tokens = float(match[1])
  261. ratio = max_limit / current_tokens
  262. assert ratio > 0 and ratio < 1
  263. return ratio, str(int(current_tokens - max_limit))
  264. except:
  265. return 0.5, "不详"
  266. def write_history_to_file(
  267. history:list, file_basename:str=None, file_fullname:str=None, auto_caption:bool=True
  268. ):
  269. """
  270. 将对话记录history以Markdown格式写入文件中。如果没有指定文件名,则使用当前时间生成文件名。
  271. """
  272. import os
  273. import time
  274. if file_fullname is None:
  275. if file_basename is not None:
  276. file_fullname = pj(get_log_folder(), file_basename)
  277. else:
  278. file_fullname = pj(get_log_folder(), f"GPT-Academic-{gen_time_str()}.md")
  279. os.makedirs(os.path.dirname(file_fullname), exist_ok=True)
  280. with open(file_fullname, "w", encoding="utf8") as f:
  281. f.write("# GPT-Academic Report\n")
  282. for i, content in enumerate(history):
  283. try:
  284. if type(content) != str:
  285. content = str(content)
  286. except:
  287. continue
  288. if i % 2 == 0 and auto_caption:
  289. f.write("## ")
  290. try:
  291. f.write(content)
  292. except:
  293. # remove everything that cannot be handled by utf8
  294. f.write(content.encode("utf-8", "ignore").decode())
  295. f.write("\n\n")
  296. res = os.path.abspath(file_fullname)
  297. return res
  298. def regular_txt_to_markdown(text:str):
  299. """
  300. 将普通文本转换为Markdown格式的文本。
  301. """
  302. text = text.replace("\n", "\n\n")
  303. text = text.replace("\n\n\n", "\n\n")
  304. text = text.replace("\n\n\n", "\n\n")
  305. return text
  306. def report_exception(chatbot:ChatBotWithCookies, history:list, a:str, b:str):
  307. """
  308. 向chatbot中添加错误信息
  309. """
  310. chatbot.append((a, b))
  311. history.extend([a, b])
  312. def find_free_port()->int:
  313. """
  314. 返回当前系统中可用的未使用端口。
  315. """
  316. import socket
  317. from contextlib import closing
  318. with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
  319. s.bind(("", 0))
  320. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  321. return s.getsockname()[1]
  322. def find_recent_files(directory:str)->List[str]:
  323. """
  324. Find files that is created with in one minutes under a directory with python, write a function
  325. """
  326. import os
  327. import time
  328. current_time = time.time()
  329. one_minute_ago = current_time - 60
  330. recent_files = []
  331. if not os.path.exists(directory):
  332. os.makedirs(directory, exist_ok=True)
  333. for filename in os.listdir(directory):
  334. file_path = pj(directory, filename)
  335. if file_path.endswith(".log"):
  336. continue
  337. created_time = os.path.getmtime(file_path)
  338. if created_time >= one_minute_ago:
  339. if os.path.isdir(file_path):
  340. continue
  341. recent_files.append(file_path)
  342. return recent_files
  343. def file_already_in_downloadzone(file:str, user_path:str):
  344. try:
  345. parent_path = os.path.abspath(user_path)
  346. child_path = os.path.abspath(file)
  347. if os.path.samefile(os.path.commonpath([parent_path, child_path]), parent_path):
  348. return True
  349. else:
  350. return False
  351. except:
  352. return False
  353. def promote_file_to_downloadzone(file:str, rename_file:str=None, chatbot:ChatBotWithCookies=None):
  354. # 将文件复制一份到下载区
  355. import shutil
  356. if chatbot is not None:
  357. user_name = get_user(chatbot)
  358. else:
  359. user_name = default_user_name
  360. if not os.path.exists(file):
  361. raise FileNotFoundError(f"文件{file}不存在")
  362. user_path = get_log_folder(user_name, plugin_name=None)
  363. if file_already_in_downloadzone(file, user_path):
  364. new_path = file
  365. else:
  366. user_path = get_log_folder(user_name, plugin_name="downloadzone")
  367. if rename_file is None:
  368. rename_file = f"{gen_time_str()}-{os.path.basename(file)}"
  369. new_path = pj(user_path, rename_file)
  370. # 如果已经存在,先删除
  371. if os.path.exists(new_path) and not os.path.samefile(new_path, file):
  372. os.remove(new_path)
  373. # 把文件复制过去
  374. if not os.path.exists(new_path):
  375. shutil.copyfile(file, new_path)
  376. # 将文件添加到chatbot cookie中
  377. if chatbot is not None:
  378. if "files_to_promote" in chatbot._cookies:
  379. current = chatbot._cookies["files_to_promote"]
  380. else:
  381. current = []
  382. if new_path not in current: # 避免把同一个文件添加多次
  383. chatbot._cookies.update({"files_to_promote": [new_path] + current})
  384. return new_path
  385. def disable_auto_promotion(chatbot:ChatBotWithCookies):
  386. chatbot._cookies.update({"files_to_promote": []})
  387. return
  388. def del_outdated_uploads(outdate_time_seconds:float, target_path_base:str=None):
  389. if target_path_base is None:
  390. user_upload_dir = get_conf("PATH_PRIVATE_UPLOAD")
  391. else:
  392. user_upload_dir = target_path_base
  393. current_time = time.time()
  394. one_hour_ago = current_time - outdate_time_seconds
  395. # Get a list of all subdirectories in the user_upload_dir folder
  396. # Remove subdirectories that are older than one hour
  397. for subdirectory in glob.glob(f"{user_upload_dir}/*"):
  398. subdirectory_time = os.path.getmtime(subdirectory)
  399. if subdirectory_time < one_hour_ago:
  400. try:
  401. shutil.rmtree(subdirectory)
  402. except:
  403. pass
  404. return
  405. def to_markdown_tabs(head: list, tabs: list, alignment=":---:", column=False, omit_path=None):
  406. """
  407. Args:
  408. head: 表头:[]
  409. tabs: 表值:[[列1], [列2], [列3], [列4]]
  410. alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
  411. column: True to keep data in columns, False to keep data in rows (default).
  412. Returns:
  413. A string representation of the markdown table.
  414. """
  415. if column:
  416. transposed_tabs = list(map(list, zip(*tabs)))
  417. else:
  418. transposed_tabs = tabs
  419. # Find the maximum length among the columns
  420. max_len = max(len(column) for column in transposed_tabs)
  421. tab_format = "| %s "
  422. tabs_list = "".join([tab_format % i for i in head]) + "|\n"
  423. tabs_list += "".join([tab_format % alignment for i in head]) + "|\n"
  424. for i in range(max_len):
  425. row_data = [tab[i] if i < len(tab) else "" for tab in transposed_tabs]
  426. row_data = file_manifest_filter_type(row_data, filter_=None)
  427. # for dat in row_data:
  428. # if (omit_path is not None) and os.path.exists(dat):
  429. # dat = os.path.relpath(dat, omit_path)
  430. tabs_list += "".join([tab_format % i for i in row_data]) + "|\n"
  431. return tabs_list
  432. def on_file_uploaded(
  433. request: gradio.Request, files:List[str], chatbot:ChatBotWithCookies,
  434. txt:str, txt2:str, checkboxes:List[str], cookies:dict
  435. ):
  436. """
  437. 当文件被上传时的回调函数
  438. """
  439. if len(files) == 0:
  440. return chatbot, txt
  441. # 创建工作路径
  442. user_name = default_user_name if not request.username else request.username
  443. time_tag = gen_time_str()
  444. target_path_base = get_upload_folder(user_name, tag=time_tag)
  445. os.makedirs(target_path_base, exist_ok=True)
  446. # 移除过时的旧文件从而节省空间&保护隐私
  447. outdate_time_seconds = 3600 # 一小时
  448. del_outdated_uploads(outdate_time_seconds, get_upload_folder(user_name))
  449. # 逐个文件转移到目标路径
  450. upload_msg = ""
  451. for file in files:
  452. file_origin_name = os.path.basename(file.orig_name)
  453. this_file_path = pj(target_path_base, file_origin_name)
  454. shutil.move(file.name, this_file_path)
  455. upload_msg += extract_archive(
  456. file_path=this_file_path, dest_dir=this_file_path + ".extract"
  457. )
  458. # 整理文件集合 输出消息
  459. files = glob.glob(f"{target_path_base}/**/*", recursive=True)
  460. moved_files = [fp for fp in files]
  461. max_file_to_show = 10
  462. if len(moved_files) > max_file_to_show:
  463. moved_files = moved_files[:max_file_to_show//2] + [f'... ( 📌省略{len(moved_files) - max_file_to_show}个文件的显示 ) ...'] + \
  464. moved_files[-max_file_to_show//2:]
  465. moved_files_str = to_markdown_tabs(head=["文件"], tabs=[moved_files], omit_path=target_path_base)
  466. chatbot.append(
  467. [
  468. "我上传了文件,请查收",
  469. f"[Local Message] 收到以下文件 (上传到路径:{target_path_base}): " +
  470. f"\n\n{moved_files_str}" +
  471. f"\n\n调用路径参数已自动修正到: \n\n{txt}" +
  472. f"\n\n现在您点击任意函数插件时,以上文件将被作为输入参数" +
  473. upload_msg,
  474. ]
  475. )
  476. txt, txt2 = target_path_base, ""
  477. if "浮动输入区" in checkboxes:
  478. txt, txt2 = txt2, txt
  479. # 记录近期文件
  480. cookies.update(
  481. {
  482. "most_recent_uploaded": {
  483. "path": target_path_base,
  484. "time": time.time(),
  485. "time_str": time_tag,
  486. }
  487. }
  488. )
  489. return chatbot, txt, txt2, cookies
  490. def generate_file_link(report_files:List[str]):
  491. file_links = ""
  492. for f in report_files:
  493. file_links += (
  494. f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
  495. )
  496. return file_links
  497. def on_report_generated(cookies:dict, files:List[str], chatbot:ChatBotWithCookies):
  498. if "files_to_promote" in cookies:
  499. report_files = cookies["files_to_promote"]
  500. cookies.pop("files_to_promote")
  501. else:
  502. report_files = []
  503. if len(report_files) == 0:
  504. return cookies, None, chatbot
  505. file_links = ""
  506. for f in report_files:
  507. file_links += (
  508. f'<br/><a href="file={os.path.abspath(f)}" target="_blank">{f}</a>'
  509. )
  510. chatbot.append(["报告如何远程获取?", f"报告已经添加到右侧“文件下载区”(可能处于折叠状态),请查收。{file_links}"])
  511. return cookies, report_files, chatbot
  512. def load_chat_cookies():
  513. API_KEY, LLM_MODEL, AZURE_API_KEY = get_conf(
  514. "API_KEY", "LLM_MODEL", "AZURE_API_KEY"
  515. )
  516. AZURE_CFG_ARRAY, NUM_CUSTOM_BASIC_BTN = get_conf(
  517. "AZURE_CFG_ARRAY", "NUM_CUSTOM_BASIC_BTN"
  518. )
  519. # deal with azure openai key
  520. if is_any_api_key(AZURE_API_KEY):
  521. if is_any_api_key(API_KEY):
  522. API_KEY = API_KEY + "," + AZURE_API_KEY
  523. else:
  524. API_KEY = AZURE_API_KEY
  525. if len(AZURE_CFG_ARRAY) > 0:
  526. for azure_model_name, azure_cfg_dict in AZURE_CFG_ARRAY.items():
  527. if not azure_model_name.startswith("azure"):
  528. raise ValueError("AZURE_CFG_ARRAY中配置的模型必须以azure开头")
  529. AZURE_API_KEY_ = azure_cfg_dict["AZURE_API_KEY"]
  530. if is_any_api_key(AZURE_API_KEY_):
  531. if is_any_api_key(API_KEY):
  532. API_KEY = API_KEY + "," + AZURE_API_KEY_
  533. else:
  534. API_KEY = AZURE_API_KEY_
  535. customize_fn_overwrite_ = {}
  536. for k in range(NUM_CUSTOM_BASIC_BTN):
  537. customize_fn_overwrite_.update(
  538. {
  539. "自定义按钮"
  540. + str(k + 1): {
  541. "Title": r"",
  542. "Prefix": r"请在自定义菜单中定义提示词前缀.",
  543. "Suffix": r"请在自定义菜单中定义提示词后缀",
  544. }
  545. }
  546. )
  547. EMBEDDING_MODEL = get_conf("EMBEDDING_MODEL")
  548. return {
  549. "api_key": API_KEY,
  550. "llm_model": LLM_MODEL,
  551. "embed_model": EMBEDDING_MODEL,
  552. "customize_fn_overwrite": customize_fn_overwrite_,
  553. }
  554. def clear_line_break(txt):
  555. txt = txt.replace("\n", " ")
  556. txt = txt.replace(" ", " ")
  557. txt = txt.replace(" ", " ")
  558. return txt
  559. class DummyWith:
  560. """
  561. 这段代码定义了一个名为DummyWith的空上下文管理器,
  562. 它的作用是……额……就是不起作用,即在代码结构不变得情况下取代其他的上下文管理器。
  563. 上下文管理器是一种Python对象,用于与with语句一起使用,
  564. 以确保一些资源在代码块执行期间得到正确的初始化和清理。
  565. 上下文管理器必须实现两个方法,分别为 __enter__()和 __exit__()。
  566. 在上下文执行开始的情况下,__enter__()方法会在代码块被执行前被调用,
  567. 而在上下文执行结束时,__exit__()方法则会被调用。
  568. """
  569. def __enter__(self):
  570. return self
  571. def __exit__(self, exc_type, exc_value, traceback):
  572. return
  573. def run_gradio_in_subpath(demo, auth, port, custom_path):
  574. """
  575. 把gradio的运行地址更改到指定的二次路径上
  576. """
  577. def is_path_legal(path: str) -> bool:
  578. """
  579. check path for sub url
  580. path: path to check
  581. return value: do sub url wrap
  582. """
  583. if path == "/":
  584. return True
  585. if len(path) == 0:
  586. logger.info(
  587. "ilegal custom path: {}\npath must not be empty\ndeploy on root url".format(
  588. path
  589. )
  590. )
  591. return False
  592. if path[0] == "/":
  593. if path[1] != "/":
  594. logger.info("deploy on sub-path {}".format(path))
  595. return True
  596. return False
  597. logger.info(
  598. "ilegal custom path: {}\npath should begin with '/'\ndeploy on root url".format(
  599. path
  600. )
  601. )
  602. return False
  603. if not is_path_legal(custom_path):
  604. raise RuntimeError("Ilegal custom path")
  605. import uvicorn
  606. import gradio as gr
  607. from fastapi import FastAPI
  608. app = FastAPI()
  609. if custom_path != "/":
  610. @app.get("/")
  611. def read_main():
  612. return {"message": f"Gradio is running at: {custom_path}"}
  613. app = gr.mount_gradio_app(app, demo, path=custom_path)
  614. uvicorn.run(app, host="0.0.0.0", port=port) # , auth=auth
  615. def clip_history(inputs, history, tokenizer, max_token_limit):
  616. """
  617. reduce the length of history by clipping.
  618. this function search for the longest entries to clip, little by little,
  619. until the number of token of history is reduced under threshold.
  620. 通过裁剪来缩短历史记录的长度。
  621. 此函数逐渐地搜索最长的条目进行剪辑,
  622. 直到历史记录的标记数量降低到阈值以下。
  623. """
  624. import numpy as np
  625. from request_llms.bridge_all import model_info
  626. def get_token_num(txt):
  627. return len(tokenizer.encode(txt, disallowed_special=()))
  628. input_token_num = get_token_num(inputs)
  629. if max_token_limit < 5000:
  630. output_token_expect = 256 # 4k & 2k models
  631. elif max_token_limit < 9000:
  632. output_token_expect = 512 # 8k models
  633. else:
  634. output_token_expect = 1024 # 16k & 32k models
  635. if input_token_num < max_token_limit * 3 / 4:
  636. # 当输入部分的token占比小于限制的3/4时,裁剪时
  637. # 1. 把input的余量留出来
  638. max_token_limit = max_token_limit - input_token_num
  639. # 2. 把输出用的余量留出来
  640. max_token_limit = max_token_limit - output_token_expect
  641. # 3. 如果余量太小了,直接清除历史
  642. if max_token_limit < output_token_expect:
  643. history = []
  644. return history
  645. else:
  646. # 当输入部分的token占比 > 限制的3/4时,直接清除历史
  647. history = []
  648. return history
  649. everything = [""]
  650. everything.extend(history)
  651. n_token = get_token_num("\n".join(everything))
  652. everything_token = [get_token_num(e) for e in everything]
  653. # 截断时的颗粒度
  654. delta = max(everything_token) // 16
  655. while n_token > max_token_limit:
  656. where = np.argmax(everything_token)
  657. encoded = tokenizer.encode(everything[where], disallowed_special=())
  658. clipped_encoded = encoded[: len(encoded) - delta]
  659. everything[where] = tokenizer.decode(clipped_encoded)[
  660. :-1
  661. ] # -1 to remove the may-be illegal char
  662. everything_token[where] = get_token_num(everything[where])
  663. n_token = get_token_num("\n".join(everything))
  664. history = everything[1:]
  665. return history
  666. """
  667. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
  668. 第三部分
  669. 其他小工具:
  670. - zip_folder: 把某个路径下所有文件压缩,然后转移到指定的另一个路径中(gpt写的)
  671. - gen_time_str: 生成时间戳
  672. - ProxyNetworkActivate: 临时地启动代理网络(如果有)
  673. - objdump/objload: 快捷的调试函数
  674. =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
  675. """
  676. def zip_folder(source_folder, dest_folder, zip_name):
  677. import zipfile
  678. import os
  679. # Make sure the source folder exists
  680. if not os.path.exists(source_folder):
  681. logger.info(f"{source_folder} does not exist")
  682. return
  683. # Make sure the destination folder exists
  684. if not os.path.exists(dest_folder):
  685. logger.info(f"{dest_folder} does not exist")
  686. return
  687. # Create the name for the zip file
  688. zip_file = pj(dest_folder, zip_name)
  689. # Create a ZipFile object
  690. with zipfile.ZipFile(zip_file, "w", zipfile.ZIP_DEFLATED) as zipf:
  691. # Walk through the source folder and add files to the zip file
  692. for foldername, subfolders, filenames in os.walk(source_folder):
  693. for filename in filenames:
  694. filepath = pj(foldername, filename)
  695. zipf.write(filepath, arcname=os.path.relpath(filepath, source_folder))
  696. # Move the zip file to the destination folder (if it wasn't already there)
  697. if os.path.dirname(zip_file) != dest_folder:
  698. os.rename(zip_file, pj(dest_folder, os.path.basename(zip_file)))
  699. zip_file = pj(dest_folder, os.path.basename(zip_file))
  700. logger.info(f"Zip file created at {zip_file}")
  701. def zip_result(folder):
  702. t = gen_time_str()
  703. zip_folder(folder, get_log_folder(), f"{t}-result.zip")
  704. return pj(get_log_folder(), f"{t}-result.zip")
  705. def gen_time_str():
  706. import time
  707. return time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
  708. def get_log_folder(user=default_user_name, plugin_name="shared"):
  709. if user is None:
  710. user = default_user_name
  711. PATH_LOGGING = get_conf("PATH_LOGGING")
  712. if plugin_name is None:
  713. _dir = pj(PATH_LOGGING, user)
  714. else:
  715. _dir = pj(PATH_LOGGING, user, plugin_name)
  716. if not os.path.exists(_dir):
  717. os.makedirs(_dir)
  718. return _dir
  719. def get_upload_folder(user=default_user_name, tag=None):
  720. PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
  721. if user is None:
  722. user = default_user_name
  723. if tag is None or len(tag) == 0:
  724. target_path_base = pj(PATH_PRIVATE_UPLOAD, user)
  725. else:
  726. target_path_base = pj(PATH_PRIVATE_UPLOAD, user, tag)
  727. return target_path_base
  728. def is_the_upload_folder(string):
  729. PATH_PRIVATE_UPLOAD = get_conf("PATH_PRIVATE_UPLOAD")
  730. pattern = r"^PATH_PRIVATE_UPLOAD[\\/][A-Za-z0-9_-]+[\\/]\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}$"
  731. pattern = pattern.replace("PATH_PRIVATE_UPLOAD", PATH_PRIVATE_UPLOAD)
  732. if re.match(pattern, string):
  733. return True
  734. else:
  735. return False
  736. def get_user(chatbotwithcookies:ChatBotWithCookies):
  737. return chatbotwithcookies._cookies.get("user_name", default_user_name)
  738. class ProxyNetworkActivate:
  739. """
  740. 这段代码定义了一个名为ProxyNetworkActivate的空上下文管理器, 用于给一小段代码上代理
  741. """
  742. def __init__(self, task=None) -> None:
  743. self.task = task
  744. if not task:
  745. # 不给定task, 那么我们默认代理生效
  746. self.valid = True
  747. else:
  748. # 给定了task, 我们检查一下
  749. from toolbox import get_conf
  750. WHEN_TO_USE_PROXY = get_conf("WHEN_TO_USE_PROXY")
  751. self.valid = task in WHEN_TO_USE_PROXY
  752. def __enter__(self):
  753. if not self.valid:
  754. return self
  755. from toolbox import get_conf
  756. proxies = get_conf("proxies")
  757. if "no_proxy" in os.environ:
  758. os.environ.pop("no_proxy")
  759. if proxies is not None:
  760. if "http" in proxies:
  761. os.environ["HTTP_PROXY"] = proxies["http"]
  762. if "https" in proxies:
  763. os.environ["HTTPS_PROXY"] = proxies["https"]
  764. return self
  765. def __exit__(self, exc_type, exc_value, traceback):
  766. os.environ["no_proxy"] = "*"
  767. if "HTTP_PROXY" in os.environ:
  768. os.environ.pop("HTTP_PROXY")
  769. if "HTTPS_PROXY" in os.environ:
  770. os.environ.pop("HTTPS_PROXY")
  771. return
  772. def Singleton(cls):
  773. """
  774. 一个单实例装饰器
  775. """
  776. _instance = {}
  777. def _singleton(*args, **kargs):
  778. if cls not in _instance:
  779. _instance[cls] = cls(*args, **kargs)
  780. return _instance[cls]
  781. return _singleton
  782. def get_pictures_list(path):
  783. file_manifest = [f for f in glob.glob(f"{path}/**/*.jpg", recursive=True)]
  784. file_manifest += [f for f in glob.glob(f"{path}/**/*.jpeg", recursive=True)]
  785. file_manifest += [f for f in glob.glob(f"{path}/**/*.png", recursive=True)]
  786. return file_manifest
  787. def have_any_recent_upload_image_files(chatbot:ChatBotWithCookies, pop:bool=False):
  788. _5min = 5 * 60
  789. if chatbot is None:
  790. return False, None # chatbot is None
  791. if pop:
  792. most_recent_uploaded = chatbot._cookies.pop("most_recent_uploaded", None)
  793. else:
  794. most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
  795. # most_recent_uploaded 是一个放置最新上传图像的路径
  796. if not most_recent_uploaded:
  797. return False, None # most_recent_uploaded is None
  798. if time.time() - most_recent_uploaded["time"] < _5min:
  799. path = most_recent_uploaded["path"]
  800. file_manifest = get_pictures_list(path)
  801. if len(file_manifest) == 0:
  802. return False, None
  803. return True, file_manifest # most_recent_uploaded is new
  804. else:
  805. return False, None # most_recent_uploaded is too old
  806. # Claude3 model supports graphic context dialogue, reads all images
  807. def every_image_file_in_path(chatbot:ChatBotWithCookies):
  808. if chatbot is None:
  809. return False, [] # chatbot is None
  810. most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
  811. if not most_recent_uploaded:
  812. return False, [] # most_recent_uploaded is None
  813. path = most_recent_uploaded["path"]
  814. file_manifest = get_pictures_list(path)
  815. if len(file_manifest) == 0:
  816. return False, []
  817. return True, file_manifest
  818. # Function to encode the image
  819. def encode_image(image_path):
  820. with open(image_path, "rb") as image_file:
  821. return base64.b64encode(image_file.read()).decode("utf-8")
  822. def get_max_token(llm_kwargs):
  823. from request_llms.bridge_all import model_info
  824. return model_info[llm_kwargs["llm_model"]]["max_token"]
  825. def check_packages(packages=[]):
  826. import importlib.util
  827. for p in packages:
  828. spam_spec = importlib.util.find_spec(p)
  829. if spam_spec is None:
  830. raise ModuleNotFoundError
  831. def map_file_to_sha256(file_path):
  832. import hashlib
  833. with open(file_path, 'rb') as file:
  834. content = file.read()
  835. # Calculate the SHA-256 hash of the file contents
  836. sha_hash = hashlib.sha256(content).hexdigest()
  837. return sha_hash
  838. def check_repeat_upload(new_pdf_path, pdf_hash):
  839. '''
  840. 检查历史上传的文件是否与新上传的文件相同,如果相同则返回(True, 重复文件路径),否则返回(False,None)
  841. '''
  842. from toolbox import get_conf
  843. import PyPDF2
  844. user_upload_dir = os.path.dirname(os.path.dirname(new_pdf_path))
  845. file_name = os.path.basename(new_pdf_path)
  846. file_manifest = [f for f in glob.glob(f'{user_upload_dir}/**/{file_name}', recursive=True)]
  847. for saved_file in file_manifest:
  848. with open(new_pdf_path, 'rb') as file1, open(saved_file, 'rb') as file2:
  849. reader1 = PyPDF2.PdfFileReader(file1)
  850. reader2 = PyPDF2.PdfFileReader(file2)
  851. # 比较页数是否相同
  852. if reader1.getNumPages() != reader2.getNumPages():
  853. continue
  854. # 比较每一页的内容是否相同
  855. for page_num in range(reader1.getNumPages()):
  856. page1 = reader1.getPage(page_num).extractText()
  857. page2 = reader2.getPage(page_num).extractText()
  858. if page1 != page2:
  859. continue
  860. maybe_project_dir = glob.glob('{}/**/{}'.format(get_log_folder(), pdf_hash + ".tag"), recursive=True)
  861. if len(maybe_project_dir) > 0:
  862. return True, os.path.dirname(maybe_project_dir[0])
  863. # 如果所有页的内容都相同,返回 True
  864. return False, None
  865. def log_chat(llm_model: str, input_str: str, output_str: str):
  866. try:
  867. if output_str and input_str and llm_model:
  868. uid = str(uuid.uuid4().hex)
  869. input_str = input_str.rstrip('\n')
  870. output_str = output_str.rstrip('\n')
  871. logger.bind(chat_msg=True).info(dedent(
  872. """
  873. ╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
  874. [UID]
  875. {uid}
  876. [Model]
  877. {llm_model}
  878. [Query]
  879. {input_str}
  880. [Response]
  881. {output_str}
  882. ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
  883. """).format(uid=uid, llm_model=llm_model, input_str=input_str, output_str=output_str))
  884. except:
  885. logger.error(trimmed_format_exc())