generate_model_seq.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. """
  2. Copyright 2023 Yingqiang Ge
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. __author__ = "Wenyue Hua, Yingqiang Ge"
  14. __copyright__ = "Copyright 2023, OpenAGI"
  15. __date__ = "2023/04/12"
  16. __license__ = "Apache 2.0"
  17. __version__ = "0.0.1"
  18. from typing import Dict, List
  19. from types import MethodType
  20. import torch
  21. from undecorated import undecorated
  22. class Trie(object):
  23. def __init__(self, sequences: List[List[int]] = []):
  24. self.trie_dict = {}
  25. self.len = 0
  26. if sequences:
  27. for sequence in sequences:
  28. Trie._add_to_trie(sequence, self.trie_dict)
  29. self.len += 1
  30. self.append_trie = None
  31. self.bos_token_id = None
  32. def append(self, trie, bos_token_id):
  33. self.append_trie = trie
  34. self.bos_token_id = bos_token_id
  35. def add(self, sequence: List[int]):
  36. Trie._add_to_trie(sequence, self.trie_dict)
  37. self.len += 1
  38. def get(self, prefix_sequence: List[int]):
  39. return Trie._get_from_trie(
  40. prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
  41. )
  42. @staticmethod
  43. def load_from_dict(trie_dict):
  44. trie = Trie()
  45. trie.trie_dict = trie_dict
  46. trie.len = sum(1 for _ in trie)
  47. return trie
  48. @staticmethod
  49. def _add_to_trie(sequence: List[int], trie_dict: Dict):
  50. if sequence:
  51. if sequence[0] not in trie_dict:
  52. trie_dict[sequence[0]] = {}
  53. Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
  54. @staticmethod
  55. def _get_from_trie(
  56. prefix_sequence: List[int],
  57. trie_dict: Dict,
  58. append_trie=None,
  59. bos_token_id: int = None,
  60. ):
  61. if len(prefix_sequence) == 0:
  62. output = list(trie_dict.keys())
  63. if append_trie and bos_token_id in output:
  64. output.remove(bos_token_id)
  65. output += list(append_trie.trie_dict.keys())
  66. return output
  67. elif prefix_sequence[0] in trie_dict:
  68. return Trie._get_from_trie(
  69. prefix_sequence[1:],
  70. trie_dict[prefix_sequence[0]],
  71. append_trie,
  72. bos_token_id,
  73. )
  74. else:
  75. if append_trie:
  76. return append_trie.get(prefix_sequence)
  77. else:
  78. return []
  79. def __iter__(self):
  80. def _traverse(prefix_sequence, trie_dict):
  81. if trie_dict:
  82. for next_token in trie_dict:
  83. yield from _traverse(
  84. prefix_sequence + [next_token], trie_dict[next_token]
  85. )
  86. else:
  87. yield prefix_sequence
  88. return _traverse([], self.trie_dict)
  89. def __len__(self):
  90. return self.len
  91. def __getitem__(self, value):
  92. return self.get(value)
  93. """
  94. def prefix_allowed_tokens_fn(candidate_trie):
  95. def prefix_allowed_tokens(batch_id, sentence):
  96. sentence = sentence.tolist()
  97. if sentence[-1] == 6: # last token is ","
  98. trie_out = candidate_trie.get([0])
  99. elif 6 not in sentence: # "," is not in the generated sentence
  100. trie_out = candidate_trie.get(sentence)
  101. else:
  102. assert 6 in sentence
  103. indices = [i for i, c in enumerate(sentence) if c == 6]
  104. sentence = sentence[indices[-1] + 1 :]
  105. trie_out = candidate_trie.get([0] + sentence)
  106. return trie_out
  107. return prefix_allowed_tokens
  108. """
  109. class SeqGen:
  110. def __init__(self, model, tokenizer, device):
  111. i2i_tasks = {
  112. "input": "image",
  113. "output": "image",
  114. "task_list": [
  115. "Colorization,",
  116. "Image Denoising,",
  117. "Image Deblurring,",
  118. "Image Super Resolution,"
  119. ],
  120. }
  121. i2t_tasks = {
  122. "input": "image",
  123. "output": "text",
  124. "task_list": [
  125. "Image Classification,",
  126. "Image Captioning,",
  127. "Object Detection,"
  128. ],
  129. }
  130. t2t_tasks = {
  131. "input": "text",
  132. "output": "text",
  133. "task_list": [
  134. "Text Summarization,",
  135. "Text Generation,",
  136. "Machine Translation,",
  137. "Fill Mask,",
  138. "Sentiment Analysis,"
  139. ],
  140. }
  141. t2i_tasks = {
  142. "input": "text",
  143. "output": "image",
  144. "task_list": ["Text to Image Generation,"],
  145. }
  146. tt2t_tasks = {
  147. "input": "text+text",
  148. "output": "text",
  149. "task_list": [
  150. "Question Answering,",
  151. ],
  152. }
  153. it2t_tasks = {
  154. "input": "image+text",
  155. "output": "text",
  156. "task_list": [
  157. "Visual Question Answering,",
  158. ],
  159. }
  160. self.candidates = [
  161. i2i_tasks,
  162. i2t_tasks,
  163. t2t_tasks,
  164. t2i_tasks,
  165. # tt2t_tasks,
  166. # i2it_tasks,
  167. # it2i_tasks,
  168. # it2t_tasks,
  169. ]
  170. self.device = device
  171. self.model = model#.to(self.device)
  172. self.tokenizer = tokenizer
  173. def find_last_task(self,sentence):
  174. if sentence.count(6) == 1:
  175. last_cand = sentence[1 : sentence.index(6) + 1]
  176. return last_cand
  177. indices = [i for i, c in enumerate(sentence) if c == 6]
  178. last_cand = sentence[indices[-2] + 1 : indices[-1] + 1 :]
  179. return last_cand
  180. def t5_prefix_allowed_tokens_fn(self, module_length, constraint):
  181. all_candidates = [
  182. a for candidate_list in self.candidates for a in candidate_list["task_list"]
  183. ]
  184. def prefix_allowed_tokens(batch_id, sentence):
  185. sentence = sentence.tolist()
  186. # print(tokenizer.decode(sentence))
  187. if sentence.count(6) == 0:
  188. all_candidate_trie = Trie(
  189. [[0] + self.tokenizer.encode("{}".format(e)) for e in all_candidates]
  190. )
  191. trie_out = all_candidate_trie.get(sentence)
  192. elif sentence[-1] == 6 and sentence.count(6) != module_length:
  193. one_cand = self.find_last_task(sentence)
  194. one_cand = self.tokenizer.decode(one_cand)
  195. next_input_type = [
  196. candidate_list
  197. for candidate_list in self.candidates
  198. if one_cand in candidate_list["task_list"]
  199. ][0]["output"]
  200. # find corresponding list
  201. one_candidate_list = [
  202. candidate
  203. for candidate_list in [
  204. candidate_list["task_list"]
  205. for candidate_list in self.candidates
  206. if candidate_list["input"] == next_input_type
  207. ]
  208. for candidate in candidate_list
  209. ]
  210. # remove candidates that occurred
  211. remove_repetition = [
  212. candidate
  213. for candidate in one_candidate_list
  214. if candidate not in self.tokenizer.decode(sentence)
  215. ]
  216. # print(remove_repetition)
  217. # if sentence.count(6) == 1:
  218. # remove_repetition_ = remove_repetition[constraint[0]:constraint[1]] + ["</s>"]
  219. # one_candidate_trie = Trie(
  220. # [[0] + self.tokenizer.encode("{}".format(e)) for e in remove_repetition_]
  221. # )
  222. # else:
  223. # one_candidate_trie = Trie(
  224. # [[0] + self.tokenizer.encode("{}".format(e)) for e in remove_repetition + ["</s>"]]
  225. # )
  226. one_candidate_trie = Trie([[0] + self.tokenizer.encode("{}".format(e)) for e in remove_repetition + ["</s>"]])
  227. trie_out = one_candidate_trie.get([0])
  228. elif sentence[-1] != 6 and sentence.count(6) != module_length:
  229. one_cand = self.find_last_task(sentence)
  230. one_cand = self.tokenizer.decode(one_cand)
  231. input_type = [
  232. candidate_list
  233. for candidate_list in self.candidates
  234. if one_cand in candidate_list["task_list"]
  235. ][0]["output"]
  236. # find corresponding list
  237. one_candidate_list = [
  238. candidate
  239. for candidate_list in [
  240. candidate_list["task_list"]
  241. for candidate_list in self.candidates
  242. if candidate_list["input"] == input_type
  243. ]
  244. for candidate in candidate_list
  245. ]
  246. # remove candidates that occurred
  247. remove_repetition = [
  248. candidate
  249. for candidate in one_candidate_list
  250. if candidate not in self.tokenizer.decode(sentence)
  251. ]
  252. one_candidate_trie = Trie(
  253. [[0] + self.tokenizer.encode("{}".format(e)) for e in remove_repetition]
  254. )
  255. indices = [i for i, c in enumerate(sentence) if c == 6]
  256. sentence = sentence[indices[-1] + 1 :]
  257. trie_out = one_candidate_trie.get([0] + sentence)
  258. elif sentence.count(6) == module_length:
  259. candidate_trie = Trie(
  260. [[0] + self.tokenizer.encode("{}".format(e)) for e in ["</s>"]]
  261. )
  262. trie_out = candidate_trie.get([0])
  263. return trie_out
  264. return prefix_allowed_tokens
  265. def generate_sequence(self, input_s, \
  266. module_length=5, \
  267. beam_size=4, \
  268. num_seq=1, \
  269. top_k=5, \
  270. top_p=0.9, \
  271. temperature=0.7, \
  272. constraint=[0,100], \
  273. num_beam_groups=2):
  274. output_sequences = []
  275. log_probs = []
  276. # output_scores = []
  277. # output_results = []
  278. # for input_s in input_sentences:
  279. input_ids = self.tokenizer.batch_encode_plus(input_s, padding="longest", return_tensors="pt")["input_ids"]
  280. input_ids = input_ids.to(self.device)
  281. prefix_allowed_tokens = self.t5_prefix_allowed_tokens_fn(module_length, constraint=constraint)
  282. output = self.model.generate_with_grad(
  283. input_ids,
  284. max_length=80,
  285. min_length=1,
  286. prefix_allowed_tokens_fn=prefix_allowed_tokens,
  287. num_beams=beam_size,
  288. num_return_sequences=num_seq,
  289. return_dict_in_generate=True,
  290. output_scores=True,
  291. output_hidden_states=True,
  292. renormalize_logits=True,
  293. # do_sample=True,
  294. top_k=top_k,
  295. top_p=top_p,
  296. temperature=temperature,
  297. early_stopping=True,
  298. # no_repeat_ngram_size=no_repeat_ngram_size,
  299. num_beam_groups=num_beam_groups,
  300. )
  301. # print(output["sequences"])
  302. output_ids = output["sequences"][:,1:]
  303. output_result = [s for s in self.tokenizer.batch_decode(output_ids)]
  304. output_sequence = [s.replace("<pad>", "").replace("</s>", "") for s in self.tokenizer.batch_decode(output_ids)]
  305. output_sequences.append(output_sequence)
  306. # B * length tuple of (beam_size * vocab_size) tensor
  307. scores = output["scores"]
  308. if beam_size > 1:
  309. output_score = output.sequences_scores
  310. if num_seq == 1:
  311. length = output_ids.size(-1)
  312. logprob = 0
  313. # B * beam_size * length
  314. beam_indices = output["beam_indices"][0]
  315. for l in range(length):
  316. beam_index = beam_indices[l]
  317. score = scores[l][beam_index]
  318. # score = toenrch.exp(scores[l][beam_index]) # unnormalized prob
  319. # score /= score.sum() # normalized prob
  320. # score = torch.log(score) # normalized log prob
  321. if self.tokenizer.decode(output_ids[0][l]) == "</s>":
  322. continue
  323. logprob += score[output_ids[0][l]]
  324. # else:
  325. # logprob = 0
  326. loss = logprob#/length
  327. log_probs.append([loss])
  328. else:
  329. loss = []
  330. for i in range(num_seq):
  331. if 0 in output_ids[i]:
  332. one_length = output_ids[i][
  333. : (output_ids[i] == 0).nonzero(as_tuple=True)[0].tolist()[0]
  334. ].size(-1)
  335. else:
  336. one_length = output_ids[i].size(-1)
  337. logprob = 0
  338. # B * num_seq * length
  339. beam_indices = output["beam_indices"][i]
  340. for l in range(one_length):
  341. beam_index = beam_indices[l]
  342. score = scores[l][beam_index]
  343. # score = torch.exp(scores[l][beam_index]) # unnormalized prob
  344. # score /= score.sum() # normalized prob
  345. # score = torch.log(score) # normalized log prob
  346. if self.tokenizer.decode(output_ids[i][l]) == "</s>":
  347. continue
  348. logprob += score[output_ids[i][l]]
  349. loss.append(logprob)#/one_length)
  350. log_probs.append(loss)
  351. else:
  352. logprob = 0
  353. length = output_ids.size(-1)
  354. for l in range(length):
  355. score = scores[l][0]
  356. if self.tokenizer.decode(output_ids[0][l]) == "</s>":
  357. continue
  358. logprob += score[output_ids[0][l]]
  359. loss = logprob#/length
  360. log_probs.append([loss])
  361. # return output_sequence, loss, output_score, output_result #, prob, output_score
  362. return output_sequence, loss