eval.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from typing import Any
  2. from openai import RateLimitError
  3. from openai.types.chat import ChatCompletionMessageParam
  4. import multiprocessing as mp
  5. import time
  6. import argparse
  7. import json
  8. import os
  9. from client_utils import StatsCompleter, UsageStats, build_openai_client
  10. import logging
  11. from logconf import log_setup
  12. from tqdm import tqdm
  13. from concurrent.futures import ThreadPoolExecutor, as_completed
  14. from dotenv import load_dotenv
  15. from tenacity import Retrying, retry, wait_exponential, retry_if_exception_type, before_sleep_log
  16. from client_utils import CompletionsCompleter
  17. load_dotenv() # take environment variables from .env.
  18. def get_args() -> argparse.Namespace:
  19. """
  20. Parses and returns the arguments specified by the user's command
  21. """
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument("--question-file", type=str, required=True)
  24. parser.add_argument("--answer-file", type=str, default="answer.jsonl")
  25. parser.add_argument("--model", type=str, default="gpt-4", help="The model to evaluate")
  26. parser.add_argument("--mode", type=str, default="chat", help="The model API mode. 'chat' or 'completion' mode. Defaults to 'chat' mode.")
  27. parser.add_argument("--input-prompt-key", type=str, default="instruction", help="The column to use as input prompt")
  28. parser.add_argument("--output-answer-key", type=str, default="answer", help="The column to use as output answer")
  29. parser.add_argument("--workers", type=int, default=2, help="The number of worker threads to use to evaluate the dataset")
  30. parser.add_argument("--env-prefix", type=str, default="EVAL", help="The OPENAI env var prefix. Defaults to EVAL for EVAL_OPENAI_BASE_URL and EVAL_OPENAI_API_KEY")
  31. args = parser.parse_args()
  32. return args
  33. if __name__ == "__main__":
  34. log_setup()
  35. logger = logging.getLogger('eval')
  36. args = get_args()
  37. model = args.model
  38. mode = args.mode
  39. prompt_key = args.input_prompt_key
  40. answer_key = args.output_answer_key
  41. logger.info(f"Using model: {model}")
  42. logger.info(f"Using mode: {mode}")
  43. logger.info(f"Using prompt key: {prompt_key}")
  44. logger.info(f"Using answer key: {answer_key}")
  45. client = build_openai_client(env_prefix = args.env_prefix)
  46. if mode not in ['chat', 'completion']:
  47. raise ValueError("Invalid --mode. Mode must be either 'chat' or 'completion'")
  48. # Chat or completion mode function
  49. complete = client.chat.completions.create if mode == 'chat' else client.completions.create
  50. # Wrap with retry decorator
  51. @retry(wait=wait_exponential(multiplier=1, min=10, max=120), reraise=True, retry=retry_if_exception_type(RateLimitError), before_sleep=before_sleep_log(logger, logging.INFO))
  52. def retry_complete(*args, **kwargs):
  53. return complete(*args, **kwargs)
  54. # Wrap with statistics completer
  55. completions_completer = StatsCompleter(retry_complete)
  56. def get_answer(input_json: dict[str, Any]) -> dict[str, Any]:
  57. message = [{"role": "user", "content": input_json['instruction']}]
  58. result = get_openai_response(message)
  59. input_json['model_answer'] = result
  60. return input_json
  61. # Evaluate a chat model
  62. def get_openai_response_chat(prompt: str | list[ChatCompletionMessageParam]) -> str | None :
  63. messages = [{"role": "user", "content": prompt}]
  64. response = completions_completer(
  65. model=model,
  66. messages=messages,
  67. temperature=0.2,
  68. max_tokens=1024,
  69. stop='<STOP>'
  70. )
  71. return response.choices[0].message.content
  72. # Evaluate a completion model
  73. def get_openai_response_completion(prompt: str) -> str | None :
  74. response = completions_completer(
  75. model=model,
  76. prompt=prompt,
  77. temperature=0.2,
  78. max_tokens=1024,
  79. stop='<STOP>'
  80. )
  81. return response.choices[0].text
  82. # Chat or completion mode function
  83. get_openai_response = get_openai_response_chat if mode == 'chat' else get_openai_response_completion
  84. def get_answer(input_json: dict[str, Any]) -> dict[str, Any]:
  85. prompt = input_json[prompt_key]
  86. try:
  87. result = get_openai_response(prompt)
  88. input_json[answer_key] = result
  89. except Exception as e:
  90. input_json['error'] = str(e)
  91. return input_json
  92. def write_result_to_file(
  93. result: dict[str, Any],
  94. write_file_name: str
  95. ) -> None:
  96. global file_write_lock
  97. with file_write_lock:
  98. with open(write_file_name, "a") as outfile:
  99. json.dump(result, outfile)
  100. outfile.write("\n")
  101. write_file_name = args.answer_file
  102. if os.path.isfile(write_file_name):
  103. logger.info(f"Removing existing file: {write_file_name}")
  104. os.remove(write_file_name)
  105. num_workers = args.workers
  106. file_write_lock = mp.Lock()
  107. inputs = []
  108. question_file = args.question_file
  109. logger.info(f"Reading questions from: {question_file}")
  110. with open(question_file, 'r') as f:
  111. for line in f:
  112. inputs.append(json.loads(line))
  113. logger.info(f'Number of questions: {len(inputs)}')
  114. start_time = time.time()
  115. usage_stats = UsageStats()
  116. tps = 0
  117. retrying: Retrying = retry_complete.retry
  118. with tqdm(total=len(inputs), unit="answers") as pbar:
  119. with ThreadPoolExecutor(num_workers) as executor:
  120. futures = [executor.submit(get_answer, input) for input in inputs]
  121. for future in as_completed(futures):
  122. result = future.result()
  123. stats = completions_completer.get_stats_and_reset()
  124. if stats:
  125. tps = stats.total_tokens / stats.duration
  126. usage_stats += stats
  127. retry_stats = retrying.statistics
  128. if len(retry_stats.keys()) > 0:
  129. logger.info(f"retrying stats: {retry_stats}")
  130. pbar.set_postfix({'last tok/s': tps, 'avg tok/s': usage_stats.total_tokens / usage_stats.duration})
  131. pbar.update(1)
  132. write_result_to_file(result, write_file_name)
  133. end_time = time.time()
  134. logger.info(f"Wrote evaluation results to {write_file_name}")
  135. logger.info(f"total time used: {end_time - start_time}")