bridge_jittorllms_rwkv.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from transformers import AutoModel, AutoTokenizer
  2. import time
  3. import threading
  4. import importlib
  5. from toolbox import update_ui, get_conf
  6. from multiprocessing import Process, Pipe
  7. load_message = "jittorllms尚未加载,加载需要一段时间。注意,请避免混用多种jittor模型,否则可能导致显存溢出而造成卡顿,取决于`config.py`的配置,jittorllms消耗大量的内存(CPU)或显存(GPU),也许会导致低配计算机卡死 ……"
  8. #################################################################################
  9. class GetGLMHandle(Process):
  10. def __init__(self):
  11. super().__init__(daemon=True)
  12. self.parent, self.child = Pipe()
  13. self.jittorllms_model = None
  14. self.info = ""
  15. self.local_history = []
  16. self.success = True
  17. self.check_dependency()
  18. self.start()
  19. self.threadLock = threading.Lock()
  20. def check_dependency(self):
  21. try:
  22. import pandas
  23. self.info = "依赖检测通过"
  24. self.success = True
  25. except:
  26. from toolbox import trimmed_format_exc
  27. self.info = r"缺少jittorllms的依赖,如果要使用jittorllms,除了基础的pip依赖以外,您还需要运行`pip install -r request_llms/requirements_jittorllms.txt -i https://pypi.jittor.org/simple -I`"+\
  28. r"和`git clone https://gitlink.org.cn/jittor/JittorLLMs.git --depth 1 request_llms/jittorllms`两个指令来安装jittorllms的依赖(在项目根目录运行这两个指令)。" +\
  29. r"警告:安装jittorllms依赖后将完全破坏现有的pytorch环境,建议使用docker环境!" + trimmed_format_exc()
  30. self.success = False
  31. def ready(self):
  32. return self.jittorllms_model is not None
  33. def run(self):
  34. # 子进程执行
  35. # 第一次运行,加载参数
  36. def validate_path():
  37. import os, sys
  38. dir_name = os.path.dirname(__file__)
  39. env = os.environ.get("PATH", "")
  40. os.environ["PATH"] = env.replace('/cuda/bin', '/x/bin')
  41. root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
  42. os.chdir(root_dir_assume + '/request_llms/jittorllms')
  43. sys.path.append(root_dir_assume + '/request_llms/jittorllms')
  44. validate_path() # validate path so you can run from base directory
  45. def load_model():
  46. import types
  47. try:
  48. if self.jittorllms_model is None:
  49. device = get_conf('LOCAL_MODEL_DEVICE')
  50. from .jittorllms.models import get_model
  51. # availabel_models = ["chatglm", "pangualpha", "llama", "chatrwkv"]
  52. args_dict = {'model': 'chatrwkv'}
  53. print('self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))')
  54. self.jittorllms_model = get_model(types.SimpleNamespace(**args_dict))
  55. print('done get model')
  56. except:
  57. self.child.send('[Local Message] Call jittorllms fail 不能正常加载jittorllms的参数。')
  58. raise RuntimeError("不能正常加载jittorllms的参数!")
  59. print('load_model')
  60. load_model()
  61. # 进入任务等待状态
  62. print('进入任务等待状态')
  63. while True:
  64. # 进入任务等待状态
  65. kwargs = self.child.recv()
  66. query = kwargs['query']
  67. history = kwargs['history']
  68. # 是否重置
  69. if len(self.local_history) > 0 and len(history)==0:
  70. print('触发重置')
  71. self.jittorllms_model.reset()
  72. self.local_history.append(query)
  73. print('收到消息,开始请求')
  74. try:
  75. for response in self.jittorllms_model.stream_chat(query, history):
  76. print(response)
  77. self.child.send(response)
  78. except:
  79. from toolbox import trimmed_format_exc
  80. print(trimmed_format_exc())
  81. self.child.send('[Local Message] Call jittorllms fail.')
  82. # 请求处理结束,开始下一个循环
  83. self.child.send('[Finish]')
  84. def stream_chat(self, **kwargs):
  85. # 主进程执行
  86. self.threadLock.acquire()
  87. self.parent.send(kwargs)
  88. while True:
  89. res = self.parent.recv()
  90. if res != '[Finish]':
  91. yield res
  92. else:
  93. break
  94. self.threadLock.release()
  95. global rwkv_glm_handle
  96. rwkv_glm_handle = None
  97. #################################################################################
  98. def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=[], console_slience=False):
  99. """
  100. 多线程方法
  101. 函数的说明请见 request_llms/bridge_all.py
  102. """
  103. global rwkv_glm_handle
  104. if rwkv_glm_handle is None:
  105. rwkv_glm_handle = GetGLMHandle()
  106. if len(observe_window) >= 1: observe_window[0] = load_message + "\n\n" + rwkv_glm_handle.info
  107. if not rwkv_glm_handle.success:
  108. error = rwkv_glm_handle.info
  109. rwkv_glm_handle = None
  110. raise RuntimeError(error)
  111. # jittorllms 没有 sys_prompt 接口,因此把prompt加入 history
  112. history_feedin = []
  113. for i in range(len(history)//2):
  114. history_feedin.append([history[2*i], history[2*i+1]] )
  115. watch_dog_patience = 5 # 看门狗 (watchdog) 的耐心, 设置5秒即可
  116. response = ""
  117. for response in rwkv_glm_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=sys_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
  118. print(response)
  119. if len(observe_window) >= 1: observe_window[0] = response
  120. if len(observe_window) >= 2:
  121. if (time.time()-observe_window[1]) > watch_dog_patience:
  122. raise RuntimeError("程序终止。")
  123. return response
  124. def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream = True, additional_fn=None):
  125. """
  126. 单线程方法
  127. 函数的说明请见 request_llms/bridge_all.py
  128. """
  129. chatbot.append((inputs, ""))
  130. global rwkv_glm_handle
  131. if rwkv_glm_handle is None:
  132. rwkv_glm_handle = GetGLMHandle()
  133. chatbot[-1] = (inputs, load_message + "\n\n" + rwkv_glm_handle.info)
  134. yield from update_ui(chatbot=chatbot, history=[])
  135. if not rwkv_glm_handle.success:
  136. rwkv_glm_handle = None
  137. return
  138. if additional_fn is not None:
  139. from core_functional import handle_core_functionality
  140. inputs, history = handle_core_functionality(additional_fn, inputs, history, chatbot)
  141. # 处理历史信息
  142. history_feedin = []
  143. for i in range(len(history)//2):
  144. history_feedin.append([history[2*i], history[2*i+1]] )
  145. # 开始接收jittorllms的回复
  146. response = "[Local Message] 等待jittorllms响应中 ..."
  147. for response in rwkv_glm_handle.stream_chat(query=inputs, history=history_feedin, system_prompt=system_prompt, max_length=llm_kwargs['max_length'], top_p=llm_kwargs['top_p'], temperature=llm_kwargs['temperature']):
  148. chatbot[-1] = (inputs, response)
  149. yield from update_ui(chatbot=chatbot, history=history)
  150. # 总结输出
  151. if response == "[Local Message] 等待jittorllms响应中 ...":
  152. response = "[Local Message] jittorllms响应异常 ..."
  153. history.extend([inputs, response])
  154. yield from update_ui(chatbot=chatbot, history=history)