generate_model_seq_llama.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747
  1. """
  2. Copyright 2023 Yingqiang Ge
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. __author__ = "Wenyue Hua, Yingqiang Ge"
  14. __copyright__ = "Copyright 2023, OpenAGI"
  15. __date__ = "2023/05/13"
  16. __license__ = "Apache 2.0"
  17. __version__ = "0.0.1"
  18. from typing import Dict, List
  19. from types import MethodType
  20. import torch
  21. from undecorated import undecorated
  22. import os
  23. from peft import PeftModel, PeftModelForCausalLM
  24. from transformers import AutoTokenizer, AutoModelForCausalLM
  25. class Trie(object):
  26. def __init__(self, sequences: List[List[int]] = []):
  27. self.trie_dict = {}
  28. self.len = 0
  29. if sequences:
  30. for sequence in sequences:
  31. Trie._add_to_trie(sequence, self.trie_dict)
  32. self.len += 1
  33. self.append_trie = None
  34. self.bos_token_id = None
  35. def append(self, trie, bos_token_id):
  36. self.append_trie = trie
  37. self.bos_token_id = bos_token_id
  38. def add(self, sequence: List[int]):
  39. Trie._add_to_trie(sequence, self.trie_dict)
  40. self.len += 1
  41. def get(self, prefix_sequence: List[int]):
  42. return Trie._get_from_trie(
  43. prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
  44. )
  45. @staticmethod
  46. def load_from_dict(trie_dict):
  47. trie = Trie()
  48. trie.trie_dict = trie_dict
  49. trie.len = sum(1 for _ in trie)
  50. return trie
  51. @staticmethod
  52. def _add_to_trie(sequence: List[int], trie_dict: Dict):
  53. if sequence:
  54. if sequence[0] not in trie_dict:
  55. trie_dict[sequence[0]] = {}
  56. Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
  57. @staticmethod
  58. def _get_from_trie(
  59. prefix_sequence: List[int],
  60. trie_dict: Dict,
  61. append_trie=None,
  62. bos_token_id: int = None,
  63. ):
  64. if len(prefix_sequence) == 0:
  65. output = list(trie_dict.keys())
  66. if append_trie and bos_token_id in output:
  67. output.remove(bos_token_id)
  68. output += list(append_trie.trie_dict.keys())
  69. return output
  70. elif prefix_sequence[0] in trie_dict:
  71. return Trie._get_from_trie(
  72. prefix_sequence[1:],
  73. trie_dict[prefix_sequence[0]],
  74. append_trie,
  75. bos_token_id,
  76. )
  77. else:
  78. if append_trie:
  79. return append_trie.get(prefix_sequence)
  80. else:
  81. return []
  82. def __iter__(self):
  83. def _traverse(prefix_sequence, trie_dict):
  84. if trie_dict:
  85. for next_token in trie_dict:
  86. yield from _traverse(
  87. prefix_sequence + [next_token], trie_dict[next_token]
  88. )
  89. else:
  90. yield prefix_sequence
  91. return _traverse([], self.trie_dict)
  92. def __len__(self):
  93. return self.len
  94. def __getitem__(self, value):
  95. return self.get(value)
  96. class SeqGen:
  97. def __init__(self, model, tokenizer):
  98. i2i_tasks = {
  99. "input": "image",
  100. "output": "image",
  101. "task_list": [
  102. "Colorization,",
  103. "Image Denoising,",
  104. "Image Deblurring,",
  105. "Image Super Resolution,"
  106. ],
  107. }
  108. i2t_tasks = {
  109. "input": "image",
  110. "output": "text",
  111. "task_list": [
  112. "Image Classification,",
  113. "Image Captioning,",
  114. "Object Detection,"
  115. ],
  116. }
  117. t2t_tasks = {
  118. "input": "text",
  119. "output": "text",
  120. "task_list": [
  121. "Text Summarization,",
  122. "Text Generation,",
  123. "Machine Translation,",
  124. "Fill Mask,",
  125. "Sentiment Analysis,"
  126. ],
  127. }
  128. t2i_tasks = {
  129. "input": "text",
  130. "output": "image",
  131. "task_list": ["Text to Image Generation,"],
  132. }
  133. tt2t_tasks = {
  134. "input": "text+text",
  135. "output": "text",
  136. "task_list": [
  137. "Question Answering,",
  138. ],
  139. }
  140. it2t_tasks = {
  141. "input": "image+text",
  142. "output": "text",
  143. "task_list": [
  144. "Visual Question Answering,",
  145. ],
  146. }
  147. self.candidates = [
  148. i2i_tasks,
  149. i2t_tasks,
  150. t2t_tasks,
  151. t2i_tasks,
  152. # tt2t_tasks,
  153. # i2it_tasks,
  154. # it2i_tasks,
  155. # it2t_tasks,
  156. ]
  157. self.model = model
  158. self.tokenizer = tokenizer
  159. def find_last_task(self, sentence):
  160. if sentence.count(29892) == 1:
  161. last_cand = sentence[1 : sentence.index(29892) + 1]
  162. if 1723 in last_cand:
  163. last_cand.remove(1723)
  164. if 313 in last_cand:
  165. last_cand.remove(313)
  166. return last_cand
  167. indices = [i for i, c in enumerate(sentence) if c == 29892]
  168. last_cand = sentence[indices[-2] + 1 : indices[-1] + 1 :]
  169. if 1723 in last_cand:
  170. last_cand.remove(1723)
  171. if 313 in last_cand:
  172. last_cand.remove(313)
  173. if 1 in last_cand:
  174. last_cand.remove(1)
  175. if 0 in last_cand:
  176. last_cand.remove(0)
  177. return last_cand
  178. def count_parallel_length(self, sentence):
  179. if sentence.count(313) == 2 and sentence.count(1723) == 2:
  180. left_parenthesis_position = [i for i, c in enumerate(sentence) if c == 313]
  181. right_parenthesis_position = [i for i, c in enumerate(sentence) if c == 1723]
  182. first_parallel_length = sentence[
  183. left_parenthesis_position[0] : right_parenthesis_position[0]
  184. ].count(29892)
  185. second_parallel_length = sentence[
  186. left_parenthesis_position[1] : right_parenthesis_position[1]
  187. ].count(29892)
  188. rest = sentence[right_parenthesis_position[1] :].count(29892)
  189. if rest == 0:
  190. return 0
  191. return max(first_parallel_length, second_parallel_length) + rest
  192. elif sentence.count(1723) == 0:
  193. return sentence.count(29892)
  194. else:
  195. assert sentence.count(1723) == 1
  196. second_parallel = sentence[sentence.index(1723) + 1 :]
  197. return second_parallel.count(29892)
  198. def find_second_task(self, sentence):
  199. assert 1723 in sentence
  200. end_position = sentence.index(1723)
  201. start_positions = [1] + [i for i, c in enumerate(sentence[:end_position]) if c == 29892]
  202. start_position = start_positions[-2]
  203. second_cand = sentence[start_position:end_position]
  204. if 1723 in second_cand:
  205. second_cand.remove(1723)
  206. if 313 in second_cand:
  207. second_cand.remove(313)
  208. if second_cand[0] == 29892:
  209. second_cand = second_cand[1:]
  210. return second_cand
  211. def check_two_input_types(self, sentence):
  212. first_cand = self.find_last_task(sentence)
  213. first_cand = self.tokenizer.decode(first_cand).strip()
  214. first_input_type = [
  215. candidate_list
  216. for candidate_list in self.candidates
  217. if first_cand in candidate_list["task_list"]
  218. ][0]["output"]
  219. second_cand = self.find_second_task(sentence)
  220. second_cand = self.tokenizer.decode(second_cand).strip()
  221. second_input_type = [
  222. candidate_list
  223. for candidate_list in self.candidates
  224. if second_cand in candidate_list["task_list"]
  225. ][0]["output"]
  226. # find corresponding list
  227. one_candidate_list = [
  228. candidate
  229. for candidate_list in [
  230. candidate_list["task_list"]
  231. for candidate_list in self.candidates
  232. if second_input_type + "+" + first_input_type == candidate_list["input"]
  233. or first_input_type + "+" + second_input_type == candidate_list["input"]
  234. ]
  235. for candidate in candidate_list
  236. ]
  237. # remove candidates that occurred
  238. remove_repetition = [
  239. candidate
  240. for candidate in one_candidate_list
  241. if candidate not in self.tokenizer.decode(sentence)
  242. ]
  243. one_candidate_trie = Trie(
  244. [self.tokenizer.encode("{}".format(e)) for e in remove_repetition]
  245. )
  246. indices = [i for i, c in enumerate(sentence) if c == 1723]
  247. sentence = sentence[indices[-1] + 1 :]
  248. trie_out = one_candidate_trie.get([1] + sentence)
  249. return trie_out
  250. def after_one_cand(self, sentence):
  251. one_cand = self.find_last_task(sentence)
  252. one_cand = self.tokenizer.decode(one_cand)
  253. input_type = [
  254. candidate_list
  255. for candidate_list in self.candidates
  256. if one_cand.strip() in candidate_list["task_list"]
  257. ][0]["output"]
  258. # find corresponding list
  259. one_candidate_list = [
  260. candidate
  261. for candidate_list in [
  262. candidate_list["task_list"]
  263. for candidate_list in self.candidates
  264. if candidate_list["input"] == input_type
  265. ]
  266. for candidate in candidate_list
  267. ]
  268. if sentence.count(1723) == 0 or sentence.count(1723) == 2:
  269. remove_repetition = [
  270. candidate
  271. for candidate in one_candidate_list
  272. if candidate not in self.tokenizer.decode(sentence)
  273. ]
  274. else:
  275. assert sentence.count(1723) == 1
  276. sentence = sentence[sentence.index(1723) + 1 :]
  277. remove_repetition = [
  278. candidate
  279. for candidate in one_candidate_list
  280. if candidate not in self.tokenizer.decode(sentence)
  281. ]
  282. return remove_repetition
  283. def llama_prefix_allowed_tokens_fn(self, module_length, input_ids):
  284. all_candidates = [
  285. a for candidate_list in self.candidates for a in candidate_list["task_list"]
  286. ]
  287. # pad all tasks to the same length
  288. def prefix_allowed_tokens(batch_id, sentence):
  289. sentence = sentence.tolist()
  290. # remove given prompts
  291. prompt_length = input_ids.shape[-1] # len(input_ids[batch_id])
  292. new_sentence_list = sentence[prompt_length:]
  293. if len(new_sentence_list) <= 1:
  294. all_candidate_trie = Trie(
  295. [self.tokenizer.encode("{}".format(e)) for e in all_candidates]
  296. + [self.tokenizer.encode("{}".format(e)) for e in all_candidates]
  297. )
  298. trie_out = all_candidate_trie.get(new_sentence_list)
  299. return trie_out
  300. else:
  301. sentence = torch.tensor(new_sentence_list)
  302. if sentence[1] == 313:
  303. return parenthesis_prefix_allowed_tokens(batch_id, sentence)
  304. else:
  305. return without_parenthesis_prefix_allowed_tokens(batch_id, sentence)
  306. def without_parenthesis_prefix_allowed_tokens(batch_id, sentence):
  307. sentence = sentence.tolist()
  308. if self.tokenizer.decode(sentence).count(',') == 0:
  309. all_candidate_trie = Trie(
  310. [self.tokenizer.encode("{}".format(e)) for e in all_candidates]
  311. )
  312. trie_out = all_candidate_trie.get(sentence)
  313. elif sentence[-1] == 29892 and sentence.count(29892) != module_length:
  314. one_cand = self.find_last_task(sentence)
  315. one_cand = self.tokenizer.decode(one_cand)
  316. next_input_type = [
  317. candidate_list
  318. for candidate_list in self.candidates
  319. if one_cand.strip() in candidate_list["task_list"]
  320. ][0]["output"]
  321. # find corresponding list
  322. one_candidate_list = [
  323. candidate
  324. for candidate_list in [
  325. candidate_list["task_list"]
  326. for candidate_list in self.candidates
  327. if candidate_list["input"] == next_input_type
  328. ]
  329. for candidate in candidate_list
  330. ]
  331. # remove candidates that occurred
  332. remove_repetition = [
  333. candidate
  334. for candidate in one_candidate_list
  335. if candidate not in self.tokenizer.decode(sentence)
  336. ] + ["</s>"]
  337. one_candidate_trie = Trie(
  338. [self.tokenizer.encode("{}".format(e)) for e in remove_repetition]
  339. )
  340. trie_out = one_candidate_trie.get([])
  341. elif sentence[-1] != 29892 and sentence.count(29892) != module_length:
  342. # if sentence[-1] == 0 and sentence[sentence.index(0) - 1] != 6:
  343. # sentence = sentence[: sentence.index(0)]
  344. # print(tokenizer.decode(sentence))
  345. one_cand = self.find_last_task(sentence)
  346. one_cand = self.tokenizer.decode(one_cand).strip()
  347. input_type = [
  348. candidate_list
  349. for candidate_list in self.candidates
  350. if one_cand in candidate_list["task_list"]
  351. ][0]["output"]
  352. # find corresponding list
  353. one_candidate_list = [
  354. candidate
  355. for candidate_list in [
  356. candidate_list["task_list"]
  357. for candidate_list in self.candidates
  358. if candidate_list["input"] == input_type
  359. ]
  360. for candidate in candidate_list
  361. ]
  362. # remove candidates that occurred
  363. remove_repetition = [
  364. candidate
  365. for candidate in one_candidate_list
  366. if candidate not in self.tokenizer.decode(sentence)
  367. ]
  368. one_candidate_trie = Trie(
  369. [self.tokenizer.encode("{}".format(e)) for e in remove_repetition]
  370. )
  371. indices = [i for i, c in enumerate(sentence) if c == 29892]
  372. sentence = sentence[indices[-1]+1:]
  373. # a = one_candidate_trie.get([0] + sentence)
  374. # for b in a:
  375. # print(tokenizer.decode(b))
  376. # print("***")
  377. trie_out = one_candidate_trie.get(sentence)
  378. elif sentence.count(29892) == module_length:
  379. candidate_trie = Trie(
  380. [self.tokenizer.encode("{}".format(e)) for e in ["</s>"]]
  381. )
  382. trie_out = candidate_trie.get([1])
  383. return trie_out
  384. def parenthesis_prefix_allowed_tokens(batch_id, sentence):
  385. #print(sentence)
  386. #print(tokenizer.decode(sentence))
  387. #print("***")
  388. sentence = sentence.tolist()
  389. # either begin of sentence, or finish one ()
  390. if sentence.count(29892) == 0 or (
  391. sentence[-1] == 1723 # )
  392. and sentence.count(313) == 1 # has one ()
  393. and sentence.count(1723) == 1
  394. ):
  395. all_candidate_trie = Trie(
  396. [self.tokenizer.encode("({}".format(e)) for e in all_candidates]
  397. )
  398. if 313 not in sentence:
  399. trie_out = all_candidate_trie.get([1] + sentence)
  400. else:
  401. trie_out = all_candidate_trie.get(sentence)
  402. elif sentence[-1] != 29892 and self.count_parallel_length(sentence) < module_length:
  403. if sentence[-1] == 1723 and sentence.count(1723) == 2:
  404. # check two input types and generate without any () in the future
  405. trie_out = self.check_two_input_types(sentence)
  406. else:
  407. # keep generating the unfinished task, can generate ) or not, not necessary
  408. if sentence.count(1723) == 0:
  409. remove_repetition = self.after_one_cand(sentence)
  410. one_candidate_trie = Trie(
  411. [
  412. self.tokenizer.encode("{}".format(e))
  413. for e in remove_repetition
  414. ]
  415. )
  416. indices = [i for i, c in enumerate(sentence) if c == 29892]
  417. sentence = sentence[indices[-1] + 1 :]
  418. trie_out = one_candidate_trie.get([1]+sentence)
  419. elif sentence.count(1723) == 1:
  420. # the first task of the second parallel
  421. rebegin_sentence = sentence[sentence.index(1723) + 1 :]
  422. if 29892 not in rebegin_sentence:
  423. all_candidate_trie = Trie(
  424. [
  425. self.tokenizer.encode("({}".format(e))
  426. for e in all_candidates
  427. ]
  428. )
  429. sentence = sentence[sentence.index(1723) + 2 :]
  430. trie_out = all_candidate_trie.get([1]+sentence)
  431. else:
  432. # the non-first task of the second parallel
  433. remove_repetition = self.after_one_cand(sentence)
  434. if sentence[-1] == 1723 and sentence.count(1723) == 1:
  435. one_candidate_trie = Trie(
  436. [[1]]
  437. + [
  438. self.tokenizer.encode("{}".format(e))
  439. for e in remove_repetition
  440. ]
  441. )
  442. else:
  443. one_candidate_trie = Trie(
  444. [
  445. self.tokenizer.encode("{}".format(e))
  446. for e in remove_repetition
  447. ]
  448. )
  449. indices = [i for i, c in enumerate(sentence) if c == 29892]
  450. sentence = sentence[indices[-1] + 1 :]
  451. trie_out = one_candidate_trie.get([1] + sentence)
  452. else:
  453. right_parentheses_position = [
  454. index for index, c in enumerate(sentence) if c == 1723
  455. ]
  456. if 29892 not in sentence[right_parentheses_position[-1] :]:
  457. # generate the first task after ()()
  458. trie_out = self.check_two_input_types(sentence)
  459. else: # find the last task and generate without any () in the future
  460. remove_repetition = self.after_one_cand(sentence)
  461. one_candidate_trie = Trie(
  462. [
  463. self.tokenizer.encode("{}".format(e))
  464. for e in remove_repetition
  465. ]
  466. )
  467. indices = [i for i, c in enumerate(sentence) if c == 29892]
  468. sentence = sentence[indices[-1] + 2 :]
  469. trie_out = one_candidate_trie.get([1]+sentence)
  470. elif sentence[-1] == 29892 and self.count_parallel_length(sentence) < module_length:
  471. if sentence.count(313) - sentence.count(1723) == 1:
  472. # need to generate candidates with ), without ), or directly )
  473. remove_repetition = self.after_one_cand(sentence)
  474. if self.count_parallel_length(sentence) + 1 >= module_length:
  475. trie_out = [1, 1723]
  476. else:
  477. one_candidate_trie = Trie(
  478. [[1, 1723]]
  479. + [
  480. self.tokenizer.encode("{}".format(e))
  481. for e in remove_repetition
  482. ]
  483. )
  484. indices = [i for i, c in enumerate(sentence) if c == 29892]
  485. sentence = sentence[indices[-1] + 1 :]
  486. trie_out = one_candidate_trie.get([1]+sentence)
  487. else:
  488. # outside of (), find the last task and keep generating
  489. remove_repetition = self.after_one_cand(sentence)
  490. one_candidate_trie = Trie(
  491. [self.tokenizer.encode("{}".format(e)) for e in remove_repetition]
  492. )
  493. indices = [i for i, c in enumerate(sentence) if c == 29892]
  494. sentence = sentence[indices[-1] + 1 :]
  495. trie_out = one_candidate_trie.get([1]+sentence)
  496. else:
  497. assert self.count_parallel_length(sentence) >= module_length
  498. if sentence.count(313) - sentence.count(1723) == 1:
  499. trie_out = [1, 1723]
  500. else:
  501. trie_out = [1]
  502. return trie_out
  503. return prefix_allowed_tokens
  504. def generate_sequence(
  505. self,
  506. input_ids,
  507. module_length,
  508. num_beams,
  509. num_return_sequences,
  510. **kwargs,
  511. ):
  512. prefix_allowed_tokens = self.llama_prefix_allowed_tokens_fn(module_length, input_ids)
  513. output = self.model.generate_with_grad(
  514. input_ids=input_ids,
  515. max_length=80,
  516. min_length=2,
  517. prefix_allowed_tokens_fn=prefix_allowed_tokens,
  518. num_beams=num_beams,
  519. num_return_sequences=num_return_sequences,
  520. return_dict_in_generate=True,
  521. output_scores=True,
  522. output_hidden_states=True,
  523. )
  524. input_length = input_ids.shape[-1]
  525. output_ids = output["sequences"]
  526. # print(output_ids)
  527. output_sequence = [
  528. s.replace("<pad>", "").replace("</s>", "")
  529. for s in self.tokenizer.batch_decode(output_ids)
  530. ]
  531. # print(output_sequence)
  532. # B * length tuple of (num_beams * vocab_size) tensor
  533. scores = output["scores"]
  534. # print(scores)
  535. if num_beams > 1:
  536. if num_return_sequences == 1:
  537. length = len(scores)
  538. number_of_output_ids = output_ids[0].tolist().count(29892)
  539. logprob = 0
  540. # B * num_beams * length
  541. beam_indices = output["beam_indices"][0]
  542. for l in range(length):
  543. beam_index = beam_indices[l]
  544. # print(tokenizer.decode(output_ids[0][l]))
  545. # print(scores[l][beam_index][output_ids[0][l]])
  546. exponential_score = (
  547. torch.exp(scores[l][beam_index]) + 1e-10
  548. ) # unnormalized prob
  549. # print(exponential_score[output_ids[0][l]])
  550. normalized_score = (
  551. exponential_score / exponential_score.sum()
  552. ) # normalized prob
  553. # print(normalized_score[output_ids[0][l]])
  554. prob_score = torch.log(normalized_score) # normalized log prob
  555. if self.tokenizer.decode(output_ids[0][l + input_length]) == "</s>":
  556. continue
  557. else:
  558. logprob += prob_score[output_ids[0][l + input_length]]
  559. loss = logprob / number_of_output_ids
  560. else:
  561. loss = []
  562. for i in range(num_return_sequences):
  563. one_length = len(scores)
  564. number_of_output_ids = output_ids[i].tolist().count(29892)
  565. if number_of_output_ids == 0:
  566. number_of_output_ids += 1
  567. logprob = 0
  568. # B * num_return_sequences * length
  569. beam_indices = output["beam_indices"][i]
  570. for l in range(one_length):
  571. if l + input_length >= len(output_ids[i]):
  572. break
  573. # print(tokenizer.decode(output_ids[i][l]))
  574. beam_index = beam_indices[l]
  575. # print(scores[l][beam_index])
  576. exponential_score = (
  577. torch.exp(scores[l][beam_index]) + 1e-10
  578. ) # unnormalized prob
  579. # print(exponential_score)
  580. normalized_score = (
  581. exponential_score / exponential_score.sum()
  582. ) # normalized prob
  583. # print(normalized_score)
  584. prob_score = torch.log(exponential_score) # normalized log prob
  585. # print(prob_score)
  586. if self.tokenizer.decode(output_ids[i][l + input_length]) == "</s>":
  587. continue
  588. else:
  589. logprob += prob_score[output_ids[i][l + input_length]]
  590. loss.append(logprob / number_of_output_ids)
  591. else:
  592. logprob = 0
  593. number_of_output_ids = output_ids[0].tolist().count(29892)
  594. length = len(scores)
  595. print(length)
  596. for l in range(length):
  597. exponential_score = torch.exp(scores[l][0]) + 1e-10 # unnormalized prob
  598. normalized_score = exponential_score / (
  599. exponential_score.sum()
  600. ) # normalized prob
  601. prob_score = torch.log(normalized_score) # normalized log prob
  602. if self.tokenizer.decode(output_ids[0][l + input_length]) == "</s>":
  603. continue
  604. else:
  605. logprob += prob_score[output_ids[0][l + input_length]]
  606. loss = logprob / number_of_output_ids
  607. return output_sequence, loss
  608. if __name__ == "__main__":
  609. base_model = "eachadea/vicuna-7b-1.1"
  610. load_8bit = True
  611. max_memory_mapping = {
  612. 0: "0GB",
  613. 1: "0GB",
  614. 2: "0GB",
  615. 3: "0GB",
  616. 4: "24GB",
  617. 5: "24GB",
  618. 6: "0GB",
  619. 7: "0GB",
  620. }
  621. tokenizer = AutoTokenizer.from_pretrained(
  622. "eachadea/vicuna-7b-1.1",
  623. cache_dir="/common/users/yg334/LLAMA/huggingface/cache",
  624. )
  625. tokenizer.add_special_tokens({'pad_token': '<pad>'})
  626. model = AutoModelForCausalLM.from_pretrained(
  627. "eachadea/vicuna-7b-1.1",
  628. cache_dir="/common/users/yg334/LLAMA/huggingface/cache",
  629. device_map="auto",
  630. max_memory=max_memory_mapping,
  631. )
  632. generate_with_grad = undecorated(model.generate)
  633. model.generate_with_grad = MethodType(generate_with_grad, model)
  634. lora_weights = "/common/users/yg334/lora-vicuna"
  635. model = PeftModelForCausalLM.from_pretrained(
  636. model,
  637. lora_weights,
  638. torch_dtype=torch.float16,
  639. )
  640. input_s = [
  641. "### Human: Given low-resolutioned noisy blurry gray image, how to return the regular image step by step? \n### Assistant:"
  642. ]
  643. input_ids = tokenizer.batch_encode_plus(
  644. input_s, padding="longest", return_tensors="pt"
  645. )["input_ids"].cuda()
  646. print(tokenizer.batch_decode(input_ids))
  647. module_length = 10
  648. num_beams = 2
  649. num_return_sequences = 1
  650. sq = SeqGen(model,tokenizer)
  651. prefix_allowed_tokens = sq.llama_prefix_allowed_tokens_fn(module_length, input_ids)
  652. # output = model.generate_with_grad(
  653. # input_ids=input_ids,
  654. # max_length=70,
  655. # min_length=1,
  656. # prefix_allowed_tokens_fn=prefix_allowed_tokens,
  657. # num_beams=2,
  658. # num_return_sequences=2,
  659. # return_dict_in_generate=True,
  660. # output_scores=True,
  661. # output_hidden_states=True,
  662. # )
  663. # output_ids = output["sequences"][0][1:]
  664. # output_sequence = (
  665. # tokenizer.decode(output_ids).replace("<pad>", "").replace("</s>", "")
  666. # )
  667. # print(output_sequence)
  668. output_sequence, loss = sq.generate_sequence(input_ids, module_length, num_beams, num_return_sequences)
  669. print(output_sequence)
  670. print(loss)