gen_mcts_dpo.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from functools import lru_cache
  2. import os
  3. import json
  4. from glob import glob
  5. import random
  6. import re
  7. from tqdm import tqdm
  8. from itertools import groupby
  9. import numpy as np
  10. data_folders = [
  11. './jsons'
  12. ]
  13. pattern = re.compile(r'\-?\d+\.\d+|\-?\d+')
  14. @lru_cache(1024)
  15. def extract_label(text: str) -> str:
  16. if '\n####' in text:
  17. text = text.split('\n####')[-1].replace(',','')
  18. elif 'The answer is' in text:
  19. text = text.split('The answer is')[-1].replace(',','')
  20. numbers = pattern.findall(text)
  21. if not numbers:
  22. return None
  23. return numbers[0]
  24. def check(gt,ans):
  25. gt_label = extract_label(gt)
  26. ans_label = extract_label(ans)
  27. # print(gt_label,ans_label)
  28. if gt_label is None or ans_label is None:
  29. return False
  30. if ans_label == gt_label or abs(float(ans_label) - float(gt_label)) < 1e-5:
  31. return True
  32. else:
  33. return False
  34. final_json_list = []
  35. def get_json(query,good,bad,history=[]):
  36. return {'input':'','instruction':query,'output':[good,bad],'history':history}
  37. def get_node_id(answers,ans):
  38. return answers.index(ans)
  39. def get_oldest_father(answers,ans,childs):
  40. possible_fathers = []
  41. for possible_father in childs:
  42. if ans in childs[possible_father]:
  43. possible_fathers.append(possible_father)
  44. print(len(possible_fathers))
  45. possible_father_ids = []
  46. for possible_father in possible_fathers:
  47. possible_father_ids.append(get_node_id(answers,possible_father))
  48. return possible_fathers[possible_father_ids.index(min(possible_father_ids))]
  49. def get_fathers(answers,ans,childs):
  50. possible_fathers = []
  51. for possible_father in childs:
  52. if ans in childs[possible_father]:
  53. possible_fathers.append(possible_father)
  54. return possible_fathers
  55. def fix_loops(answers,fathers,childs):
  56. # 如果节点已经在fathers字典中,说明找到了环的起点
  57. fathers_no_loop = {}
  58. for node in childs:
  59. for child in childs[node]:
  60. if child not in fathers_no_loop:
  61. fathers_no_loop[child] = [node,]
  62. elif child in fathers_no_loop:
  63. fathers_no_loop[child].append(child)
  64. for ans in answers:
  65. if ans not in fathers_no_loop:
  66. fathers_no_loop[ans] = [None,]
  67. return fathers_no_loop
  68. def collect_paths(answers,gt,fathers,childs):
  69. gold = []
  70. for answer in answers:
  71. if check(gt,answer):
  72. gold.append(answer)
  73. paths = []
  74. for g in gold:
  75. if g is None:
  76. continue
  77. path = [g,]
  78. while g in fathers and g is not None:
  79. father = None
  80. for t in fathers:
  81. if t in path:
  82. continue
  83. else:
  84. father =t
  85. g = father
  86. if g is not None:
  87. path.append(g)
  88. else:
  89. break
  90. paths.append(path)
  91. return paths
  92. def rereward(paths,answers,gt,fathers,childs,gemma=0.9):
  93. structue_reward = {}
  94. for path in paths:
  95. for i,ix in enumerate(path):
  96. structue_reward[ix] = gemma**i
  97. path_list = []
  98. for path in paths:
  99. path_list.extend(path)
  100. gemma2 = 0.5*gemma
  101. root_reward = min(structue_reward.values())*gemma
  102. def get_reward(ans):
  103. if ans is None:
  104. structue_reward[ans] = root_reward
  105. return structue_reward[ans]
  106. if ans in path_list:
  107. return structue_reward[ans]
  108. if ans in structue_reward:
  109. return structue_reward[ans]
  110. if ans in fathers:
  111. if fathers[ans] is None:
  112. structue_reward[ans] = root_reward * gemma2
  113. return structue_reward[ans]
  114. if fathers[ans] in structue_reward:
  115. structue_reward[ans] = structue_reward[fathers[ans]] * gemma2
  116. return structue_reward[ans]
  117. else:
  118. structue_reward[ans] = get_reward(fathers[ans]) * gemma2
  119. return structue_reward[ans]
  120. for ans in answers:
  121. get_reward(ans)
  122. return structue_reward
  123. def get_refined_ans(history_bank,hints_list,answer_list):
  124. hints_map = {}
  125. for ans in history_bank:
  126. if len(history_bank[ans]) > 2:
  127. hint = history_bank[ans][-3]
  128. hints_map[hint] = ans
  129. for hint in hints_list:
  130. if hint not in hints_map:
  131. for history in history_bank.values():
  132. if hint in history:
  133. hints_map[hint] = history[history.index(hint)+2]
  134. break
  135. dummys = ["I Don't Know","I can't understand this question.","I can't help with this question.","I don't know how to solve this question.","I don't know the answer to this question.","I don't know the answer to this question, sorry."]
  136. startpoint = 1
  137. for dummy in dummys:
  138. if dummy in answer_list:
  139. startpoint = answer_list.index(dummy) + 1
  140. for hint in hints_list:
  141. if hint not in hints_map:
  142. hints_map[hint] = answer_list[hints_list.index(hint) + startpoint]
  143. return hints_map
  144. def collect_refine(paths,hints_reward_imp_bank,hints_map,structure_reward):
  145. re_hints_reward_imp_bank = {}
  146. for ans in hints_reward_imp_bank:
  147. if len(hints_reward_imp_bank[ans]) >= 2:
  148. re_hints_reward_imp_bank[ans] = []
  149. for hint,_ in hints_reward_imp_bank[ans]:
  150. reward0 = structure_reward[ans]
  151. refined_ans = hints_map[hint]
  152. reward1 = structure_reward[refined_ans]
  153. re_hints_reward_imp_bank[ans].append([hint,reward1-reward0])
  154. re_hints_reward_imp_bank[ans] = sorted(re_hints_reward_imp_bank[ans], key=lambda x: x[1], reverse=True)
  155. re_hints_reward_imp_bank[ans] = [random.choice(list(g))[0] for k, g in groupby(re_hints_reward_imp_bank[ans], key=lambda x: x[1])]
  156. return re_hints_reward_imp_bank
  157. def pair_importance_sampling(rewards, actions, nums):
  158. # Initialize an empty list to store the importance weights
  159. weights = []
  160. action_pairs = []
  161. # For each pair of actions
  162. for i in range(len(actions)):
  163. for j in range(i+1, len(actions)):
  164. # Calculate the difference in rewards
  165. reward_diff = abs(rewards[i] - rewards[j])
  166. # Use the reward difference as the weight for this pair
  167. weights.append(reward_diff)
  168. if rewards[i] >= rewards[j]:
  169. action_pairs.append([actions[i],actions[j]])
  170. else:
  171. action_pairs.append([actions[j],actions[i]])
  172. # Normalize the weights so they sum to 1
  173. weights = [weight / sum(weights) for weight in weights]
  174. action_pairs_index = list(range(len(action_pairs)))
  175. # Sample from the actions according to the weights
  176. sampled_actions_index = np.random.choice(action_pairs_index, size=nums, p=weights)
  177. sampled_actions = [action_pairs[index] for index in sampled_actions_index]
  178. return sampled_actions
  179. def refine_prompt(query,ans):
  180. q = f'Since we have a weak Answer, could you provide me with a relection or feedback to correct this answer better? Analyze this Answer Strictly and Critic, point out every flaw for ervery possible imperfect to minus every possible score!\nLet\'s think step by step.'
  181. return q
  182. for data_folder in data_folders:
  183. for file in tqdm(glob(data_folder+'/*')):
  184. # for file in tqdm(glob('/home/bingxing2/ailab/group/ai4phys/math/gsm8k-pathfinder-mistral7B-new-mcts-7/jsons/0913cc66580de7f71567cee17c96479a.json')):
  185. # print(file)
  186. data = json.load(open(file,'r'))
  187. gold = []
  188. for answer in data['answers_list']:
  189. if check(data['ground_truth'],answer):
  190. gold.append(answer)
  191. data['fathers'] = fix_loops(data['answers_list'],data['fathers'],data['childs'])
  192. golden_paths = collect_paths(data['answers_list'],data['ground_truth'],data['fathers'],data['childs'])
  193. structue_reward = rereward(golden_paths,data['answers_list'],data['ground_truth'],data['fathers'],data['childs'])
  194. hints_map = get_refined_ans(data['history_bank'],data['hints_list'],data['answers_list'])
  195. re_hints_reward_imp_bank = collect_refine(golden_paths,data['hints_reward_imp_bank'],hints_map,structue_reward)
  196. dpo_pairs = [] #q,good,bad
  197. for path in golden_paths:#golden path from right answer to wrong root answers
  198. if len(path) > 1:
  199. for i,ix in enumerate(path):
  200. # if ix in ["I Don't Know","I can't understand this question.","I can't help with this question.","I don't know how to solve this question.","I don't know the answer to this question.","I don't know the answer to this question, sorry."]:
  201. # path.remove(ix)
  202. if ix in gold and i != 0:
  203. path.remove(ix)
  204. if len(path) <= 1:
  205. golden_paths.remove(path)
  206. for path in golden_paths:
  207. if len(path) == 2:
  208. dpo_pairs.append([data['query'],path[0],path[-1]])
  209. else:
  210. pairs = pair_importance_sampling([structue_reward[node] for node in path],path,(len(path)**2)//2)
  211. pairs = [[data['query'],pair[0],pair[-1]] for pair in pairs]
  212. dpo_pairs.extend(pairs)
  213. # if i < 1:
  214. # continue
  215. # else:
  216. # dpo_pairs.append([data['query'],path[i-1],ix])
  217. # for ans in re_hints_reward_imp_bank:
  218. # if ans in ["I Don't Know","I can't understand this question.","I can't help with this question.","I don't know how to solve this question.","I don't know the answer to this question.","I don't know the answer to this question, sorry."]:
  219. # continue
  220. # if len(re_hints_reward_imp_bank[ans]) >= 2:
  221. # for i,hint in enumerate(re_hints_reward_imp_bank[ans]):
  222. # if i < 1:
  223. # continue
  224. # dpo_pairs.append([refine_prompt(data['query'],ans),re_hints_reward_imp_bank[ans][i-1],hint,[[data['query'],ans],]])
  225. for dpo_pair in dpo_pairs:
  226. final_json_list.append(get_json(*dpo_pair))
  227. with open('data_mistral7b_pathfinder_new_mcts_answers_10_percent.json','w') as f:
  228. random.shuffle(final_json_list)
  229. print(len(final_json_list))
  230. json.dump(final_json_list[:len(final_json_list)//100],f)