format.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from abc import ABC, abstractmethod
  2. import argparse
  3. from datasets import Dataset, load_dataset
  4. from typing import Dict, Literal, Any, get_args
  5. from logconf import log_setup
  6. import logging
  7. """
  8. This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models.
  9. """
  10. OutputDatasetType = Literal["parquet", "jsonl"]
  11. outputDatasetTypes = list(get_args(OutputDatasetType))
  12. InputDatasetType = Literal["arrow", "jsonl"]
  13. inputDatasetTypes = list(get_args(InputDatasetType))
  14. DatasetFormat = Literal["hf", "completion", "chat", "eval"]
  15. datasetFormats = list(get_args(DatasetFormat))
  16. default_chat_system_prompt = "The following is a conversation with an AI assistant. The assistant is helpful, clever, friendly and gives concise and accurate answers."
  17. def get_args() -> argparse.Namespace:
  18. """
  19. Parses and returns the arguments specified by the user's command
  20. """
  21. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  22. parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file")
  23. parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes)
  24. parser.add_argument("--output", type=str, required=True, help="Output file")
  25. parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats)
  26. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
  27. parser.add_argument("--output-chat-system-prompt", type=str, default=default_chat_system_prompt, help="The system prompt to use when the output format is chat")
  28. parser.add_argument("--output-completion-prompt-column", type=str, default="prompt", help="The prompt column name to use for the completion format")
  29. parser.add_argument("--output-completion-completion-column", type=str, default="completion", help="The completion column name to use for the completion format")
  30. parser.add_argument("--output-completion-stop", type=str, default="<STOP>", help="The stop keyword to use for the completion format")
  31. args = parser.parse_args()
  32. return args
  33. class DatasetFormatter(ABC):
  34. """
  35. Base class for dataset formatters. Formatters rename columns, remove and add
  36. columns to match the expected target format structure. HF, Chat or Completion models file formats.
  37. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  38. """
  39. @abstractmethod
  40. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  41. pass
  42. class DatasetExporter(ABC):
  43. """
  44. Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ...
  45. """
  46. @abstractmethod
  47. def export(self, ds: Dataset, output_path: str):
  48. pass
  49. class DatasetConverter():
  50. """
  51. Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them.
  52. """
  53. formats: Dict[DatasetFormat, DatasetFormatter]
  54. exporters: Dict[OutputDatasetType, Any]
  55. def __init__(self) -> None:
  56. self.formats = {
  57. "hf": HuggingFaceDatasetFormatter(),
  58. "completion": OpenAiCompletionDatasetFormatter(),
  59. "chat": OpenAiChatDatasetFormatter(),
  60. "eval": EvalDatasetFormatter(),
  61. }
  62. self.exporters = {
  63. "parquet": ParquetDatasetExporter(),
  64. "jsonl": JsonlDatasetExporter()
  65. }
  66. def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]):
  67. if not format in self.formats:
  68. raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}")
  69. if not output_type in self.exporters:
  70. raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}")
  71. formatter = self.formats[format]
  72. newds = formatter.format(ds, **params)
  73. exporter = self.exporters[output_type]
  74. exporter.export(newds, output_path)
  75. class HuggingFaceDatasetFormatter(DatasetFormatter):
  76. """
  77. Returns the HuggingFace Dataset as is
  78. """
  79. def format(self, ds: Dataset) -> Dataset:
  80. return ds
  81. def _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset:
  82. """
  83. HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help
  84. removes all columns but the ones specified.
  85. """
  86. remove_columns = list(ds.column_names)
  87. for keep in keep_columns:
  88. try:
  89. remove_columns.remove(keep)
  90. except ValueError:
  91. raise Exception(f"Column {keep} not found in {remove_columns}")
  92. ds = ds.remove_columns(remove_columns)
  93. return ds
  94. class OpenAiCompletionDatasetFormatter(DatasetFormatter):
  95. """
  96. Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion".
  97. Field names can be customized because different systems have different expectations.
  98. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  99. """
  100. def format(self, ds: Dataset, prompt_column: str = 'prompt', completion_column : str = 'completion', stop: str = '<STOP>') -> Dataset:
  101. newds = ds.filter(lambda example: example['cot_answer'] and example['instruction'], desc="Filter out empty examples")
  102. newds = newds.rename_columns({'instruction': prompt_column})
  103. newds = newds.map(lambda examples: {completion_column: [answer + stop for answer in examples['cot_answer']]}, batched=True, desc=f"Rename fields and add {stop} token")
  104. return _remove_all_columns_but(newds, [prompt_column, completion_column])
  105. class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter):
  106. """
  107. Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages".
  108. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  109. """
  110. def format(self, ds: Dataset, system_prompt: str, **params) -> Dataset:
  111. newds = super().format(ds, stop = "")
  112. def format_messages(row):
  113. messages = []
  114. if system_prompt:
  115. messages.append({ "role": "system", "content": system_prompt})
  116. messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}])
  117. chat_row = {"messages": messages}
  118. return chat_row
  119. newds = newds.map(format_messages)
  120. return _remove_all_columns_but(newds, ['messages'])
  121. def extract_final_answer(cot_answer: str) -> str:
  122. """
  123. Extracts the final answer from the cot_answer field
  124. """
  125. if cot_answer:
  126. return cot_answer.split("<ANSWER>: ")[-1]
  127. return None
  128. def extract_context(instruction: str) -> str:
  129. """
  130. Extracts the context from the instruction field.
  131. Keeps all <DOCUMENTS/> and removes the last line with the question.
  132. """
  133. return "\n".join(instruction.split("\n")[:-1])
  134. class EvalDatasetFormatter(DatasetFormatter):
  135. """
  136. Returns the Dataset in a format suitable for evaluation. Extracts final answer separates context from question.
  137. """
  138. def format(self, ds: Dataset) -> Dataset:
  139. newds = ds.filter(lambda example: example['cot_answer'] and example['instruction'] and example['context'], desc="Filter out empty examples")
  140. newds = newds.rename_columns({'context': 'context_sentences'})
  141. newds = newds.map(lambda examples: {"gold_final_answer": [extract_final_answer(answer) for answer in examples['cot_answer']]}, batched=True)
  142. keep_columns = ['question', 'gold_final_answer', 'context']
  143. if 'answer' in newds.column_names:
  144. [keep_columns.append(col) for col in ['answer', 'final_answer']]
  145. newds = newds.map(lambda examples: {"final_answer": [extract_final_answer(answer) for answer in examples['answer']]}, batched=True)
  146. newds = newds.map(lambda examples: {"context": [extract_context(instruction) for instruction in examples['instruction']]}, batched=True)
  147. return _remove_all_columns_but(newds, keep_columns)
  148. def append_extension(path: str, extension: str) -> str:
  149. suffix = "." + extension
  150. if not path.endswith(suffix):
  151. path = path + suffix
  152. return path
  153. class JsonlDatasetExporter(DatasetExporter):
  154. """
  155. Exports the Dataset to a JSONL file
  156. """
  157. def export(self, ds: Dataset, output_path: str):
  158. ds.to_json(append_extension(output_path, "jsonl"))
  159. class ParquetDatasetExporter(DatasetExporter):
  160. """
  161. Exports the Dataset to a Parquet file
  162. """
  163. def export(self, ds: Dataset, output_path: str):
  164. ds.to_parquet(append_extension(output_path, "parquet"))
  165. def main():
  166. """
  167. When raft.py is executed from the command line.
  168. """
  169. log_setup()
  170. args = get_args()
  171. input_type = args.input_type
  172. # datasets except json when loading jsonl files
  173. if input_type == "jsonl":
  174. input_type = "json"
  175. logger = logging.getLogger("raft")
  176. ds = load_dataset(input_type, data_files={"train": args.input})['train']
  177. logger.info(f"Dataset has {ds.num_rows} rows")
  178. formatter = DatasetConverter()
  179. format_params = {}
  180. if args.output_chat_system_prompt and args.output_format == "chat":
  181. format_params['system_prompt'] = args.output_chat_system_prompt
  182. if args.output_format == "completion":
  183. format_params['prompt_column'] = args.output_completion_prompt_column
  184. format_params['completion_column'] = args.output_completion_completion_column
  185. format_params['stop'] = args.output_completion_stop
  186. logger.info(f"Converting {args.input_type} file {args.input} to {args.output_type} {args.output_format} file {args.output}")
  187. formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)
  188. if __name__ == "__main__":
  189. main()