moonshot_session.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. from bot.session_manager import Session
  2. from common.log import logger
  3. class MoonshotSession(Session):
  4. def __init__(self, session_id, system_prompt=None, model="moonshot-v1-128k"):
  5. super().__init__(session_id, system_prompt)
  6. self.model = model
  7. self.reset()
  8. def discard_exceeding(self, max_tokens, cur_tokens=None):
  9. precise = True
  10. try:
  11. cur_tokens = self.calc_tokens()
  12. except Exception as e:
  13. precise = False
  14. if cur_tokens is None:
  15. raise e
  16. logger.debug("Exception when counting tokens precisely for query: {}".format(e))
  17. while cur_tokens > max_tokens:
  18. if len(self.messages) > 2:
  19. self.messages.pop(1)
  20. elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant":
  21. self.messages.pop(1)
  22. if precise:
  23. cur_tokens = self.calc_tokens()
  24. else:
  25. cur_tokens = cur_tokens - max_tokens
  26. break
  27. elif len(self.messages) == 2 and self.messages[1]["role"] == "user":
  28. logger.warn("user message exceed max_tokens. total_tokens={}".format(cur_tokens))
  29. break
  30. else:
  31. logger.debug("max_tokens={}, total_tokens={}, len(messages)={}".format(max_tokens, cur_tokens,
  32. len(self.messages)))
  33. break
  34. if precise:
  35. cur_tokens = self.calc_tokens()
  36. else:
  37. cur_tokens = cur_tokens - max_tokens
  38. return cur_tokens
  39. def calc_tokens(self):
  40. return num_tokens_from_messages(self.messages, self.model)
  41. def num_tokens_from_messages(messages, model):
  42. tokens = 0
  43. for msg in messages:
  44. tokens += len(msg["content"])
  45. return tokens