api.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import json, logging
  2. import requests
  3. # from utils.common import Common
  4. # from utils.logger import Configure_logger
  5. class Langchain_ChatGLM:
  6. def __init__(self, data):
  7. # self.common = Common()
  8. # 日志文件路径
  9. # file_path = "./log/log-" + self.common.get_bj_time(1) + ".txt"
  10. # Configure_logger(file_path)
  11. self.api_ip_port = data["api_ip_port"]
  12. self.chat_type = data["chat_type"]
  13. self.knowledge_base_id = data["knowledge_base_id"]
  14. self.history_enable = data["history_enable"]
  15. self.history_max_len = data["history_max_len"]
  16. self.history = []
  17. # 获取知识库列表
  18. def get_list_knowledge_base(self):
  19. url = self.api_ip_port + "/local_doc_qa/list_knowledge_base"
  20. try:
  21. response = requests.get(url)
  22. response.raise_for_status() # 检查响应的状态码
  23. result = response.content
  24. ret = json.loads(result)
  25. logging.debug(ret)
  26. logging.info(f"本地知识库列表:{ret['data']}")
  27. return ret['data']
  28. except Exception as e:
  29. logging.error(e)
  30. return None
  31. def get_resp(self, prompt):
  32. """请求对应接口,获取返回值
  33. Args:
  34. prompt (str): 你的提问
  35. Returns:
  36. str: 返回的文本回答
  37. """
  38. try:
  39. if self.chat_type == "模型":
  40. data_json = {
  41. "question": prompt,
  42. "streaming": False,
  43. "history": self.history
  44. }
  45. url = self.api_ip_port + "/chat"
  46. elif self.chat_type == "知识库":
  47. data_json = {
  48. "knowledge_base_id": self.knowledge_base_id,
  49. "question": prompt,
  50. "streaming": False,
  51. "history": self.history
  52. }
  53. url = self.api_ip_port + "/local_doc_qa/local_doc_chat"
  54. elif self.chat_type == "必应":
  55. data_json = {
  56. "question": prompt,
  57. "history": self.history
  58. }
  59. url = self.api_ip_port + "/local_doc_qa/bing_search_chat"
  60. else:
  61. data_json = {
  62. "question": prompt,
  63. "streaming": False,
  64. "history": self.history
  65. }
  66. url = self.api_ip_port + "/chat"
  67. response = requests.post(url=url, json=data_json)
  68. response.raise_for_status() # 检查响应的状态码
  69. result = response.content
  70. ret = json.loads(result)
  71. logging.debug(ret)
  72. if self.chat_type == "问答库" or self.chat_type == "必应":
  73. logging.info(f'源自:{ret["source_documents"]}')
  74. resp_content = ret['response']
  75. # 启用历史就给我记住!
  76. if self.history_enable:
  77. while True:
  78. # 获取嵌套列表中所有字符串的字符数
  79. total_chars = sum(len(string) for sublist in self.history for string in sublist)
  80. # 如果大于限定最大历史数,就剔除第一个元素
  81. if total_chars > self.history_max_len:
  82. self.history.pop(0)
  83. else:
  84. self.history.append(ret['history'][-1])
  85. break
  86. return resp_content
  87. except Exception as e:
  88. logging.error(e)
  89. return None
  90. if __name__ == '__main__':
  91. # 配置日志输出格式
  92. logging.basicConfig(
  93. level=logging.DEBUG, # 设置日志级别,可以根据需求调整
  94. format="%(asctime)s [%(levelname)s] %(message)s",
  95. datefmt="%Y-%m-%d %H:%M:%S",
  96. )
  97. data = {
  98. "api_ip_port": "http://127.0.0.1:7861",
  99. # 模型/知识库/必应
  100. "chat_type": "必应",
  101. "knowledge_base_id": "ikaros",
  102. "history_enable": True,
  103. "history_max_len": 300
  104. }
  105. langchain_chatglm = Langchain_ChatGLM(data)
  106. if data["chat_type"] == "模型":
  107. logging.info(langchain_chatglm.get_resp("你可以扮演猫娘吗,每句话后面加个喵"))
  108. logging.info(langchain_chatglm.get_resp("早上好"))
  109. elif data["chat_type"] == "知识库":
  110. langchain_chatglm.get_list_knowledge_base()
  111. logging.info(langchain_chatglm.get_resp("伊卡洛斯喜欢谁"))
  112. # please set BING_SUBSCRIPTION_KEY and BING_SEARCH_URL in os ENV
  113. elif data["chat_type"] == "必应":
  114. logging.info(langchain_chatglm.get_resp("伊卡洛斯是谁"))