123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- """
- Conversation prompt templates.
- Thanks to LMSYS for the template of this code.
- """
- import dataclasses
- from enum import auto, Enum
- from typing import List, Any, Dict
- class SeparatorStyle(Enum):
- """Separator styles."""
- ADD_COLON_SINGLE = auto()
- ADD_COLON_TWO = auto()
- ADD_COLON_SPACE_SINGLE = auto()
- NO_COLON_SINGLE = auto()
- ADD_NEW_LINE_SINGLE = auto()
- DOLLY = auto()
- RWKV = auto()
- PHOENIX = auto()
- NEW_LINE = auto()
- @dataclasses.dataclass
- class Conversation:
- """A class that keeps all conversation history."""
- # The name of this template
- name: str
- # The system prompt
- system: str
- # Two roles
- roles: List[str]
- # All messages. Each item is (role, message).
- messages: List[List[str]]
- # The number of few shot examples
- offset: int
- # Separators
- sep_style: SeparatorStyle
- sep: str
- sep2: str = None
- # Stop criteria (the default one is EOS token)
- stop_str: str = None
- # Stops generation if meeting any token in this list
- stop_token_ids: List[int] = None
- def get_prompt(self) -> str:
- """Get the prompt for generation."""
- if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
- ret = self.system + self.sep
- for role, message in self.messages:
- if message:
- ret += role + ": " + message + self.sep
- else:
- ret += role + ":"
- return ret
- elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
- seps = [self.sep, self.sep2]
- ret = self.system + seps[0]
- for i, (role, message) in enumerate(self.messages):
- if message:
- ret += role + ": " + message + seps[i % 2]
- else:
- ret += role + ":"
- return ret
- elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
- ret = self.system + self.sep
- for role, message in self.messages:
- if message:
- ret += role + ": " + message + self.sep
- else:
- ret += role + ": " # must be end with a space
- return ret
- elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
- ret = self.system
- for role, message in self.messages:
- if message:
- ret += role + message + self.sep
- else:
- ret += role
- return ret
- elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
- ret = self.system + self.sep
- for role, message in self.messages:
- if message:
- ret += role + "\n" + message + self.sep
- else:
- ret += role + "\n"
- return ret
- elif self.sep_style == SeparatorStyle.DOLLY:
- seps = [self.sep, self.sep2]
- ret = self.system
- for i, (role, message) in enumerate(self.messages):
- if message:
- ret += role + ":\n" + message + seps[i % 2]
- if i % 2 == 1:
- ret += "\n\n"
- else:
- ret += role + ":\n"
- return ret
- elif self.sep_style == SeparatorStyle.RWKV:
- ret = self.system
- for i, (role, message) in enumerate(self.messages):
- if message:
- ret += (
- role
- + ": "
- + message.replace("\r\n", "\n").replace("\n\n", "\n")
- )
- ret += "\n\n"
- else:
- ret += role + ":"
- return ret
- elif self.sep_style == SeparatorStyle.PHOENIX:
- ret = self.system
- for role, message in self.messages:
- if message:
- ret += role + ": " + "<s>" + message + "</s>"
- else:
- ret += role + ": " + "<s>"
- return ret
- elif self.sep_style == SeparatorStyle.NEW_LINE:
- ret = self.system + self.sep
- for role, message in self.messages:
- if message:
- ret += role + "\n" + message + self.sep
- else:
- ret += role + "\n"
- return ret
- else:
- raise ValueError(f"Invalid style: {self.sep_style}")
- def append_message(self, role: str, message: str):
- """Append a new message."""
- self.messages.append([role, message])
- def update_last_message(self, message: str):
- """Update the last output.
- The last message is typically set to be None when constructing the prompt,
- so we need to update it in-place after getting the response from a model.
- """
- self.messages[-1][1] = message
- def to_gradio_chatbot(self):
- """Convert the conversation to gradio chatbot format"""
- ret = []
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
- if i % 2 == 0:
- ret.append([msg, None])
- else:
- ret[-1][-1] = msg
- return ret
- def to_openai_api_messages(self):
- """Convert the conversation to OpenAI chat completion format."""
- ret = [{"role": "system", "content": self.system}]
- for i, (_, msg) in enumerate(self.messages[self.offset :]):
- if i % 2 == 0:
- ret.append({"role": "user", "content": msg})
- else:
- if msg is not None:
- ret.append({"role": "assistant", "content": msg})
- return ret
- def copy(self):
- return Conversation(
- name=self.name,
- system=self.system,
- roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- offset=self.offset,
- sep_style=self.sep_style,
- sep=self.sep,
- sep2=self.sep2,
- stop_str=self.stop_str,
- stop_token_ids=self.stop_token_ids,
- )
- def dict(self):
- return {
- "name": self.name,
- "system": self.system,
- "roles": self.roles,
- "messages": self.messages,
- "offset": self.offset,
- }
- # A global registry for all conversation templates
- conv_templates: Dict[str, Conversation] = {}
- def register_conv_template(template: Conversation, override: bool = False):
- """Register a new conversation template."""
- if not override:
- assert template.name not in conv_templates, f"{name} has been registered."
- conv_templates[template.name] = template
- def get_conv_template(name: str) -> Conversation:
- """Get a conversation template."""
- return conv_templates[name].copy()
- # Gorilla v0 template
- register_conv_template(
- Conversation(
- name="gorilla_v0",
- system="A chat between a curious user and an artificial intelligence assistant. "
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
- roles=("USER", "ASSISTANT"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.ADD_COLON_TWO,
- sep="\n",
- sep2="</s>",
- )
- )
- # Falcon Template
- register_conv_template(
- Conversation(
- name="falcon",
- system="",
- # system="A chat between a curious user and an artificial intelligence assistant. "
- # "The assistant gives helpful, detailed, and polite answers to the user's questions.",
- roles=("User", "Assistant"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.ADD_COLON_TWO,
- sep=" ",
- sep2="<|endoftext|>",
- )
- )
- # MPT Template
- register_conv_template(
- Conversation(
- name="mpt",
- system="""system
- - You are a helpful assistant chatbot trained by MosaicML.
- - You answer questions.
- - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
- """,
- roles=("user", "assistant"),
- messages=(),
- offset=0,
- sep_style=SeparatorStyle.NEW_LINE,
- sep=" ",
- sep2="<|endoftext|>",
- stop_token_ids=[50278, 0],
- )
- )
|