raft_local.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import logging
  2. from typing import Literal, Any
  3. import argparse
  4. import json
  5. import PyPDF2
  6. import random
  7. import os, shutil
  8. from math import ceil
  9. from datasets import Dataset, concatenate_datasets
  10. from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering
  11. import torch
  12. # Set up logging
  13. logging.basicConfig(level=logging.INFO)
  14. logger = logging.getLogger("huggingface_script")
  15. # Document type literals
  16. DocType = Literal["api", "pdf", "json", "txt"]
  17. # Every N chunks, save a checkpoint
  18. N = 15
  19. def get_args() -> argparse.Namespace:
  20. """
  21. Parses and returns the command line arguments specified by the user.
  22. """
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument("--datapath", type=str, default="", help="The path at which the document is located")
  25. parser.add_argument("--output", type=str, default="./", help="The path at which to save the dataset")
  26. parser.add_argument("--output-format", type=str, default="hf", help="Format to convert the dataset to. Defaults to hf.")
  27. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.")
  28. parser.add_argument("--distractors", type=int, default=3, help="The number of distractor documents to include per data point / triplet")
  29. parser.add_argument("--p", type=float, default=1.0, help="The percentage that the oracle document is included in the context")
  30. parser.add_argument("--questions", type=int, default=5, help="The number of data points / triplets to generate per chunk")
  31. parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
  32. parser.add_argument("--doctype", type=str, default="pdf", help="The type of the document", choices=["pdf", "txt", "json", "api"])
  33. parser.add_argument("--fast", action="store_true", help="Run the script in fast mode (no recovery implemented)")
  34. args = parser.parse_args()
  35. return args
  36. def get_chunks(file_path: str, doctype: DocType = "pdf", chunk_size: int = 512) -> list[str]:
  37. """
  38. Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
  39. `chunk_size`, and returns the chunks as a list of strings.
  40. """
  41. chunks = []
  42. logger.info(f"Retrieving chunks from {file_path} of type {doctype}")
  43. if doctype == "api":
  44. # Load API documentation and process it
  45. with open(file_path) as f:
  46. api_docs_json = json.load(f)
  47. chunks = [str(api_doc_json) for api_doc_json in api_docs_json]
  48. else:
  49. if doctype == "json":
  50. # Load JSON document
  51. with open(file_path, 'r') as f:
  52. data = json.load(f)
  53. text = data["text"]
  54. elif doctype == "pdf":
  55. # Load PDF and extract text
  56. text = ""
  57. with open(file_path, 'rb') as file:
  58. reader = PyPDF2.PdfReader(file)
  59. num_pages = len(reader.pages)
  60. for page_num in range(num_pages):
  61. page = reader.pages[page_num]
  62. text += page.extract_text()
  63. elif doctype == "txt":
  64. # Load plain text document
  65. with open(file_path, 'r') as file:
  66. text = file.read()
  67. else:
  68. raise TypeError("Document is not one of the accepted types: api, pdf, json, txt")
  69. # Split the text into chunks
  70. num_chunks = ceil(len(text) / chunk_size)
  71. logger.info(f"Splitting text into {num_chunks} chunks.")
  72. for i in range(0, len(text), chunk_size):
  73. chunks.append(text[i:i + chunk_size])
  74. return chunks
  75. def generate_instructions_hf(chunk: str, x: int = 5, model_name: str = "t5-small") -> list[str]:
  76. """
  77. Uses a Hugging Face model to generate `x` questions based on the given text chunk, utilizing the GPU if available.
  78. """
  79. # Load the Hugging Face model and tokenizer for question generation
  80. tokenizer = AutoTokenizer.from_pretrained(model_name)
  81. model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
  82. # Move model to GPU if available
  83. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  84. model.to(device)
  85. input_text = f"Generate questions based on the following text: {chunk}"
  86. inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding="longest").to(device)
  87. outputs = model.generate(
  88. inputs.input_ids,
  89. max_length=64,
  90. num_beams=x, # Using beam search with `x` beams
  91. num_return_sequences=x # Returning `x` sequences
  92. )
  93. questions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
  94. return questions
  95. def generate_label_hf(question: str, context: str, model_name: str = "deepset/roberta-base-squad2") -> str:
  96. """
  97. Uses a Hugging Face model to generate an answer to the given question based on the context, utilizing the GPU if available.
  98. """
  99. # Load the Hugging Face model and tokenizer for question-answering
  100. question_answering_pipeline = pipeline("question-answering", model=model_name, device=0 if torch.cuda.is_available() else -1)
  101. result = question_answering_pipeline(question=question, context=context)
  102. return result['answer']
  103. def add_chunk_to_dataset(
  104. chunks: list[str],
  105. chunk: str,
  106. doctype: DocType = "api",
  107. x: int = 5,
  108. num_distract: int = 3,
  109. p: float = 0.8,
  110. model_name_qg: str = "t5-small",
  111. model_name_qa: str = "deepset/roberta-base-squad2"
  112. ) -> None:
  113. """
  114. Given a chunk, create {Q, A, D} triplets and add them to the dataset using Hugging Face models.
  115. """
  116. global ds
  117. i = chunks.index(chunk)
  118. # Use the Hugging Face model to generate questions
  119. qs = generate_instructions_hf(chunk, x, model_name=model_name_qg)
  120. for q in qs:
  121. datapt = {
  122. "id": None,
  123. "type": None,
  124. "question": None,
  125. "context": None,
  126. "oracle_context": None,
  127. "cot_answer": None
  128. }
  129. datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
  130. datapt["type"] = "api call" if doctype == "api" else "general"
  131. datapt["question"] = q
  132. # Create distractor documents
  133. docs = [chunk]
  134. indices = list(range(0, len(chunks)))
  135. indices.remove(i)
  136. for j in random.sample(indices, num_distract):
  137. docs.append(chunks[j])
  138. # Decide whether to add oracle document
  139. oracle = random.uniform(0, 1) < p
  140. if not oracle:
  141. docs[0] = chunks[random.sample(indices, 1)[0]]
  142. random.shuffle(docs)
  143. d = {
  144. "title": ["placeholder_title"] * (num_distract + 1),
  145. "sentences": docs
  146. }
  147. datapt["context"] = d
  148. datapt["oracle_context"] = chunk
  149. # Add the answer generated by the Hugging Face model
  150. datapt["cot_answer"] = generate_label_hf(q, chunk, model_name=model_name_qa)
  151. # Construct model instruction
  152. context = ""
  153. for doc in docs:
  154. context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  155. context += q
  156. datapt["instruction"] = context
  157. # Add to dataset
  158. if not ds:
  159. # Initialize dataset
  160. datapt["id"] = [datapt["id"]]
  161. datapt["type"] = [datapt["type"]]
  162. datapt["question"] = [datapt["question"]]
  163. datapt["context"] = [datapt["context"]]
  164. datapt["oracle_context"] = [datapt["oracle_context"]]
  165. datapt["cot_answer"] = [datapt["cot_answer"]]
  166. datapt["instruction"] = [datapt["instruction"]]
  167. ds = Dataset.from_dict(datapt)
  168. else:
  169. ds = ds.add_item(datapt)
  170. def save_checkpoint(state, filename):
  171. """
  172. Saves the current state of processing to a file for recovery.
  173. """
  174. with open(filename, 'w') as f:
  175. f.write(str(state))
  176. def load_checkpoint(filename):
  177. """
  178. Loads the processing state from a checkpoint file.
  179. """
  180. with open(filename, 'r') as f:
  181. return int(f.read())
  182. def main():
  183. global ds
  184. # Get command line arguments
  185. args = get_args()
  186. CHUNK_SIZE = args.chunk_size
  187. NUM_DISTRACT_DOCS = args.distractors
  188. # Split the document into chunks
  189. chunks = get_chunks(args.datapath, args.doctype, CHUNK_SIZE)
  190. ds = None
  191. num_chunks = len(chunks)
  192. if not args.fast:
  193. start = 0
  194. if os.path.exists("checkpoint.txt"):
  195. start = int(load_checkpoint("checkpoint.txt"))
  196. for i in range((start // N) * N, len(chunks)):
  197. chunk = chunks[i]
  198. save_checkpoint(i, "checkpoint.txt")
  199. perc = ceil(i / num_chunks * 100)
  200. logger.info(f"Adding chunk {i}/{num_chunks}")
  201. add_chunk_to_dataset(chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS)
  202. if (i + 1) % N == 0:
  203. ds.save_to_disk(args.output + "-checkpoints-" + str(i))
  204. ds = None
  205. if ds:
  206. ds.save_to_disk(args.output + "-checkpoints-last")
  207. ds_list = []
  208. for filename in os.listdir(os.path.dirname(args.output)):
  209. if "-checkpoints-" in filename:
  210. for f in os.listdir(os.path.dirname(args.output) + "/" + filename):
  211. if f.endswith(".arrow"):
  212. ds_list.append(Dataset.from_file(os.path.dirname(args.output) + "/" + filename + "/" + f))
  213. ds = concatenate_datasets(ds_list)
  214. else:
  215. for i, chunk in enumerate(chunks):
  216. perc = ceil(i / num_chunks * 100)
  217. logger.info(f"Adding chunk {i}/{num_chunks}")
  218. add_chunk_to_dataset(chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS)
  219. # Save the final dataset
  220. ds.save_to_disk(args.output)
  221. # Save as .jsonl format (dummy functionality)
  222. # Implement a conversion function if needed, this is just a placeholder
  223. logger.info("Converting dataset to the desired format...")
  224. if not args.fast:
  225. os.remove("checkpoint.txt")
  226. for filename in os.listdir(os.path.dirname(args.output)):
  227. if "-checkpoints-" in filename:
  228. shutil.rmtree(os.path.dirname(args.output) + "/" + filename)
  229. if __name__ == "__main__":
  230. logger.info("Starting the Hugging Face processing script...")
  231. main()