ast_eval_hf.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2023 https://github.com/ShishirPatil/gorilla
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. from tree_sitter import Language, Parser
  17. # Get all the subtrees given a root_node
  18. def get_all_sub_trees(root_node):
  19. node_stack = []
  20. sub_tree_sexp_list = []
  21. depth = 1
  22. text = root_node.text
  23. node_stack.append([root_node, depth])
  24. while len(node_stack) != 0:
  25. cur_node, cur_depth = node_stack.pop()
  26. if cur_node.child_count > 0:
  27. sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text])
  28. else:
  29. sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
  30. for child_node in cur_node.children:
  31. if len(child_node.children) != 0:
  32. depth = cur_depth + 1
  33. node_stack.append([child_node, depth])
  34. return sub_tree_sexp_list
  35. # Parse the program into AST trees
  36. def ast_parse(candidate, lang='python'):
  37. LANGUAGE = Language('codebleu/parser/my-languages.so', lang)
  38. parser = Parser()
  39. parser.set_language(LANGUAGE)
  40. candidate_tree = parser.parse(bytes(candidate,'utf8')).root_node
  41. return candidate_tree
  42. # Get all the arguments in the ast tree
  43. def get_args(node):
  44. if node.child_count == 0:
  45. return []
  46. args_list = []
  47. for child in node.children[0].children[0].children[1].children:
  48. if "=" in child.text.decode():
  49. args_list.append(child.children[2].text)
  50. elif child.text.decode() != "(" and child.text.decode() != ")" and child.text.decode() != ",":
  51. args_list.append(child.text)
  52. return args_list
  53. # Check if there is an api match
  54. def ast_check(candidate_subtree_list, base_tree_list):
  55. """
  56. Check if there is an API match between candidate subtrees and base trees.
  57. Args:
  58. candidate_subtree_list (list): A list of candidate subtrees with their depths and text contents.
  59. base_tree_list (list): A list of base trees to compare against.
  60. Returns:
  61. int: The index of the matching base tree in base_tree_list if a match is found, -1 otherwise.
  62. """
  63. for idx, base_tree in enumerate(base_tree_list):
  64. if base_tree.children[0].children[0].child_count == 0:
  65. continue
  66. api_name = base_tree.children[0].children[0].children[0].text
  67. for candidate_tree in candidate_subtree_list:
  68. if candidate_tree[3] == api_name:
  69. break
  70. # Now we have a sub-tree
  71. candidate_tree = candidate_tree[2]
  72. args_list = get_args(base_tree)
  73. if len(args_list) == 0:
  74. continue
  75. ast_match = True
  76. for arg in args_list:
  77. if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
  78. ast_match = False
  79. break
  80. if ast_match:
  81. return idx
  82. return -1
  83. # Parse the dataset
  84. def parse_dataset(args):
  85. # Read the api datasest
  86. api_database = []
  87. with open(args.api_dataset, 'r') as f:
  88. for line in f:
  89. api_database.append(json.loads(line))
  90. # Read the question answer pair datasest
  91. qa_pairs = []
  92. with open(args.apibench, 'r') as f:
  93. for line in f:
  94. qa_pairs.append(json.loads(line)["api_data"])
  95. # Read the language model response datasest
  96. llm_responses = []
  97. with open(args.llm_responses, 'r') as f:
  98. for line in f:
  99. llm_responses.append(json.loads(line))
  100. # Parse all apis to ast trees
  101. ast_database = []
  102. for data in api_database:
  103. ast_tree = ast_parse(data['api_call'])
  104. ast_database.append(ast_tree)
  105. return api_database, qa_pairs, llm_responses, ast_database
  106. def main(args):
  107. # Read datsets
  108. api_database, qa_pairs, llm_responses, ast_database = parse_dataset(args)
  109. # Check correctness
  110. total_correct = 0
  111. total_hallucination = 0
  112. for idx, response in enumerate(llm_responses):
  113. try:
  114. output = response['text']
  115. except:
  116. print('Error: cannot parse line ', idx)
  117. continue
  118. # Index the "api_call" domain
  119. output = output.split("api_call")
  120. if len(output) == 1:
  121. # print('Error: line ', idx, ' is not the right format')
  122. # continue
  123. api_call = output[0]
  124. else:
  125. # Parse the output
  126. output = output[1].split("api_provider")[0]
  127. if ":" not in output:
  128. start = 0
  129. else:
  130. start = output.index(":")
  131. if ")" not in output:
  132. end = -2
  133. else:
  134. end = output.rindex(")")
  135. api_call = output[start+2:end+1]
  136. # Parse the api_call into AST tree
  137. ast_tree = ast_parse(api_call)
  138. # Search for a subtree
  139. ast_subtree_list = get_all_sub_trees(ast_tree)
  140. # Check which ast tree is matching
  141. database_index = ast_check(ast_subtree_list, ast_database)
  142. # We cannot index this ast in our database
  143. if database_index == -1:
  144. total_hallucination += 1
  145. continue
  146. # We index our reference api_call
  147. ref_api_call = api_database[database_index]
  148. # Check for functionality
  149. if ref_api_call['domain'] == qa_pairs[response['question_id'] - 1]['domain']:
  150. total_correct += 1
  151. else:
  152. pass
  153. if args.use_wandb:
  154. import wandb
  155. if args.wandb_run_id is not None:
  156. wandb.init(project=args.wandb_project, entity=args.wandb_entity, id=args.wandb_run_id, resume="must")
  157. else:
  158. wandb.init(project=args.wandb_project, entity=args.wandb_entity)
  159. wandb.summary['final_functionality_accuracy'] = total_correct / len(llm_responses)
  160. wandb.summary['final_hallucination'] = total_hallucination/len(llm_responses)
  161. print('Final Functionality accuracy: ', total_correct / len(llm_responses))
  162. print('Final hallucination: ', total_hallucination/len(llm_responses))
  163. if __name__ == "__main__":
  164. parser = argparse.ArgumentParser()
  165. parser.add_argument("--api_dataset", type=str, default=None, help="path to your api dataset")
  166. parser.add_argument("--apibench", type=str, default=None, help="path to your apibench dataset including the question and answer pairs")
  167. parser.add_argument("--llm_responses", type=str, default=None, help="path to the language model responses")
  168. parser.add_argument("--use_wandb", action='store_true', help="pass this argument to turn on Weights & Biases logging of the LLM responses")
  169. parser.add_argument("--wandb_project", type=str, default="gorilla-api", help="Weights & Biases project name")
  170. parser.add_argument("--wandb_entity", type=str, default=None, help="Weights & Biases entity name")
  171. parser.add_argument("--wandb_run_id", type=str, default=None, help="pass W&B run id to append results to that run, otherwise a new W&B run is logged")
  172. args = parser.parse_args()
  173. main(args)