few_shot_schema_gpt.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. """
  2. Copyright 2023 Yingqiang Ge
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. __author__ = "Yingqiang Ge"
  14. __copyright__ = "Copyright 2023, OpenAGI"
  15. __date__ = "2023/04/10"
  16. __license__ = "Apache 2.0"
  17. __version__ = "0.0.1"
  18. from sentence_transformers import SentenceTransformer, util
  19. import os
  20. # os.chdir('../')
  21. from benchmark_tasks.general_dataset import GeneralDataset
  22. from torch.utils.data import DataLoader
  23. import torch
  24. from benchmark_tasks.agi_utils import *
  25. import openai
  26. import numpy as np
  27. from IPython.utils import io
  28. import random
  29. from tqdm import tqdm
  30. from evaluate import load
  31. from torchvision import transforms
  32. from transformers import AutoModel, AutoFeatureExtractor
  33. from torchmetrics.multimodal import CLIPScore
  34. from benchmark_tasks.combine_model_seq import SeqCombine
  35. def run_few_gpt(args):
  36. """
  37. assign openagi data path
  38. """
  39. data_path = args.data_path
  40. device_list = args.device_list
  41. eval_device = args.eval_device
  42. batch_size = args.batch_size
  43. task_discriptions = txt_loader(data_path+"task_description.txt")
  44. # task_idx = [0,21,61,105,110,120,10,35,62,107,115]
  45. # test_task_idx = [2,3,10,15,20,35,45,55,65,70,70,90,106,107]
  46. test_task_idx = [2]
  47. test_dataloaders = []
  48. for i in test_task_idx:
  49. dataset = GeneralDataset(i, data_path)
  50. dataloader = DataLoader(dataset, batch_size=batch_size)
  51. test_dataloaders.append(dataloader)
  52. test_tasks = [task_discriptions[i].strip() for i in test_task_idx]
  53. # Training
  54. train_solution = []
  55. with open(data_path+'train_model_sequence.txt') as f:
  56. lines = f.readlines()
  57. for line in lines[:50]:
  58. train_solution.append(line)
  59. f.close()
  60. train_tasks = []
  61. with open(data_path+'train_task_description.txt') as f:
  62. lines = f.readlines()
  63. for line in lines[:50]:
  64. train_tasks.append(line)
  65. f.close()
  66. context = ""
  67. for i in range(len(train_tasks)):
  68. steps = ""
  69. for index,j in enumerate(train_solution[i].split(',')):
  70. steps += "Step " + str(index+1) + ":" + j.strip("\n") + ", \n"
  71. cur = "Problem: " + train_tasks[i] + "Solution:\n" + steps
  72. context += cur
  73. # print(context + "Problem: " + test_tasks[0]+"\nSoltuion: ")
  74. # device_list = ["cuda:1","cuda:2","cuda:3","cuda:4","cuda:5","cuda:7","cpu"]
  75. # device_list = ["cuda:3","cuda:4","cpu"]
  76. seqCombination = SeqCombine(args)
  77. # Load a pre-trained Vision Transformer model and its feature extractor
  78. vit_ckpt = "nateraw/vit-base-beans"
  79. vit = AutoModel.from_pretrained(vit_ckpt)
  80. vit.eval()
  81. vit_extractor = AutoFeatureExtractor.from_pretrained(vit_ckpt)
  82. f = transforms.ToPILImage()
  83. bertscore = load("bertscore")
  84. clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
  85. """
  86. assign openai api_key
  87. """
  88. openai.api_key = args.openai_key
  89. # device = "cuda:4"
  90. rewards = []
  91. clips = []
  92. berts = []
  93. similairies = []
  94. sentence_model = SentenceTransformer('all-MiniLM-L6-v2', device="cpu")
  95. # Testing
  96. for i, task_description in enumerate(tqdm(test_tasks)):
  97. task_rewards = []
  98. with torch.no_grad():
  99. completion = openai.ChatCompletion.create(model="gpt-3.5-turbo",\
  100. messages=[{"role": "user",\
  101. "content": context +\
  102. "Problem: " +\
  103. task_description +\
  104. "\nSoltuion: "}]
  105. )
  106. gpt_output = completion.choices[0].message['content'].split('\n')
  107. gpt_steps = []
  108. for l,j in enumerate(gpt_output):
  109. if j[0:4] == "Step":
  110. gpt_steps.append(gpt_output[l])
  111. module_list = match_module_seq(gpt_steps, sentence_model)
  112. # print(module_list)
  113. # break
  114. if len(module_list) > 0 and whole_module_seq_filter(module_list, test_task_idx[i]):
  115. seqCombination.construct_module_seq(module_list)
  116. for idx, batch in enumerate(test_dataloaders[i]):
  117. inputs = list(batch['input'][0])
  118. predictions = seqCombination.run_module_seq(inputs)
  119. if 0<=test_task_idx[i]<=14:
  120. outputs = list(batch['output'][0])
  121. dist = image_similarity(predictions, outputs, vit, vit_extractor)
  122. task_rewards.append(dist/100)
  123. elif 15<=test_task_idx[i]<=104 or 107<=task_idx[i]:
  124. outputs = list(batch['output'][0])
  125. f1 = np.mean(txt_eval(predictions, outputs, bertscore, device=eval_device))
  126. task_rewards.append(f1)
  127. else:
  128. clip_score = clip_score(predictions, inputs)
  129. task_rewards.append(clip_score.detach()/100)
  130. ave_task_reward = np.mean(task_rewards)
  131. seqCombination.close_module_seq()
  132. else:
  133. ave_task_reward = 0
  134. if 0 <=test_task_idx[i] <=14:
  135. similairies.append(ave_task_reward)
  136. elif 15<=test_task_idx[i]<=104 or 107<=test_task_idx[i]:
  137. berts.append(ave_task_reward)
  138. else:
  139. clips.append(ave_task_reward)
  140. rewards.append(ave_task_reward)
  141. print("Finished testing!")
  142. print("Evaluation Results: ", np.mean(clips), np.mean(berts), np.mean(similairies), np.mean(rewards))