history_processors.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import re
  2. from abc import abstractmethod
  3. from dataclasses import dataclass
  4. class FormatError(Exception):
  5. pass
  6. # ABSTRACT BASE CLASSES
  7. class HistoryProcessorMeta(type):
  8. _registry = {}
  9. def __new__(cls, name, bases, attrs):
  10. new_cls = super().__new__(cls, name, bases, attrs)
  11. if name != "HistoryProcessor":
  12. cls._registry[name] = new_cls
  13. return new_cls
  14. @dataclass
  15. class HistoryProcessor(metaclass=HistoryProcessorMeta):
  16. def __init__(self, *args, **kwargs):
  17. pass
  18. @abstractmethod
  19. def __call__(self, history: list[str]) -> list[str]:
  20. raise NotImplementedError
  21. @classmethod
  22. def get(cls, name, *args, **kwargs):
  23. try:
  24. return cls._registry[name](*args, **kwargs)
  25. except KeyError:
  26. raise ValueError(f"Model output parser ({name}) not found.")
  27. # DEFINE NEW PARSING FUNCTIONS BELOW THIS LINE
  28. class DefaultHistoryProcessor(HistoryProcessor):
  29. def __call__(self, history):
  30. return history
  31. def last_n_history(history, n):
  32. if n <= 0:
  33. raise ValueError('n must be a positive integer')
  34. new_history = list()
  35. user_messages = len([entry for entry in history if (entry['role'] == 'user' and not entry.get('is_demo', False))])
  36. user_msg_idx = 0
  37. for entry in history:
  38. data = entry.copy()
  39. if data['role'] != 'user':
  40. new_history.append(entry)
  41. continue
  42. if data.get('is_demo', False):
  43. new_history.append(entry)
  44. continue
  45. else:
  46. user_msg_idx += 1
  47. if user_msg_idx == 1 or user_msg_idx in range(user_messages - n + 1, user_messages + 1):
  48. new_history.append(entry)
  49. else:
  50. data['content'] = f'Old output omitted ({len(entry["content"].splitlines())} lines)'
  51. new_history.append(data)
  52. return new_history
  53. class LastNObservations(HistoryProcessor):
  54. def __init__(self, n):
  55. self.n = n
  56. def __call__(self, history):
  57. return last_n_history(history, self.n)
  58. class Last2Observations(HistoryProcessor):
  59. def __call__(self, history):
  60. return last_n_history(history, 2)
  61. class Last5Observations(HistoryProcessor):
  62. def __call__(self, history):
  63. return last_n_history(history, 5)
  64. class ClosedWindowHistoryProcessor(HistoryProcessor):
  65. pattern = re.compile(r'^(\d+)\:.*?(\n|$)', re.MULTILINE)
  66. file_pattern = re.compile(r'\[File:\s+(.*)\s+\(\d+\s+lines\ total\)\]')
  67. def __call__(self, history):
  68. new_history = list()
  69. # For each value in history, keep track of which windows have been shown.
  70. # We want to mark windows that should stay open (they're the last window for a particular file)
  71. # Then we'll replace all other windows with a simple summary of the window (i.e. number of lines)
  72. windows = set()
  73. for entry in reversed(history):
  74. data = entry.copy()
  75. if data['role'] != 'user':
  76. new_history.append(entry)
  77. continue
  78. if data.get('is_demo', False):
  79. new_history.append(entry)
  80. continue
  81. matches = list(self.pattern.finditer(entry['content']))
  82. if len(matches) >= 1:
  83. file_match = self.file_pattern.search(entry['content'])
  84. if file_match:
  85. file = file_match.group(1)
  86. else:
  87. continue
  88. if file in windows:
  89. start = matches[0].start()
  90. end = matches[-1].end()
  91. data['content'] = (
  92. entry['content'][:start] +\
  93. f'Outdated window with {len(matches)} lines omitted...\n' +\
  94. entry['content'][end:]
  95. )
  96. windows.add(file)
  97. new_history.append(data)
  98. history = list(reversed(new_history))
  99. return history