zhipu_ai_session.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from bot.session_manager import Session
  2. from common.log import logger
  3. class ZhipuAISession(Session):
  4. def __init__(self, session_id, system_prompt=None, model="glm-4"):
  5. super().__init__(session_id, system_prompt)
  6. self.model = model
  7. self.reset()
  8. if not system_prompt:
  9. logger.warn("[ZhiPu] `character_desc` can not be empty")
  10. def discard_exceeding(self, max_tokens, cur_tokens=None):
  11. precise = True
  12. try:
  13. cur_tokens = self.calc_tokens()
  14. except Exception as e:
  15. precise = False
  16. if cur_tokens is None:
  17. raise e
  18. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  19. while cur_tokens > max_tokens:
  20. if len(self.messages) > 2:
  21. self.messages.pop(1)
  22. elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
  23. self.messages.pop(1)
  24. if precise:
  25. cur_tokens = self.calc_tokens()
  26. else:
  27. cur_tokens = cur_tokens - max_tokens
  28. break
  29. elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
  30. logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
  31. break
  32. else:
  33. logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
  34. len(self.messages)))
  35. break
  36. if precise:
  37. cur_tokens = self.calc_tokens()
  38. else:
  39. cur_tokens = cur_tokens - max_tokens
  40. return cur_tokens
  41. def calc_tokens(self):
  42. return num_tokens_from_messages(self.messages, self.model)
  43. def num_tokens_from_messages(messages, model):
  44. tokens = 0
  45. for msg in messages:
  46. tokens += len(msg["content"])
  47. return tokens