rltf_schema_flan_t5.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """
  2. Copyright 2023 Yingqiang Ge
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. __author__ = "Yingqiang Ge"
  14. __copyright__ = "Copyright 2023, OpenAGI"
  15. __date__ = "2023/04/10"
  16. __license__ = "Apache 2.0"
  17. __version__ = "0.0.1"
  18. from torch.utils.data import DataLoader
  19. import torch
  20. from transformers import (
  21. T5Tokenizer,
  22. T5ForConditionalGeneration,
  23. AutoModel,
  24. AutoFeatureExtractor,
  25. AutoConfig
  26. )
  27. import os
  28. from benchmark_tasks.generate_model_seq import SeqGen
  29. import torch.optim as optim
  30. from benchmark_tasks.general_dataset import GeneralDataset
  31. from benchmark_tasks.agi_utils import *
  32. from benchmark_tasks.combine_model_seq import SeqCombine
  33. import numpy as np
  34. from IPython.utils import io
  35. import random
  36. from tqdm import tqdm
  37. from evaluate import load
  38. from torchvision import transforms
  39. from torchmetrics.multimodal import CLIPScore
  40. import argparse
  41. from undecorated import undecorated
  42. from benchmark_tasks.finetune.utils import construct_optimizer
  43. from types import MethodType
  44. def run_rltf_flan_t5(args):
  45. """
  46. load training and test datasets
  47. """
  48. data_path = args.data_path
  49. llm_device = args.llm_device
  50. eval_device = args.eval_device
  51. num_seq = args.num_seq
  52. llm_name = args.llm_name
  53. epochs = args.epochs
  54. epsilon = args.epsilon
  55. decay_rate = args.decay_rate
  56. task_discriptions = txt_loader(data_path + "task_description.txt")
  57. training_task_idx = [7,20,30,40,50,60]
  58. # test_task_idx = [2,3,10,15,20,35,45,55,65,70,70,90,106,107]
  59. test_task_idx = [2]
  60. training_dataloaders = []
  61. test_dataloaders = []
  62. for i in training_task_idx:
  63. dataset = GeneralDataset(i, data_path)
  64. dataloader = DataLoader(dataset, batch_size=args.batch_size)
  65. training_dataloaders.append(dataloader)
  66. for j in test_task_idx:
  67. dataset = GeneralDataset(j,data_path)
  68. dataloader = DataLoader(dataset, batch_size=args.batch_size)
  69. test_dataloaders.append(dataloader)
  70. training_tasks = [task_discriptions[i].strip() for i in training_task_idx]
  71. test_tasks = [task_discriptions[j].strip() for j in test_task_idx]
  72. # print(training_tasks)
  73. clip_score = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16")
  74. # Load a pre-trained Vision Transformer model and its feature extractor
  75. vit_ckpt = "nateraw/vit-base-beans"
  76. vit = AutoModel.from_pretrained(vit_ckpt)
  77. vit.eval()
  78. vit_extractor = AutoFeatureExtractor.from_pretrained(vit_ckpt)
  79. f = transforms.ToPILImage()
  80. bertscore = load("bertscore")
  81. seqCombination = SeqCombine(args)
  82. tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-large')
  83. config = AutoConfig.from_pretrained('google/flan-t5-large')
  84. backbone_model = T5ForConditionalGeneration(config=config)
  85. backbone_model.load_state_dict(torch.load("benchmark_tasks/finetune/10_shot_finetuned.pt", map_location="cpu"))
  86. backbone_model = backbone_model.to(llm_device)
  87. seqGen = SeqGen(backbone_model, tokenizer, llm_device)
  88. generate_with_grad = undecorated(seqGen.model.generate)
  89. seqGen.model.generate_with_grad = MethodType(generate_with_grad, seqGen.model)
  90. # optimizer = optim.SGD(seqGen.model.parameters(), lr=0.0001, momentum=0.9)
  91. optimizer, scheduler = construct_optimizer(args, seqGen.model, 20)
  92. ## Training
  93. for e in range(epochs):
  94. baseline = 0
  95. rewards = []
  96. print('num of epoch ' + str(e+1))
  97. for i, task_description in enumerate(training_tasks):
  98. task_rewards = []
  99. # print(task_description)
  100. optimizer.zero_grad()
  101. generated_module_seq, log_prob = seqGen.generate_sequence([training_tasks[i]],\
  102. module_length=10, \
  103. beam_size=30, \
  104. num_seq=30,\
  105. top_k=5,\
  106. top_p=0.5,\
  107. temperature=0.9,\
  108. constraint=[0,100],\
  109. num_beam_groups=1)
  110. if random.random() >= epsilon:
  111. action = torch.argmax(torch.stack(log_prob).detach())
  112. else:
  113. action = torch.distributions.Categorical(torch.stack(log_prob).detach()).sample()
  114. # decrease epsilon by the decay rate after each step
  115. epsilon *= decay_rate
  116. module_list = generated_module_seq[action][:-1]
  117. if module_seq_filter(module_list, training_task_idx[i]):
  118. # print("Module Sequence: " + module_list)
  119. seqCombination.construct_module_seq(module_list)
  120. for idx, batch in enumerate(tqdm(training_dataloaders[i])):
  121. inputs = list(batch['input'][0])
  122. seqCombination.construct_module_seq(module_list)
  123. predictions = seqCombination.run_module_seq(inputs)
  124. if 0 <= training_task_idx[i] <= 14:
  125. outputs = list(batch['output'][0])
  126. dist = image_similarity(predictions, outputs, vit, vit_extractor)
  127. task_rewards.append(dist/100)
  128. elif 15 <= training_task_idx[i] <= 104 or 107 <= task_idx[i]:
  129. outputs = list(batch['output'][0])
  130. f1 = np.mean(txt_eval(predictions, outputs, bertscore, device=eval_device))
  131. task_rewards.append(f1)
  132. else:
  133. clip_score = score = clip_score(predictions, inputs)
  134. task_rewards.append(clip_score.detach()/100)
  135. ave_task_reward = np.mean(task_rewards)
  136. # print("Average reward on current task: " + str(ave_task_reward))
  137. rewards.append(ave_task_reward)
  138. seqCombination.close_module_seq()
  139. else:
  140. rewards.append(-1)
  141. avg_reward = np.mean(rewards)
  142. print("Average reward: " + str(avg_reward))
  143. loss = -log_prob[action] * (avg_reward - baseline)
  144. print("Loss: "+ str(loss.item()))
  145. loss.backward()
  146. optimizer.step()
  147. scheduler.step()
  148. # baseline = avg_reward
  149. print("Finished training!")
  150. ## Testing
  151. rewards = []
  152. clips = []
  153. berts = []
  154. similairies = []
  155. for i, task_description in enumerate(test_tasks):
  156. task_rewards = []
  157. with torch.no_grad():
  158. generated_module_seq, log_prob = seqGen.generate_sequence([test_tasks[i]],\
  159. module_length=10, \
  160. beam_size=30, \
  161. num_seq=30,\
  162. top_k=5,\
  163. top_p=0.5,\
  164. temperature=0.9,\
  165. constraint=[0,100],\
  166. num_beam_groups=1)
  167. action = torch.argmax(torch.stack(log_prob).detach())
  168. module_list = generated_module_seq[action][:-1]
  169. # print(task_description)
  170. # print("Module Sequence: " + module_list)
  171. if module_seq_filter(module_list, test_task_idx[i]):
  172. seqCombination.construct_module_seq(module_list)
  173. for idx, batch in enumerate(tqdm(test_dataloaders[i])):
  174. inputs = list(batch['input'][0])
  175. predictions = seqCombination.run_module_seq(inputs)
  176. if 0 <= test_task_idx[i] <= 14:
  177. outputs = list(batch['output'][0])
  178. dist = image_similarity(predictions, outputs, vit, vit_extractor)
  179. task_rewards.append(dist/100)
  180. elif 15 <= test_task_idx[i] <= 104 or 107 <= test_task_idx[i]:
  181. outputs = list(batch['output'][0])
  182. f1 = np.mean(txt_eval(predictions, outputs, bertscore))
  183. task_rewards.append(f1)
  184. else:
  185. score = clip_score(predictions, inputs)
  186. task_rewards.append(score.detach()/100)
  187. ave_task_reward = np.mean(task_rewards)
  188. seqCombination.close_module_seq()
  189. else:
  190. ave_task_reward = 0
  191. if 0 <= test_task_idx[i] <= 14:
  192. similairies.append(ave_task_reward)
  193. elif 15 <= test_task_idx[i] <= 104 or 107 <= test_task_idx[i]:
  194. berts.append(ave_task_reward)
  195. else:
  196. clips.append(ave_task_reward)
  197. rewards.append(ave_task_reward)
  198. print("Finished testing!")
  199. print("Experimental results: ", np.mean(clips), np.mean(berts), np.mean(similairies), np.mean(rewards))