run_with_earlystopping.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. import copy
  2. from curses.ascii import isalpha, isdigit
  3. import math
  4. import multiprocessing
  5. import os
  6. import re
  7. import socket
  8. import sys
  9. from datasets import load_dataset
  10. import hashlib
  11. import json
  12. import random
  13. from functools import lru_cache
  14. import numpy as np
  15. from tqdm import tqdm
  16. import time
  17. from retry import retry
  18. import random
  19. # MODEL_NAME = 'meta-llama/Llama-2-7b-chat-hf'
  20. # MODEL_NAME = 'mistralai/Mistral-7B-Instruct-v0.2'
  21. # MODEL_NAME = 'meta-llama/Meta-Llama-3-8B-Instruct'
  22. # MODEL_NAME = 'google/gemma-1.1-7b-it'
  23. # MODEL_NAME = 'test-lora'
  24. # MODEL_NAME = '/home/bingxing2/ailab/group/ai4phys/EXPORT/new_mistral_7b_4'
  25. MODEL_NAME = ''
  26. # DATA_NAME = 'meta-math-40k-pathfinder-mistral7B'
  27. # DATA_NAME = 'meta-math-40k-pathfinder-llama2_7B'
  28. # DATA_NAME = 'meta-math-40k-testtime-llama2_7B'
  29. # DATA_NAME = 'gsm8k-rs-llama2_7B'
  30. # DATA_NAME = 'meta-math-40k-testtime-mistral7B'
  31. # DATA_NAME = 'gsm8k-rs-mistral7B'
  32. # DATA_NAME = 'gsm8k-sample-testtime-mistral-dpo-7'
  33. # DATA_NAME = 'gsm8k-testtime-mistral_7B_pathfinder_0'
  34. # DATA_NAME = 'MATH-rs-mistral7B'
  35. # DATA_NAME = 'gsm8k-pathfinder-gemma7b-new-mcts-8'
  36. # DATA_NAME = 'gsmhard-pathfinder-llama3-8b-new-mcts-8'
  37. # DATA_NAME = 'olympiadbench-pathfinder-llama3-8b-new-mcts-8'
  38. # DATA_NAME = 'GAIC-pathfinder-llama3-8b-new-mcts-8'
  39. # DATA_NAME = 'MATH-pathfinder-llama3-8b-new-mcts-8'
  40. # DATA_NAME = 'AIME-pathfinder-llama3-8b-mcts-2'
  41. # DATA_NAME = 'gsm8k-testtime-pathfinder-mistral7B-mcts-2'
  42. # DATA_NAME = 'gsm8k-testtime-pathfinder-pureseq-mistral7B-5'
  43. DATA_NAME = ''
  44. if MODEL_NAME == '':
  45. MODEL_NAME = sys.argv[1]
  46. if DATA_NAME == '':
  47. DATA_NAME = sys.argv[2]
  48. def last_boxed_only_string(string):
  49. idx = string.rfind('\\boxed')
  50. if idx < 0:
  51. idx = string.rfind('\\fbox')
  52. if idx < 0:
  53. return None
  54. i = idx
  55. right_brace_idx = None
  56. num_left_braces_open = 0
  57. while i < len(string):
  58. if string[i] == '{':
  59. num_left_braces_open += 1
  60. if string[i] == '}':
  61. num_left_braces_open -= 1
  62. if num_left_braces_open == 0:
  63. right_brace_idx = i
  64. break
  65. i += 1
  66. if right_brace_idx is None:
  67. retval = None
  68. else:
  69. retval = string[idx:right_brace_idx + 1]
  70. return retval
  71. def remove_boxed(s):
  72. left = '\\boxed{'
  73. try:
  74. assert s[:len(left)] == left
  75. assert s[-1] == '}'
  76. return s[len(left):-1]
  77. except Exception:
  78. return None
  79. def extract_boxed_answer(pred_str, strip_double_curly_brace=False):
  80. boxed_str = last_boxed_only_string(pred_str)
  81. if boxed_str is None:
  82. return None
  83. answer = remove_boxed(boxed_str)
  84. if answer is None:
  85. return None
  86. if strip_double_curly_brace:
  87. match = re.match('^\{(.*)\}$', answer) # noqa: W605
  88. if match:
  89. answer = match.group(1)
  90. return answer
  91. class Extractor:
  92. def extract_matching_bracket(cls, target_str: str):
  93. if not target_str:
  94. return target_str
  95. current_nest_level = 1
  96. for i, ch in enumerate(target_str):
  97. if ch == '{':
  98. current_nest_level += 1
  99. elif ch == '}':
  100. current_nest_level -= 1
  101. if current_nest_level == 0:
  102. break
  103. return target_str[:i]
  104. def clean(cls, target_str: str):
  105. opt = target_str.strip().replace('{{', '{').replace('}}', '}')
  106. if not opt:
  107. return opt
  108. if opt[-1] == '.' or opt[-1] == '。':
  109. return opt[:-1]
  110. return opt
  111. def extract_answer(cls, pred: str, extract_last_num=False):
  112. if pred.find('The final answer is ') >= 0:
  113. x = pred[pred.find('The final answer is ') +
  114. len('The final answer is '):]
  115. x = x[1:x.find('$.')]
  116. # print(x)
  117. return cls.clean(x)
  118. if pred.find('\n\nQuestion:') >= 0:
  119. pred = pred.split('\n\nQuestion:')[0]
  120. if pred.find('The answer is'):
  121. pred = pred[pred.find('The answer is') + len('The answer is'):]
  122. return cls.clean(pred)
  123. if pred.find('# Answer') >= 0:
  124. return cls.clean(pred[pred.find('# Answer') + len('# Answer'):])
  125. if pred.find('The answer is:') >= 0:
  126. return cls.clean(pred[pred.find('The answer is:') +
  127. len('The answer is:'):])
  128. if pred.find('####') >= 0:
  129. return cls.clean(pred[pred.find('####') + 4:])
  130. left = '\\boxed{'
  131. if pred.find(left) >= 0:
  132. pred = pred[pred.find(left) + len(left):]
  133. return cls.clean(cls.extract_matching_bracket(pred))
  134. if extract_last_num:
  135. nums = []
  136. opt = ''
  137. def contain_digit(opt):
  138. for ch in opt:
  139. if ch.isdigit():
  140. return True
  141. return False
  142. for ch in pred:
  143. if ch.isdigit() or ch in ' ,.':
  144. opt = opt + ch
  145. else:
  146. if contain_digit(opt):
  147. nums.append(opt)
  148. opt = ''
  149. if contain_digit(opt):
  150. return cls.clean(opt)
  151. if nums:
  152. return cls.clean(nums[-1])
  153. return None
  154. def fix_fracs(string):
  155. substrs = string.split('\\frac')
  156. new_str = substrs[0]
  157. if len(substrs) > 1:
  158. substrs = substrs[1:]
  159. for substr in substrs:
  160. new_str += '\\frac'
  161. if substr[0] == '{':
  162. new_str += substr
  163. else:
  164. try:
  165. assert len(substr) >= 2
  166. except AssertionError:
  167. return string
  168. a = substr[0]
  169. b = substr[1]
  170. if b != '{':
  171. if len(substr) > 2:
  172. post_substr = substr[2:]
  173. new_str += '{' + a + '}{' + b + '}' + post_substr
  174. else:
  175. new_str += '{' + a + '}{' + b + '}'
  176. else:
  177. if len(substr) > 2:
  178. post_substr = substr[2:]
  179. new_str += '{' + a + '}' + b + post_substr
  180. else:
  181. new_str += '{' + a + '}' + b
  182. string = new_str
  183. return string
  184. def fix_a_slash_b(string):
  185. if len(string.split('/')) != 2:
  186. return string
  187. a = string.split('/')[0]
  188. b = string.split('/')[1]
  189. try:
  190. a = int(a)
  191. b = int(b)
  192. assert string == '{}/{}'.format(a, b)
  193. new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
  194. return new_string
  195. except AssertionError:
  196. return string
  197. def remove_right_units(string):
  198. # "\\text{ " only ever occurs (at least in the val set)
  199. if '\\text{ ' in string:
  200. splits = string.split('\\text{ ')
  201. assert len(splits) == 2
  202. return splits[0]
  203. else:
  204. return string
  205. def fix_sqrt(string):
  206. if '\\sqrt' not in string:
  207. return string
  208. splits = string.split('\\sqrt')
  209. new_string = splits[0]
  210. for split in splits[1:]:
  211. if split[0] != '{':
  212. a = split[0]
  213. new_substr = '\\sqrt{' + a + '}' + split[1:]
  214. else:
  215. new_substr = '\\sqrt' + split
  216. new_string += new_substr
  217. return new_string
  218. def strip_string(string):
  219. # linebreaks
  220. string = string.replace('\n', '')
  221. # remove inverse spaces
  222. string = string.replace('\\!', '')
  223. # replace \\ with \
  224. string = string.replace('\\\\', '\\')
  225. # replace tfrac and dfrac with frac
  226. string = string.replace('tfrac', 'frac')
  227. string = string.replace('dfrac', 'frac')
  228. # remove \left and \right
  229. string = string.replace('\\left', '')
  230. string = string.replace('\\right', '')
  231. # Remove circ (degrees)
  232. string = string.replace('^{\\circ}', '')
  233. string = string.replace('^\\circ', '')
  234. # remove dollar signs
  235. string = string.replace('\\$', '')
  236. # remove units (on the right)
  237. string = remove_right_units(string)
  238. # remove percentage
  239. string = string.replace('\\%', '')
  240. string = string.replace('\%', '') # noqa: W605
  241. string = string.replace(' .', ' 0.')
  242. string = string.replace('{.', '{0.')
  243. # if empty, return empty string
  244. if len(string) == 0:
  245. return string
  246. if string[0] == '.':
  247. string = '0' + string
  248. # to consider: get rid of e.g. "k = " or "q = " at beginning
  249. if len(string.split('=')) == 2:
  250. if len(string.split('=')[0]) <= 2:
  251. string = string.split('=')[1]
  252. # fix sqrt3 --> sqrt{3}
  253. string = fix_sqrt(string)
  254. # remove spaces
  255. string = string.replace(' ', '')
  256. string = fix_fracs(string)
  257. # manually change 0.5 --> \frac{1}{2}
  258. if string == '0.5':
  259. string = '\\frac{1}{2}'
  260. string = fix_a_slash_b(string)
  261. string = string.replace('x \\in', '').strip() # noqa: W605
  262. # a_b == a, a_{b} == a_b for bit conversion
  263. if string.find('_') >= 0:
  264. p = string.split('_')
  265. p[1] = p[1].replace('{', '').replace('}', '')
  266. string = '_'.join(p)
  267. # 10800 == 10,800; we only deal with single number
  268. if string.strip().find(' ') == -1 and string.find('(') == -1:
  269. string = string.replace(',', '')
  270. return string
  271. def is_equiv(str1, str2, verbose=False):
  272. if str1 is None and str2 is None:
  273. # print("WARNING: Both None")
  274. return False
  275. if str1 is None or str2 is None:
  276. return False
  277. try:
  278. ss1 = strip_string(str1)
  279. ss2 = strip_string(str2)
  280. return ss1 == ss2
  281. except Exception:
  282. return str1 == str2
  283. if not os.path.exists(DATA_NAME):
  284. os.mkdir(DATA_NAME)
  285. if not os.path.exists(f'{DATA_NAME}/jsons'):
  286. os.mkdir(f'{DATA_NAME}/jsons')
  287. if 'testtime' in DATA_NAME:
  288. if 'gsm8k' in DATA_NAME:
  289. if 'sample' in DATA_NAME:
  290. dataset = load_dataset("gsm8k",'main',split='test')
  291. # dataset = dataset.shuffle()
  292. dataset = dataset.select(range(130))
  293. else:
  294. dataset = load_dataset("gsm8k",'main',split='test')
  295. elif 'MATH' in DATA_NAME:
  296. dataset = load_dataset("lighteval/MATH",'all',split='test')
  297. else:
  298. if 'gsmhard' in DATA_NAME:
  299. dataset = load_dataset("reasoning-machines/gsm-hard",split='train')
  300. elif 'gsm8k' in DATA_NAME:
  301. if not 'mcts' in DATA_NAME:
  302. dataset = load_dataset("gsm8k",'main',split='train')
  303. else:
  304. dataset = load_dataset("gsm8k",'main',split='test')
  305. elif 'level5' in DATA_NAME:
  306. dataset = load_dataset("lighteval/MATH",'all',split='test',trust_remote_code=True)
  307. dataset = dataset.filter(lambda example: example["level"].endswith("5"))
  308. elif 'MATH' in DATA_NAME and not'level5' in DATA_NAME:
  309. dataset = load_dataset("lighteval/MATH",'all',split='test',trust_remote_code=True)
  310. elif 'AIME' in DATA_NAME:
  311. dataset = load_dataset("qq8933/AIME_1983_2024",split='train')
  312. elif 'olympiadbench' in DATA_NAME:
  313. dataset = load_dataset("lmms-lab/OlympiadBench",split='test_en')
  314. dataset = dataset.filter(lambda example:len(example["images"]) == 0 and example['final_answer'] is not None and len(example['final_answer']) == 1)
  315. elif 'meta-math' in DATA_NAME:
  316. dataset = load_dataset("meta-math/MetaMathQA-40K",split='train')
  317. elif 'GAIC' in DATA_NAME:
  318. dataset = load_dataset("qq8933/AGI_Odyssey_MATH_GAIC_2024")
  319. elif 'mathinstruct' in DATA_NAME:
  320. dataset = load_dataset('TIGER-Lab/MathInstruct',split='train')
  321. else:
  322. dataset = load_dataset('json',data_files=f'/home/bingxing2/ailab/group/ai4phys/math/data_mistral_var_sft.json')
  323. dataset.shuffle()
  324. from openai import OpenAI
  325. # generation_lock = multiprocessing.Lock()
  326. # def generate(prompt,):
  327. # with generation_lock:
  328. # ret = generate_in(prompt,)
  329. # return ret
  330. # client = OpenAI(
  331. # base_url="http://10.140.24.132:10087/v1",
  332. # api_key="token-abc123",
  333. # )
  334. clients = []
  335. times = time.time()
  336. # def get_clients():
  337. # global clients
  338. # lines = open('/mnt/petrelfs/zhangdi1/reasoningpath/math/server.csv','r').readlines()
  339. # for line in lines:
  340. # if len(line) < 3:
  341. # continue
  342. # node,port,model = line.split(',')
  343. # ip = '.'.join(node.split('-')[-4:])
  344. # client = OpenAI(
  345. # base_url=f"http://{ip}:{port}/v1",
  346. # api_key="token-abc123",
  347. # )
  348. # try:
  349. # client.chat.completions.create(
  350. # model=MODEL_NAME,
  351. # messages=[
  352. # {"role": "user", "content": 'hi'}#+'\nBe concisely and clearly in no more than 50 words.'
  353. # ],
  354. # # max_tokens=min(len(prompt)+128,8000),
  355. # temperature=0.95,
  356. # timeout=10
  357. # )
  358. # print(len(clients)+1)
  359. # clients.append(client)
  360. # except:
  361. # pass
  362. from concurrent.futures import ThreadPoolExecutor
  363. def create_client(line):
  364. global clients
  365. if len(line) < 3:
  366. return
  367. node,port,model = line.split(',')
  368. ip = socket.gethostbyname(node)
  369. print(ip)
  370. client = OpenAI(
  371. base_url=f"http://{ip}:{port}/v1",
  372. api_key="token-abc123",
  373. )
  374. try:
  375. client.chat.completions.create(
  376. model=MODEL_NAME,
  377. messages=[
  378. {"role": "user", "content": 'hi'}#+'\nBe concisely and clearly in no more than 50 words.'
  379. ],
  380. # max_tokens=min(len(prompt)+128,8000),
  381. temperature=0.95,#0.5 if 'testtime' in DATA_NAME else random.uniform(0,1)
  382. timeout=15
  383. )
  384. print(len(clients)+1)
  385. clients.append(client)
  386. except:
  387. pass
  388. def get_clients():
  389. global clients
  390. lines = open('./server.csv','r').readlines()
  391. with ThreadPoolExecutor() as executor:
  392. executor.map(create_client, lines)
  393. def get_client():
  394. global clients,times
  395. # if time.time() - times > 1800:
  396. # clients = []
  397. # get_clients()
  398. # times = time.time()
  399. return random.choice(clients)
  400. @retry()
  401. def generate(prompt,history=[],timeout = 150,truncate=True):
  402. if 'testtime' in DATA_NAME:
  403. timeout=150
  404. print('awaiting response...')
  405. time0 = time.time()
  406. history_ = [{"role": "user" if i %2 ==0 else 'assistant', "content": h} for i,h in enumerate(history)]
  407. if truncate:
  408. history_ = history_[-2:]
  409. completion = get_client().chat.completions.create(
  410. model=MODEL_NAME,
  411. messages=history_+[
  412. # dict(role='user', content="Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?\nLet's think step by step\nAnswer:"),
  413. # dict(role='assistant', content="Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.\nFor the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.\nAngelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.\nHowever, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.\nThey also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.\nAnd they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.\nSo Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.\nThey want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75\nThey will need to plan to study 4 days to allow for all the time they need.\nThe answer is 4\n"),
  414. # dict(role='user', content="Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?\nLet's think step by step\nAnswer:"),
  415. # dict(role='assistant', content="Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.\nHis team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers\nThey scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.\nAll together his team scored 50+24+10= 84 points\nMark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.\nHis opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.\nThey also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.\nAll together Mark's opponents scored 100+12+5=117 points\nThe total score for the game is both team's scores added together, so it is 84+117=201 points\nThe answer is 201\n"),
  416. # dict(role='user', content="Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?\nLet's think step by step\nAnswer:"),
  417. # dict(role='assistant', content="When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24\nThe total number of marbles she'll have is 60+24 = 84\nIf Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.\nIf Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.\nThe total number of frisbees she'll have will increase to 30+12 = 42\nBella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards\nIf she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.\nThe total number of deck cards she'll have is 10+4 = 14\nTogether, Bella will have a total of 14+42+84 = 140 items\nThe answer is 140\n"),
  418. # dict(role='user', content="Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?\nLet's think step by step\nAnswer:"),
  419. # dict(role='assistant', content="For the first three baskets, the number of apples and oranges in one basket is 9+15=24\nIn total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.\nSince there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.\nThe number of apples in the fourth basket is 9-2=7\nThere are also 15-2=13 oranges in the fourth basket\nThe combined number of oranges and apples in the fourth basket is 13+7=20\nThe fourth basket also contains 14-2=12 bananas.\nIn total, the fourth basket has 20+12=32 fruits.\nThe four baskets together have 32+114=146 fruits.\nThe answer is 146\n"),
  420. {"role": "user", "content": prompt}#
  421. ],
  422. # max_tokens=min(len(prompt)+128,8000),
  423. temperature=0.95,#0.5 if 'testtime' in DATA_NAME else random.uniform(0,1),
  424. timeout = timeout
  425. )
  426. print(f'response received! time taken: {time.time()-time0} seconds.')
  427. return completion.choices[0].message.content,list(history)+[prompt,completion.choices[0].message.content]
  428. @retry()
  429. def cal_reward(question,ans):
  430. query = f'Question: {question}\nAnswer:{ans}\nAnalyze this Answer Strictly and Critic, point out every flaw for ervery possible imperfect to minus every possible score! You need to be very harsh and mean in calculating grades, and never give full marks to ensure that the marks are authoritative. \nOutput a score between [-100,+100], ig. from -100 to +100. \nResponse format:\n[Analyst]...[Score]...'
  431. ret = generate(query)
  432. score = ret[0].split('Score')[-1]
  433. scores = pattern.findall(score)
  434. if not scores:
  435. raise Exception('no')
  436. else:
  437. ret = float(scores[-1])
  438. # if abs(ret - 100.0) < 1e-5:
  439. # ret = 50.0
  440. if ret >= 95:
  441. ret = 50
  442. # elif ret <= -100:
  443. # ret = -50
  444. return ret
  445. @retry()
  446. def get_weak_answer(question,new_len=0,ans_format=''):
  447. query = f'Question: {question}\nThe response should begin with [reasoning process]...[Verification]... and end with {ans_format}\nLet\'s think step by step.'
  448. return generate(query,timeout=90)
  449. def get_weak_hints(question,weak_answer,ground_truth_label=None,new_len=0,history=[],alreadygood=False,ans_format=''):
  450. query = f'Question: {question}\nSince we have a weak Answer, could you provide me with a relection or feedback to correct this answer better? Analyze this Answer Strictly and Critic, point out every flaw for ervery possible imperfect to minus every possible score!\nLet\'s think step by step.'
  451. return generate(query,history)
  452. def get_better_answer(question,weak_answer,hint,new_len=0,history=[],ans_format=''):
  453. query = f'Question: {question}\nPlease refine the your answer according to your Reflection or Feedback. The response should begin with [reasoning process]...[Verification]... and end with end with {ans_format}\nLet\'s think step by step.'
  454. return generate(query,history)
  455. def get_gt_hints(question,ground_truth,new_len=0):
  456. query = f"Question: {question}\nGround Truth:{ground_truth}\nAccording to ground truth answer we have, Could you descript the thought process of ground truth answer, please don’t give me the answer, just the thought process?"
  457. return generate(query)
  458. datas = []
  459. pattern = re.compile(r'\-?\d+\.\d+|\-?\d+')
  460. extractor_0 = Extractor()
  461. @lru_cache(1024)
  462. def extract_label(text: str,type='') -> str:
  463. if 'gsm' not in DATA_NAME and type != 'digit':
  464. if '####' in text:
  465. text = text.split('####')[-1]
  466. elif 'The answer is' in text:
  467. text = text.split('The answer is')[-1]
  468. if '####' in text:
  469. text = text.split('####')[-1]
  470. if 'box' in text:
  471. return extract_boxed_answer(text)
  472. else:
  473. return text
  474. if '\n####' in text:
  475. text = text.split('\n####')[-1].replace(',','')
  476. elif 'The answer is' in text:
  477. text = text.split('The answer is')[-1].replace(',','')
  478. numbers = pattern.findall(text)
  479. if not numbers:
  480. return None
  481. if '\n####' in text or 'The answer is' in text:
  482. return numbers[0]
  483. else :
  484. return numbers[-1]
  485. @lru_cache(1024)
  486. def check(gt,ans):
  487. gt_label = extract_label(gt)
  488. if gt_label.isdigit():
  489. type = 'digit'
  490. elif gt_label.isupper() and gt_label.isalpha():
  491. type = 'option'
  492. elif gt_label.lower() in ['yes','no']:
  493. gt_label = gt_label.lower()
  494. type = 'yesorno'
  495. else :
  496. type = 'formula'
  497. ans_label = extract_label(ans,type)
  498. if ans_label:
  499. if type == 'option':
  500. ans_label = ans_label.strip()[0]
  501. elif type == 'yesorno':
  502. ans_label = ans_label.lower()
  503. elif type == 'formula':
  504. ans_label = ans_label.replace('$','')
  505. print(gt_label,ans_label)
  506. if 'gsm' not in DATA_NAME and type != 'digit':
  507. return is_equiv(gt_label,ans_label)
  508. print(gt_label,ans_label)
  509. if gt_label is None or ans_label is None:
  510. return False
  511. if ans_label == gt_label or abs(float(ans_label) - float(gt_label)) < 1e-5:
  512. return True
  513. else:
  514. return False
  515. def hamming_distance(str1, str2):
  516. if len(str1) != len(str2):
  517. raise ValueError("Strings must be of the same length")
  518. return sum(el1 != el2 for el1, el2 in zip(str1[::-1], str2[::-1]))
  519. def simple_reward(gt,ans):
  520. gt_f = format(float(extract_label(gt)),'.5f')
  521. ans_f = format(float(extract_label(ans)),'.5f')
  522. return -hamming_distance(gt_f,ans_f)
  523. def sort_answers_and_rewards(answers, rewards):
  524. # Zip answers and rewards together
  525. answer_reward_pairs = zip(answers, rewards)
  526. # Sort pairs by rewards
  527. sorted_pairs = sorted(answer_reward_pairs, key=lambda x: x[1], reverse=True)
  528. # Extract sorted answers and rewards
  529. sorted_answers = [pair[0] for pair in sorted_pairs]
  530. sorted_rewards = [pair[1] for pair in sorted_pairs]
  531. return sorted_answers, sorted_rewards
  532. def filter_mature_node(childs, to_explore, to_explore_reward,max_expand=3):
  533. filterd_to_explore = []
  534. avg_reward = {node: (min(to_explore_reward[node]) + np.mean(to_explore_reward[node])) / 2 for node in to_explore}
  535. for node in to_explore:
  536. if len(childs.get(node,[])) < max_expand or max([avg_reward.get(child,-999) for child in childs.get(node,[])]) < avg_reward.get(node,-999):
  537. filterd_to_explore.append(node)
  538. return filterd_to_explore
  539. def get_best_explore_from_ucb(to_explore, ucb_bank):
  540. # 初始化最佳节点和最高UCB值
  541. best_node = None
  542. highest_ucb = float('-inf')
  543. # 遍历所有待探索的节点
  544. for node in to_explore:
  545. ucb_value = ucb_bank.get(node, float('-inf'))
  546. if ucb_value > highest_ucb:
  547. highest_ucb = ucb_value
  548. best_node = node
  549. return best_node
  550. def compute_ucb(r_c, N_n, N_c, C):
  551. return r_c + C * math.sqrt(math.log(N_n + 1) / (N_c + 1e-5))
  552. def update_ucb(fathers, childs, to_explore, to_explore_reward, ucb_bank, C=1.4,gamma=0.85):
  553. # 计算所有节点的访问次数
  554. visit_count = {node: len(to_explore_reward[node]) for node in to_explore}
  555. # 计算所有节点的平均奖励
  556. # avg_reward = {node: sum(to_explore_reward[node]) / len(to_explore_reward[node]) for node in to_explore}
  557. avg_reward = {node: (min(to_explore_reward[node]) + np.mean(to_explore_reward[node])) / 2 for node in to_explore}
  558. # 获取所有叶子节点
  559. leaves = set(to_explore) - set(fathers.values())
  560. # 更新所有叶子节点的UCB值
  561. for leaf in leaves:
  562. # ucb_bank[leaf] = avg_reward[leaf]
  563. ucb_bank[leaf] = compute_ucb(avg_reward[leaf],len(to_explore_reward.get(fathers.get(leaf,None),[])),len(to_explore_reward.get(leaf,[])),C)
  564. # 从叶子节点向上更新父节点的UCB值
  565. nodes_to_update = list(leaves)
  566. while nodes_to_update:
  567. new_nodes_to_update = set()
  568. for node in nodes_to_update:
  569. father = fathers.get(node)
  570. if father is not None:
  571. if father not in ucb_bank:
  572. new_nodes_to_update.add(father)
  573. if father in ucb_bank:
  574. # 计算父节点的UCB值
  575. ucb_values = []
  576. child_reward = []
  577. for child in childs[father]:
  578. ucb_values.append(ucb_bank[child])
  579. child_reward.append(avg_reward[child])
  580. father_reward = (avg_reward[father] + max(child_reward))/2
  581. ucb_bank[father] = compute_ucb(father_reward,len(to_explore_reward.get(fathers.get(father,None),[])),len(to_explore_reward.get(father,[])),C)
  582. nodes_to_update = list(new_nodes_to_update)
  583. def step(query,weak_answer,ground_truth_label=None,history=[],alreadygood=False,ans_format=''):
  584. hints,history = get_weak_hints(query,weak_answer,ground_truth_label=ground_truth_label,history=history,alreadygood=alreadygood,ans_format=ans_format)
  585. answer,history = get_better_answer(query,weak_answer,hints,history=history,ans_format=ans_format)
  586. return hints,answer,history
  587. def main_loop(query,ground_truth,max_iter=16,ans_format=''):
  588. to_explore = []
  589. to_explore_reward = {}
  590. history_bank = {}
  591. hints_bank = {}
  592. ucb_bank = {}
  593. fathers = {}
  594. childs = {}
  595. def sampling_reward(answer):
  596. if answer not in to_explore_reward:
  597. to_explore_reward[answer] = []
  598. reward = cal_reward(query,answer)
  599. # if check(ground_truth,answer):
  600. # reward += 100
  601. to_explore_reward[answer].append(reward)
  602. def add_to_hints_bank(hints,weak_answer):
  603. if weak_answer not in hints_bank:
  604. hints_bank[weak_answer] = []
  605. hints_bank[weak_answer].append(hints)
  606. def add_to_childs(father,child):
  607. if father not in childs:
  608. childs[father] = []
  609. childs[father].append(child)
  610. hints_reward_imp_bank = {}
  611. def add_to_hints_reward_imp_bank(hints,weak_answer,reward,answer):
  612. if weak_answer not in hints_reward_imp_bank:
  613. hints_reward_imp_bank[weak_answer] = []
  614. hints_reward_imp_bank[weak_answer].append((hints,reward,answer))
  615. ground_truth_label = extract_label(ground_truth)
  616. ###get weak answer###
  617. weak_answer,history = get_weak_answer(query,ans_format=ans_format)
  618. history_bank[weak_answer] = tuple(history)
  619. answers_list = [weak_answer,]
  620. to_explore = [weak_answer,]
  621. childs[weak_answer] = []
  622. fathers[weak_answer] = None
  623. # to_explore_reward = [cal_reward(query,weak_answer),]
  624. sampling_reward(weak_answer)
  625. ##add total-bad answer###
  626. # if check(ground_truth,weak_answer):
  627. # return
  628. if True:#not check(ground_truth,weak_answer):
  629. total_bad = random.choice(["I Don't Know","I can't understand this question.","I can't help with this question.","I don't know how to solve this question.","I don't know the answer to this question.","I don't know the answer to this question, sorry."])
  630. total_bad_history = copy.deepcopy(history)
  631. total_bad_history[-1] = total_bad
  632. history_bank[total_bad] = tuple(total_bad_history)
  633. answers_list += [total_bad,]
  634. to_explore += [total_bad,]
  635. childs[total_bad] = []
  636. fathers[total_bad] = None
  637. # to_explore_reward = [cal_reward(query,weak_answer),]
  638. sampling_reward(total_bad)
  639. hints_list = []
  640. if check(ground_truth,weak_answer) :#and 'testtime' in DATA_NAME
  641. return hints_list,answers_list,to_explore,to_explore_reward,hints_bank,history_bank,hints_reward_imp_bank,fathers,childs,ucb_bank
  642. patient = 0 if 'testtime' not in DATA_NAME else 0
  643. alpha = 0.45
  644. update_ucb(fathers=fathers,childs=childs,to_explore=to_explore,to_explore_reward=to_explore_reward,ucb_bank=ucb_bank)
  645. for i in range(max_iter):
  646. print('iteration:',i)
  647. filterd_to_explore = filter_mature_node(childs, to_explore, to_explore_reward)
  648. weak_answer = get_best_explore_from_ucb(filterd_to_explore, ucb_bank)
  649. sampling_reward(weak_answer)
  650. hints,answer,history = step(query,weak_answer,history=history_bank[weak_answer],ans_format=ans_format)
  651. add_to_hints_bank(hints,weak_answer)
  652. history_bank[answer] = tuple(history)
  653. to_explore.append(answer)
  654. sampling_reward(answer)
  655. fathers[answer] = weak_answer
  656. childs[answer] = []
  657. add_to_childs(weak_answer,answer)
  658. answers_list.append(answer)
  659. hints_list.append(hints)
  660. if check(ground_truth,answer) and 'testtime' in DATA_NAME:
  661. return hints_list,answers_list,to_explore,to_explore_reward,hints_bank,history_bank,hints_reward_imp_bank,fathers,childs,ucb_bank
  662. elif check(ground_truth,answer) and 'testtime' not in DATA_NAME:
  663. if patient <= 0:
  664. return hints_list,answers_list,to_explore,to_explore_reward,hints_bank,history_bank,hints_reward_imp_bank,fathers,childs,ucb_bank
  665. patient -= 1
  666. update_ucb(fathers=fathers,childs=childs,to_explore=to_explore,to_explore_reward=to_explore_reward,ucb_bank=ucb_bank)
  667. add_to_hints_reward_imp_bank(hints,weak_answer,min(to_explore_reward.get(answer)) - min(to_explore_reward.get(weak_answer)),answer)#ucb_bank[answer] - ucb_bank[weak_answer]
  668. return hints_list,answers_list,to_explore,to_explore_reward,hints_bank,history_bank,hints_reward_imp_bank,fathers,childs,ucb_bank
  669. def tryfunc(example):
  670. try:
  671. if os.path.exists(f'{DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}.json.lock'):
  672. return
  673. else:
  674. os.system(f'touch {DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}.json.lock')
  675. func(example)
  676. if os.path.exists(f'{DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}.json.lock'):
  677. os.system(f'rm {DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}.json.lock')
  678. except:
  679. print(example)
  680. pass
  681. # for example in tqdm(dataset['train']):
  682. def func(example):
  683. if os.path.exists(f'{DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}.json'):
  684. # return json.load(open(f'{DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}'))
  685. return {}
  686. if 'instruction' in example and 'output' in example:
  687. query = example['instruction'] + '\n' + example['input']
  688. ground_truth = example['output']
  689. elif 'context' in example and 'question' in example:
  690. if example['context']:
  691. query = example['context'] + '\n' + example['question']
  692. else:
  693. query = example['question']
  694. ground_truth = example['final_answer'][0].replace('$','')
  695. elif 'GAIC' in DATA_NAME :
  696. query = example['problem']
  697. ground_truth = example['answer']
  698. else:
  699. if 'query' in example:
  700. query = example['query']
  701. elif 'problem' in example:
  702. query = example['problem']
  703. elif 'input' in example:
  704. query = example['input']
  705. elif 'Question' in example:
  706. query = example['Question']
  707. else:
  708. query = example['question']
  709. if 'response' in example:
  710. ground_truth = example['response']
  711. elif 'solution' in example:
  712. ground_truth = example['solution']
  713. elif 'target' in example:
  714. ground_truth = str(example['target'])
  715. elif 'Answer' in example:
  716. ground_truth = example['Answer']
  717. else:
  718. ground_truth = example['answer']
  719. if 'gsm' in DATA_NAME:
  720. ans_format = r'"[Final Answer] The answer is [answer] \n#### [answer]"'
  721. else:
  722. if extract_label(ground_truth).isdigit():
  723. ans_format = r'"[Final Answer] The answer is [number] \n#### [number]"'
  724. elif extract_label(ground_truth).isalpha() and extract_label(ground_truth).isupper():
  725. ans_format = r'"[Final Answer] The answer is \\boxed{[option]} \n#### [option]"'
  726. elif extract_label(ground_truth).lower() in ['yes','no']:
  727. ans_format = r'"[Final Answer] The answer is \\boxed{[Yes or No]} \n#### [Yes or No]"'
  728. else:
  729. ans_format = r'"[Final Answer] The answer is \\boxed{[answer formula]} \n#### [answer formula]"'
  730. # new_len = len(ground_truth)
  731. hints_prompt = f'Question: {query}\nCould you provide me with the thought process to solve this problem, but please don’t give me the answer or calculation, just the thought process?'
  732. max_iter = 16
  733. if 'meta-math' in DATA_NAME:
  734. max_iter = 8
  735. if 'testtime' in DATA_NAME:
  736. max_iter = 2
  737. hints_list,answers_list,to_explore,to_explore_reward,hints_bank,history_bank,hints_reward_imp_bank,fathers,childs,ucb_bank = main_loop(query,ground_truth,max_iter=max_iter,ans_format=ans_format)
  738. if len(answers_list) <= 1 and 'rs' in DATA_NAME:
  739. return
  740. else:
  741. if not 'testtime' in DATA_NAME:
  742. # gt_hints = get_gt_hints(query,ground_truth)
  743. gt_hints = ''
  744. pass
  745. else:
  746. gt_hints = ''
  747. data = {
  748. 'query':query,
  749. 'ground_truth':ground_truth,
  750. 'hints_list':hints_list,
  751. 'answers_list':answers_list,
  752. 'ground_truth_hints':gt_hints,
  753. 'hints_prompt':hints_prompt,
  754. 'to_explore':to_explore,
  755. 'to_explore_reward':to_explore_reward,
  756. 'hints_bank':hints_bank,
  757. 'history_bank':history_bank,
  758. 'hints_reward_imp_bank':hints_reward_imp_bank,
  759. 'fathers':fathers,
  760. 'childs':childs,
  761. 'ucb_bank':ucb_bank,
  762. }
  763. if 'rs' in DATA_NAME and not check(ground_truth,answers_list[-1]):
  764. return
  765. with open(f'{DATA_NAME}/jsons/{hashlib.md5(str(example).encode()).hexdigest()}.json','w+') as f:
  766. json.dump(data,f,indent=4,ensure_ascii=False)
  767. return data
  768. if __name__ == '__main__':
  769. get_clients()
  770. # while True:
  771. # try:
  772. # datas = dataset.map(func,num_proc=len(clients)*8)
  773. datas = dataset.map(func,num_proc=32)
  774. # except :
  775. # continue
  776. # break
  777. # datas.save_to_disk('meta-math-40k-weak-better-mistral7B-data')