get_llm_responses.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2023 https://github.com/ShishirPatil/gorilla
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import re
  16. import os
  17. import sys
  18. import json
  19. import openai
  20. import anthropic
  21. import multiprocessing as mp
  22. import time
  23. import wandb
  24. from tenacity import retry, wait_exponential
  25. def encode_question(question, api_name):
  26. """Encode multiple prompt instructions into a single string."""
  27. prompts = []
  28. if api_name == "torchhub":
  29. domains = "1. $DOMAIN is inferred from the task description and should include one of {Classification, Semantic Segmentation, Object Detection, Audio Separation, Video Classification, Text-to-Speech}."
  30. elif api_name == "huggingface":
  31. domains = "1. $DOMAIN should include one of {Multimodal Feature Extraction, Multimodal Text-to-Image, Multimodal Image-to-Text, Multimodal Text-to-Video, \
  32. Multimodal Visual Question Answering, Multimodal Document Question Answer, Multimodal Graph Machine Learning, Computer Vision Depth Estimation,\
  33. Computer Vision Image Classification, Computer Vision Object Detection, Computer Vision Image Segmentation, Computer Vision Image-to-Image, \
  34. Computer Vision Unconditional Image Generation, Computer Vision Video Classification, Computer Vision Zero-Shor Image Classification, \
  35. Natural Language Processing Text Classification, Natural Language Processing Token Classification, Natural Language Processing Table Question Answering, \
  36. Natural Language Processing Question Answering, Natural Language Processing Zero-Shot Classification, Natural Language Processing Translation, \
  37. Natural Language Processing Summarization, Natural Language Processing Conversational, Natural Language Processing Text Generation, Natural Language Processing Fill-Mask,\
  38. Natural Language Processing Text2Text Generation, Natural Language Processing Sentence Similarity, Audio Text-to-Speech, Audio Automatic Speech Recognition, \
  39. Audio Audio-to-Audio, Audio Audio Classification, Audio Voice Activity Detection, Tabular Tabular Classification, Tabular Tabular Regression, \
  40. Reinforcement Learning Reinforcement Learning, Reinforcement Learning Robotics }"
  41. elif api_name == "tensorhub":
  42. domains = "1. $DOMAIN is inferred from the task description and should include one of {text-sequence-alignment, text-embedding, text-language-model, text-preprocessing, text-classification, text-generation, text-question-answering, text-retrieval-question-answering, text-segmentation, text-to-mel, image-classification, image-feature-vector, image-object-detection, image-segmentation, image-generator, image-pose-detection, image-rnn-agent, image-augmentation, image-classifier, image-style-transfer, image-aesthetic-quality, image-depth-estimation, image-super-resolution, image-deblurring, image-extrapolation, image-text-recognition, image-dehazing, image-deraining, image-enhancemenmt, image-classification-logits, image-frame-interpolation, image-text-detection, image-denoising, image-others, video-classification, video-feature-extraction, video-generation, video-audio-text, video-text, audio-embedding, audio-event-classification, audio-command-detection, audio-paralinguists-classification, audio-speech-to-text, audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
  43. else:
  44. print("Error: API name is not supported.")
  45. prompt = question + "\nWrite a python program in 1 to 2 lines to call API in " + api_name + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. Here are the requirements:\n" + domains + "\n2. The $API_CALL should have only 1 line of code that calls api.\n3. The $API_PROVIDER should be the programming framework used.\n4. $EXPLANATION should be a step-by-step explanation.\n5. The $CODE is the python code.\n6. Do not repeat the format in your answer."
  46. prompts.append({"role": "system", "content": "You are a helpful API writer who can write APIs based on requirements."})
  47. prompts.append({"role": "user", "content": prompt})
  48. return prompts
  49. @retry(wait=wait_exponential(multiplier=1, min=10, max=120), reraise=True)
  50. def get_response(get_response_input, api_key):
  51. question, question_id, api_name, model = get_response_input
  52. question = encode_question(question, api_name)
  53. try:
  54. if "gpt" in model:
  55. openai.api_key = api_key
  56. responses = openai.ChatCompletion.create(
  57. model=model,
  58. messages=question,
  59. n=1,
  60. temperature=0,
  61. )
  62. response = responses['choices'][0]['message']['content']
  63. elif "claude" in model:
  64. client = anthropic.Anthropic(api_key=api_key)
  65. responses = client.completions.create(
  66. prompt=f"{anthropic.HUMAN_PROMPT} {question[0]['content']}{question[1]['content']}{anthropic.AI_PROMPT}",
  67. stop_sequences=[anthropic.HUMAN_PROMPT],
  68. model="claude-v1",
  69. max_tokens_to_sample=2048,
  70. )
  71. response = responses.completion.strip()
  72. else:
  73. print("Error: Model is not supported.")
  74. except Exception as e:
  75. print("Error:", e)
  76. return None
  77. print("=>",)
  78. return {'text': response, "question_id": question_id, "answer_id": "None", "model_id": model, "metadata": {}}
  79. def process_entry(entry, api_key):
  80. question, question_id, api_name, model = entry
  81. result = get_response((question, question_id, api_name, model), api_key)
  82. wandb.log({"question_id_completed":question_id})
  83. return result
  84. def write_result_to_file(result, output_file):
  85. global file_write_lock
  86. with file_write_lock:
  87. with open(output_file, "a") as outfile:
  88. json.dump(result, outfile)
  89. outfile.write("\n")
  90. def callback_with_lock(result, output_file):
  91. global file_write_lock
  92. write_result_to_file(result, output_file, file_write_lock)
  93. if __name__ == '__main__':
  94. parser = argparse.ArgumentParser()
  95. parser.add_argument("--model", type=str, default=None, help="which model you want to use for eval, only support ['gpt*', 'claude*'] now")
  96. parser.add_argument("--api_key", type=str, default=None, help="the api key provided for calling")
  97. parser.add_argument("--output_file", type=str, default=None, help="the output file this script writes to")
  98. parser.add_argument("--question_data", type=str, default=None, help="path to the questions data file")
  99. parser.add_argument("--api_name", type=str, default=None, help="this will be the api dataset name you are testing, only support ['torchhub', 'tensorhun', 'huggingface'] now")
  100. parser.add_argument("--use_wandb", action='store_true', help="pass this argument to turn on Weights & Biases logging of the LLM responses")
  101. parser.add_argument("--wandb_project", type=str, default="gorilla-api", help="Weights & Biases project name")
  102. parser.add_argument("--wandb_entity", type=str, default=None, help="Weights & Biases entity name")
  103. args = parser.parse_args()
  104. if args.use_wandb:
  105. wandb.init(
  106. project=args.wandb_project,
  107. entity=args.wandb_entity,
  108. config={
  109. "api_name":args.api_name,
  110. "model":args.model,
  111. "question_data":args.question_data,
  112. "output_file": args.output_file
  113. }
  114. )
  115. start_time = time.time()
  116. # Read the question file
  117. questions = []
  118. question_ids = []
  119. with open(args.question_data, 'r') as f:
  120. for idx, line in enumerate(f):
  121. questions.append(json.loads(line)["text"])
  122. question_ids.append(json.loads(line)["question_id"])
  123. if os.path.exists(args.output_file):
  124. print(f"\nExisting responses file found at: {args.output_file}, deleting it ...\n")
  125. os.remove(args.output_file)
  126. file_write_lock = mp.Lock()
  127. with mp.Pool(1) as pool:
  128. results = []
  129. for idx, (question, question_id) in enumerate(zip(questions, question_ids)):
  130. result = pool.apply_async(
  131. process_entry,
  132. args=((question, question_id, args.api_name, args.model), args.api_key),
  133. callback=lambda result: write_result_to_file(result, args.output_file),
  134. )
  135. results.append(result)
  136. pool.close()
  137. pool.join()
  138. end_time = time.time()
  139. elapsed_time = end_time - start_time
  140. print("Total time used: ", elapsed_time)
  141. if args.use_wandb:
  142. print("\nSaving all responses to Weights & Biases...\n")
  143. wandb.summary["elapsed_time_s"] = elapsed_time
  144. wandb.log({"elapsed_time_s":elapsed_time})
  145. line_count = 0
  146. with open(args.output_file, 'r') as file:
  147. for i,line in enumerate(file):
  148. data = json.loads(line.strip())
  149. if i == 0:
  150. tbl = wandb.Table(columns=list(data.keys()))
  151. if data is not None:
  152. tbl.add_data(*list(data.values()))
  153. line_count+=1
  154. # Log the Tale to W&B
  155. wandb.log({"llm_eval_responses": tbl})
  156. wandb.summary["response_count"] = line_count
  157. # Also log results file as W&B Artifact
  158. artifact_model_name = re.sub(r'[^a-zA-Z0-9-_.]', '-', args.model)
  159. wandb.log_artifact(args.output_file,
  160. name=f"{args.api_name}-{artifact_model_name}-eval-responses",
  161. type=f"eval-responses",
  162. aliases=[f"{line_count}-responses"]
  163. )