gorilla_eval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright 2023 https://github.com/ShishirPatil/gorilla
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import argparse
  16. import os
  17. from tqdm import tqdm
  18. import torch
  19. from transformers import (
  20. AutoConfig,
  21. AutoModel,
  22. AutoModelForCausalLM,
  23. AutoModelForSeq2SeqLM,
  24. AutoTokenizer,
  25. LlamaTokenizer,
  26. LlamaForCausalLM,
  27. T5Tokenizer,
  28. )
  29. # Load Gorilla Model from HF
  30. def load_model(
  31. model_path: str,
  32. device: str,
  33. num_gpus: int,
  34. max_gpu_memory: str = None,
  35. load_8bit: bool = False,
  36. cpu_offloading: bool = False,
  37. ):
  38. if device == "cpu":
  39. kwargs = {"torch_dtype": torch.float32}
  40. elif device == "cuda":
  41. kwargs = {"torch_dtype": torch.float16}
  42. if num_gpus != 1:
  43. kwargs["device_map"] = "auto"
  44. if max_gpu_memory is None:
  45. kwargs[
  46. "device_map"
  47. ] = "sequential" # This is important for not the same VRAM sizes
  48. available_gpu_memory = get_gpu_memory(num_gpus)
  49. kwargs["max_memory"] = {
  50. i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
  51. for i in range(num_gpus)
  52. }
  53. else:
  54. kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
  55. else:
  56. raise ValueError(f"Invalid device: {device}")
  57. if cpu_offloading:
  58. # raises an error on incompatible platforms
  59. from transformers import BitsAndBytesConfig
  60. if "max_memory" in kwargs:
  61. kwargs["max_memory"]["cpu"] = (
  62. str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
  63. )
  64. kwargs["quantization_config"] = BitsAndBytesConfig(
  65. load_in_8bit_fp32_cpu_offload=cpu_offloading
  66. )
  67. kwargs["load_in_8bit"] = load_8bit
  68. elif load_8bit:
  69. if num_gpus != 1:
  70. warnings.warn(
  71. "8-bit quantization is not supported for multi-gpu inference."
  72. )
  73. else:
  74. return load_compress_model(
  75. model_path=model_path, device=device, torch_dtype=kwargs["torch_dtype"]
  76. )
  77. tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
  78. model = AutoModelForCausalLM.from_pretrained(
  79. model_path,
  80. low_cpu_mem_usage=True,
  81. **kwargs,
  82. )
  83. return model, tokenizer
  84. def get_questions(question_file):
  85. # Load questions file
  86. question_jsons = []
  87. with open(question_file, "r") as ques_file:
  88. for line in ques_file:
  89. question_jsons.append(line)
  90. return question_jsons
  91. def run_eval(args, question_jsons):
  92. # Evaluate the model for answers
  93. model, tokenizer = load_model(
  94. args.model_path, args.device, args.num_gpus, args.max_gpu_memory, args.load_8bit, args.cpu_offloading
  95. )
  96. if (args.device == "cuda" and args.num_gpus == 1 and not args.cpu_offloading) or args.device == "mps":
  97. model.to(args.device)
  98. # model = model.to(args.device)
  99. ans_jsons = []
  100. for i, line in enumerate(tqdm(question_jsons)):
  101. ques_json = json.loads(line)
  102. idx = ques_json["question_id"]
  103. prompt = ques_json["text"]
  104. prompt = "###USER: " + prompt + "###ASSISTANT: "
  105. input_ids = tokenizer([prompt]).input_ids
  106. output_ids = model.generate(
  107. torch.as_tensor(input_ids).to(args.device),
  108. do_sample=True,
  109. temperature=0.7,
  110. max_new_tokens=2048,
  111. )
  112. output_ids = output_ids[0][len(input_ids[0]) :]
  113. outputs = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
  114. ans_jsons.append(
  115. {
  116. "question_id": idx,
  117. "questions": prompt,
  118. "response": outputs,
  119. }
  120. )
  121. # Write output to file
  122. with open(args.answer_file, "w") as ans_file:
  123. for line in ans_jsons:
  124. ans_file.write(json.dumps(line) + "\n")
  125. return ans_jsons
  126. if __name__ == "__main__":
  127. parser = argparse.ArgumentParser()
  128. parser.add_argument(
  129. "--model-path",
  130. type=str,
  131. required=True)
  132. parser.add_argument(
  133. "--question-file",
  134. type=str,
  135. required=True)
  136. parser.add_argument(
  137. "--device",
  138. type=str,
  139. choices=["cpu", "cuda", "mps"],
  140. default="cuda",
  141. help="The device type",
  142. )
  143. parser.add_argument(
  144. "--max-gpu-memory",
  145. type=str,
  146. help="The maximum memory per gpu. A string like '13Gib'",
  147. )
  148. parser.add_argument(
  149. "--load-8bit",
  150. action="store_true",
  151. help="Use 8-bit quantization"
  152. )
  153. parser.add_argument(
  154. "--cpu-offloading",
  155. action="store_true",
  156. help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
  157. )
  158. parser.add_argument(
  159. "--answer-file",
  160. type=str,
  161. default="answer.jsonl"
  162. )
  163. parser.add_argument(
  164. "--num-gpus",
  165. type=int,
  166. default=1
  167. )
  168. args = parser.parse_args()
  169. questions_json = get_questions(args.question_file)
  170. run_eval(
  171. args,
  172. questions_json
  173. )