conv_template.py 8.5 KB


  1. """
  2. Conversation prompt templates.
  3. Thanks to LMSYS for the template of this code.
  4. """
  5. import dataclasses
  6. from enum import auto, Enum
  7. from typing import List, Any, Dict
  8. class SeparatorStyle(Enum):
  9. """Separator styles."""
  10. ADD_COLON_SINGLE = auto()
  11. ADD_COLON_TWO = auto()
  12. ADD_COLON_SPACE_SINGLE = auto()
  13. NO_COLON_SINGLE = auto()
  14. ADD_NEW_LINE_SINGLE = auto()
  15. DOLLY = auto()
  16. RWKV = auto()
  17. PHOENIX = auto()
  18. NEW_LINE = auto()
  19. @dataclasses.dataclass
  20. class Conversation:
  21. """A class that keeps all conversation history."""
  22. # The name of this template
  23. name: str
  24. # The system prompt
  25. system: str
  26. # Two roles
  27. roles: List[str]
  28. # All messages. Each item is (role, message).
  29. messages: List[List[str]]
  30. # The number of few shot examples
  31. offset: int
  32. # Separators
  33. sep_style: SeparatorStyle
  34. sep: str
  35. sep2: str = None
  36. # Stop criteria (the default one is EOS token)
  37. stop_str: str = None
  38. # Stops generation if meeting any token in this list
  39. stop_token_ids: List[int] = None
  40. def get_prompt(self) -> str:
  41. """Get the prompt for generation."""
  42. if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
  43. ret = self.system + self.sep
  44. for role, message in self.messages:
  45. if message:
  46. ret += role + ": " + message + self.sep
  47. else:
  48. ret += role + ":"
  49. return ret
  50. elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
  51. seps = [self.sep, self.sep2]
  52. ret = self.system + seps[0]
  53. for i, (role, message) in enumerate(self.messages):
  54. if message:
  55. ret += role + ": " + message + seps[i % 2]
  56. else:
  57. ret += role + ":"
  58. return ret
  59. elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
  60. ret = self.system + self.sep
  61. for role, message in self.messages:
  62. if message:
  63. ret += role + ": " + message + self.sep
  64. else:
  65. ret += role + ": " # must be end with a space
  66. return ret
  67. elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
  68. ret = self.system
  69. for role, message in self.messages:
  70. if message:
  71. ret += role + message + self.sep
  72. else:
  73. ret += role
  74. return ret
  75. elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
  76. ret = self.system + self.sep
  77. for role, message in self.messages:
  78. if message:
  79. ret += role + "\n" + message + self.sep
  80. else:
  81. ret += role + "\n"
  82. return ret
  83. elif self.sep_style == SeparatorStyle.DOLLY:
  84. seps = [self.sep, self.sep2]
  85. ret = self.system
  86. for i, (role, message) in enumerate(self.messages):
  87. if message:
  88. ret += role + ":\n" + message + seps[i % 2]
  89. if i % 2 == 1:
  90. ret += "\n\n"
  91. else:
  92. ret += role + ":\n"
  93. return ret
  94. elif self.sep_style == SeparatorStyle.RWKV:
  95. ret = self.system
  96. for i, (role, message) in enumerate(self.messages):
  97. if message:
  98. ret += (
  99. role
  100. + ": "
  101. + message.replace("\r\n", "\n").replace("\n\n", "\n")
  102. )
  103. ret += "\n\n"
  104. else:
  105. ret += role + ":"
  106. return ret
  107. elif self.sep_style == SeparatorStyle.PHOENIX:
  108. ret = self.system
  109. for role, message in self.messages:
  110. if message:
  111. ret += role + ": " + "<s>" + message + "</s>"
  112. else:
  113. ret += role + ": " + "<s>"
  114. return ret
  115. elif self.sep_style == SeparatorStyle.NEW_LINE:
  116. ret = self.system + self.sep
  117. for role, message in self.messages:
  118. if message:
  119. ret += role + "\n" + message + self.sep
  120. else:
  121. ret += role + "\n"
  122. return ret
  123. else:
  124. raise ValueError(f"Invalid style: {self.sep_style}")
  125. def append_message(self, role: str, message: str):
  126. """Append a new message."""
  127. self.messages.append([role, message])
  128. def update_last_message(self, message: str):
  129. """Update the last output.
  130. The last message is typically set to be None when constructing the prompt,
  131. so we need to update it in-place after getting the response from a model.
  132. """
  133. self.messages[-1][1] = message
  134. def to_gradio_chatbot(self):
  135. """Convert the conversation to gradio chatbot format"""
  136. ret = []
  137. for i, (role, msg) in enumerate(self.messages[self.offset :]):
  138. if i % 2 == 0:
  139. ret.append([msg, None])
  140. else:
  141. ret[-1][-1] = msg
  142. return ret
  143. def to_openai_api_messages(self):
  144. """Convert the conversation to OpenAI chat completion format."""
  145. ret = [{"role": "system", "content": self.system}]
  146. for i, (_, msg) in enumerate(self.messages[self.offset :]):
  147. if i % 2 == 0:
  148. ret.append({"role": "user", "content": msg})
  149. else:
  150. if msg is not None:
  151. ret.append({"role": "assistant", "content": msg})
  152. return ret
  153. def copy(self):
  154. return Conversation(
  155. name=self.name,
  156. system=self.system,
  157. roles=self.roles,
  158. messages=[[x, y] for x, y in self.messages],
  159. offset=self.offset,
  160. sep_style=self.sep_style,
  161. sep=self.sep,
  162. sep2=self.sep2,
  163. stop_str=self.stop_str,
  164. stop_token_ids=self.stop_token_ids,
  165. )
  166. def dict(self):
  167. return {
  168. "name": self.name,
  169. "system": self.system,
  170. "roles": self.roles,
  171. "messages": self.messages,
  172. "offset": self.offset,
  173. }
  174. # A global registry for all conversation templates
  175. conv_templates: Dict[str, Conversation] = {}
  176. def register_conv_template(template: Conversation, override: bool = False):
  177. """Register a new conversation template."""
  178. if not override:
  179. assert template.name not in conv_templates, f"{name} has been registered."
  180. conv_templates[template.name] = template
  181. def get_conv_template(name: str) -> Conversation:
  182. """Get a conversation template."""
  183. return conv_templates[name].copy()
  184. # Gorilla v0 template
  185. register_conv_template(
  186. Conversation(
  187. name="gorilla_v0",
  188. system="A chat between a curious user and an artificial intelligence assistant. "
  189. "The assistant gives helpful, detailed, and polite answers to the user's questions.",
  190. roles=("USER", "ASSISTANT"),
  191. messages=(),
  192. offset=0,
  193. sep_style=SeparatorStyle.ADD_COLON_TWO,
  194. sep="\n",
  195. sep2="</s>",
  196. )
  197. )
  198. # Falcon Template
  199. register_conv_template(
  200. Conversation(
  201. name="falcon",
  202. system="",
  203. # system="A chat between a curious user and an artificial intelligence assistant. "
  204. # "The assistant gives helpful, detailed, and polite answers to the user's questions.",
  205. roles=("User", "Assistant"),
  206. messages=(),
  207. offset=0,
  208. sep_style=SeparatorStyle.ADD_COLON_TWO,
  209. sep=" ",
  210. sep2="<|endoftext|>",
  211. )
  212. )
  213. # MPT Template
  214. register_conv_template(
  215. Conversation(
  216. name="mpt",
  217. system="""system
  218. - You are a helpful assistant chatbot trained by MosaicML.
  219. - You answer questions.
  220. - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
  221. - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
  222. """,
  223. roles=("user", "assistant"),
  224. messages=(),
  225. offset=0,
  226. sep_style=SeparatorStyle.NEW_LINE,
  227. sep=" ",
  228. sep2="<|endoftext|>",
  229. stop_token_ids=[50278, 0],
  230. )
  231. )