dashscope_session.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from bot.session_manager import Session
  2. from common.log import logger
  3. class DashscopeSession(Session):
  4. def __init__(self, session_id, system_prompt=None, model="qwen-turbo"):
  5. super().__init__(session_id)
  6. self.reset()
  7. def discard_exceeding(self, max_tokens, cur_tokens=None):
  8. precise = True
  9. try:
  10. cur_tokens = self.calc_tokens()
  11. except Exception as e:
  12. precise = False
  13. if cur_tokens is None:
  14. raise e
  15. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  16. while cur_tokens > max_tokens:
  17. if len(self.messages) > 2:
  18. self.messages.pop(1)
  19. elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
  20. self.messages.pop(1)
  21. if precise:
  22. cur_tokens = self.calc_tokens()
  23. else:
  24. cur_tokens = cur_tokens - max_tokens
  25. break
  26. elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
  27. logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
  28. break
  29. else:
  30. logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
  31. len(self.messages)))
  32. break
  33. if precise:
  34. cur_tokens = self.calc_tokens()
  35. else:
  36. cur_tokens = cur_tokens - max_tokens
  37. return cur_tokens
  38. def calc_tokens(self):
  39. return num_tokens_from_messages(self.messages)
  40. def num_tokens_from_messages(messages):
  41. # 只是大概,具体计算规则:https://help.aliyun.com/zh/dashscope/developer-reference/token-api?spm=a2c4g.11186623.0.0.4d8b12b0BkP3K9
  42. tokens = 0
  43. for msg in messages:
  44. tokens += len(msg["content"])
  45. return tokens