agi_utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. """
  2. Copyright 2023 Yingqiang Ge
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. __author__ = "Yingqiang Ge"
  14. __copyright__ = "Copyright 2023, OpenAGI"
  15. __date__ = "2023/04/10"
  16. __license__ = "Apache 2.0"
  17. __version__ = "0.0.1"
  18. from evaluate import load
  19. import numpy as np
  20. import torch
  21. from sentence_transformers import SentenceTransformer, util
  22. def txt_eval(predictions, references, bertscore, device="cuda"):
  23. score = bertscore.compute(
  24. predictions=predictions,
  25. references=references,
  26. lang="en",
  27. model_type="microsoft/deberta-xlarge-mnli",
  28. device=device)["f1"]
  29. return score
  30. def txt_loader(path):
  31. text = []
  32. with open(path) as f:
  33. lines = f.readlines()
  34. for line in lines:
  35. text.append(line)
  36. f.close()
  37. return text
  38. def image_similarity(im1, im2, model, extractor):
  39. batch_size = len(im1)
  40. # Load two images
  41. img1 = extractor(im1, return_tensors="pt")
  42. img2 = extractor(im2, return_tensors="pt")
  43. # Preprocess the images and get their embeddings
  44. with torch.no_grad():
  45. emb1 = model(img1.pixel_values)[0].squeeze().numpy()
  46. emb2 = model(img2.pixel_values)[0].squeeze().numpy()
  47. # Compute the cosine similarity between the embeddings
  48. dist = np.mean(np.array([np.linalg.norm(emb1[i] - emb2[i], ord='fro') for i in range(batch_size)]))
  49. return dist
  50. def module_seq_filter(module_seq, task_id):
  51. io_dict = {
  52. "Colorization":['image','image'],
  53. "Image Denoising":['image','image'],
  54. "Image Deblurring":['image','image'],
  55. "Image Super Resolution":['image','image'],
  56. "Image Classification":['image','text'],
  57. "Image Captioning":['image','text'],
  58. "Object Detection":['image','text'],
  59. "Text Summarization":['text','text'],
  60. "Text Generation":['text','text'],
  61. "Machine Translation":['text','text'],
  62. "Fill Mask":['text','text'],
  63. "Sentiment Analysis":['text','text'],
  64. "Text to Image Generation":['text','image'],
  65. "Question Answering":['text-text','text'],
  66. "Visual Question Answering":['image-text','text']
  67. }
  68. module_seq_list = module_seq.split(", ")
  69. input_type = io_dict[module_seq_list[0]][0]
  70. output_type = io_dict[module_seq_list[-1]][1]
  71. if input_type == "image" and output_type == "image" and 0<=task_id<=14:
  72. return True
  73. elif input_type == "image" and output_type == "text" and 15<=task_id<=104:
  74. return True
  75. elif input_type == "text" and output_type == "image" and 105<=task_id<=107:
  76. return True
  77. elif input_type == "text" and output_type == "text" and 108<=task_id<=125:
  78. return True
  79. elif input_type == "image-text" and output_type == "text" and 126<=task_id<=170:
  80. return True
  81. elif input_type == "text-text" and output_type == "text" and 171<=task_id<=188:
  82. return True
  83. else:
  84. return False
  85. def whole_module_seq_filter(module_seq, task_id):
  86. io_dict = {
  87. "Colorization":['image','image'],
  88. "Image Denoising":['image','image'],
  89. "Image Deblurring":['image','image'],
  90. "Image Super Resolution":['image','image'],
  91. "Image Classification":['image','text'],
  92. "Image Captioning":['image','text'],
  93. "Object Detection":['image','text'],
  94. "Text Summarization":['text','text'],
  95. "Text Generation":['text','text'],
  96. "Machine Translation":['text','text'],
  97. "Fill Mask":['text','text'],
  98. "Sentiment Analysis":['text','text'],
  99. "Text to Image Generation":['text','image'],
  100. "Question Answering":['text-text','text'],
  101. "Visual Question Answering":['image-text','text']
  102. }
  103. module_seq_list = module_seq.split(", ")
  104. condition_1 = None
  105. for i, m in enumerate(module_seq_list):
  106. if i < len(module_seq_list)-1 and io_dict[m][1] != io_dict[module_seq_list[i+1]][0]:
  107. condition_1 = False
  108. break
  109. else:
  110. condition_1 = True
  111. condition_2 = None
  112. input_type = io_dict[module_seq_list[0]][0]
  113. output_type = io_dict[module_seq_list[-1]][1]
  114. if input_type == "image" and output_type == "image" and 0<=task_id<=14:
  115. condition_2 = True
  116. elif input_type == "image" and output_type == "text" and 15<=task_id<=104:
  117. condition_2 = True
  118. elif input_type == "text" and output_type == "image" and 105<=task_id<=107:
  119. condition_2 = True
  120. elif input_type == "text" and output_type == "text" and 108<=task_id<=125:
  121. condition_2 = True
  122. elif input_type == "image-text" and output_type == "text" and 126<=task_id<=170:
  123. condition_2 = True
  124. elif input_type == "text-text" and output_type == "text" and 171<=task_id<=188:
  125. condition_2 = True
  126. else:
  127. condition_2 = False
  128. return condition_1 and condition_2
  129. def match_module_seq(model_steps, sentence_model):
  130. module_seq = ""
  131. for i in range(len(model_steps)):
  132. sentences1 = [model_steps[i]]*15
  133. sentences2 = ["Image Classification","Colorization","Object Detection",\
  134. "Image Super Resolution","Image Captioning","Image Deblurring",\
  135. "Image Denoising","Text to Image Generation","Visual Question Answering",\
  136. "Sentiment Analysis","Question Answering","Text Summarization",\
  137. "Text Generation","Machine Translation","Fill Mask"]
  138. #Compute embedding for both lists
  139. embeddings1 = sentence_model.encode(sentences1, convert_to_tensor=True)#.to(device_)
  140. embeddings2 = sentence_model.encode(sentences2, convert_to_tensor=True)#.to(device_)
  141. #Compute cosine-similarities
  142. cosine_scores = util.cos_sim(embeddings1, embeddings2)
  143. similarities = torch.stack([cosine_scores[i][i] for i in range(15)])
  144. module_index = torch.argmax(similarities).item()
  145. module_seq += sentences2[module_index] + ", "
  146. # print(similarities[module_index])
  147. # print(sentences2[module_index])
  148. #Output the pairs with their score
  149. # for i in range(len(sentences1)):
  150. # print("{} \t\t {} \t\t Score: {:.4f}".format(sentences1[i], sentences2[i], cosine_scores[i][i]))
  151. module_seq = module_seq.strip()[:-1]
  152. return module_seq