utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. # from ultralytics import YOLO
  2. import os
  3. import io
  4. import base64
  5. import time
  6. from PIL import Image, ImageDraw, ImageFont
  7. import json
  8. import requests
  9. # utility function
  10. import os
  11. from openai import AzureOpenAI
  12. import json
  13. import sys
  14. import os
  15. import cv2
  16. import numpy as np
  17. # %matplotlib inline
  18. from matplotlib import pyplot as plt
  19. import easyocr
  20. reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory # 'ch_sim',
  21. import time
  22. import base64
  23. import os
  24. import ast
  25. import torch
  26. from typing import Tuple, List
  27. from torchvision.ops import box_convert
  28. import re
  29. from torchvision.transforms import ToPILImage
  30. import supervision as sv
  31. import torchvision.transforms as T
  32. def get_caption_model_processor(model_name="Salesforce/blip2-opt-2.7b", device=None):
  33. if not device:
  34. device = "cuda" if torch.cuda.is_available() else "cpu"
  35. if model_name == "Salesforce/blip2-opt-2.7b":
  36. from transformers import Blip2Processor, Blip2ForConditionalGeneration
  37. processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
  38. model = Blip2ForConditionalGeneration.from_pretrained(
  39. "Salesforce/blip2-opt-2.7b", device_map=None, torch_dtype=torch.float16
  40. # '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
  41. )
  42. elif model_name == "blip2-opt-2.7b-ui":
  43. from transformers import Blip2Processor, Blip2ForConditionalGeneration
  44. processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
  45. if device == 'cpu':
  46. model = Blip2ForConditionalGeneration.from_pretrained(
  47. '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float32
  48. )
  49. else:
  50. model = Blip2ForConditionalGeneration.from_pretrained(
  51. '/home/yadonglu/sandbox/data/orca/blipv2_ui_merge', device_map=None, torch_dtype=torch.float16
  52. )
  53. elif model_name == "florence":
  54. from transformers import AutoProcessor, AutoModelForCausalLM
  55. processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
  56. if device == 'cpu':
  57. model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai", torch_dtype=torch.float32, trust_remote_code=True)#.to(device)
  58. else:
  59. model = AutoModelForCausalLM.from_pretrained("/home/yadonglu/sandbox/data/orca/florence-2-base-ft-fft_ep1_rai_win_ep5_fixed", torch_dtype=torch.float16, trust_remote_code=True).to(device)
  60. elif model_name == 'phi3v_ui':
  61. from transformers import AutoModelForCausalLM, AutoProcessor
  62. model_id = "microsoft/Phi-3-vision-128k-instruct"
  63. model = AutoModelForCausalLM.from_pretrained('/home/yadonglu/sandbox/data/orca/phi3v_ui', device_map=device, trust_remote_code=True, torch_dtype="auto")
  64. processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
  65. elif model_name == 'phi3v':
  66. from transformers import AutoModelForCausalLM, AutoProcessor
  67. model_id = "microsoft/Phi-3-vision-128k-instruct"
  68. model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True, torch_dtype="auto")
  69. processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
  70. return {'model': model.to(device), 'processor': processor}
  71. def get_yolo_model(model_path):
  72. from ultralytics import YOLO
  73. # Load the model.
  74. model = YOLO(model_path)
  75. return model
  76. def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=None):
  77. to_pil = ToPILImage()
  78. if ocr_bbox:
  79. non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
  80. else:
  81. non_ocr_boxes = filtered_boxes
  82. croped_pil_image = []
  83. for i, coord in enumerate(non_ocr_boxes):
  84. xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
  85. ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
  86. cropped_image = image_source[ymin:ymax, xmin:xmax, :]
  87. croped_pil_image.append(to_pil(cropped_image))
  88. # import pdb; pdb.set_trace()
  89. model, processor = caption_model_processor['model'], caption_model_processor['processor']
  90. if not prompt:
  91. if 'florence' in model.config.name_or_path:
  92. prompt = "<CAPTION>"
  93. else:
  94. prompt = "The image shows"
  95. # prompt = "NO gender!NO gender!NO gender! The image shows a icon:"
  96. batch_size = 10 # Number of samples per batch
  97. generated_texts = []
  98. device = model.device
  99. for i in range(0, len(croped_pil_image), batch_size):
  100. batch = croped_pil_image[i:i+batch_size]
  101. if model.device.type == 'cuda':
  102. inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
  103. else:
  104. inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
  105. if 'florence' in model.config.name_or_path:
  106. generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,num_beams=3, do_sample=False)
  107. else:
  108. generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
  109. generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
  110. generated_text = [gen.strip() for gen in generated_text]
  111. generated_texts.extend(generated_text)
  112. return generated_texts
  113. def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
  114. to_pil = ToPILImage()
  115. if ocr_bbox:
  116. non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
  117. else:
  118. non_ocr_boxes = filtered_boxes
  119. croped_pil_image = []
  120. for i, coord in enumerate(non_ocr_boxes):
  121. xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
  122. ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
  123. cropped_image = image_source[ymin:ymax, xmin:xmax, :]
  124. croped_pil_image.append(to_pil(cropped_image))
  125. model, processor = caption_model_processor['model'], caption_model_processor['processor']
  126. device = model.device
  127. messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
  128. prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  129. batch_size = 5 # Number of samples per batch
  130. generated_texts = []
  131. for i in range(0, len(croped_pil_image), batch_size):
  132. images = croped_pil_image[i:i+batch_size]
  133. image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
  134. inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
  135. texts = [prompt] * len(images)
  136. for i, txt in enumerate(texts):
  137. input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
  138. inputs['input_ids'].append(input['input_ids'])
  139. inputs['attention_mask'].append(input['attention_mask'])
  140. inputs['pixel_values'].append(input['pixel_values'])
  141. inputs['image_sizes'].append(input['image_sizes'])
  142. max_len = max([x.shape[1] for x in inputs['input_ids']])
  143. for i, v in enumerate(inputs['input_ids']):
  144. inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
  145. inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
  146. inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
  147. generation_args = {
  148. "max_new_tokens": 25,
  149. "temperature": 0.01,
  150. "do_sample": False,
  151. }
  152. generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
  153. # # remove input tokens
  154. generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
  155. response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
  156. response = [res.strip('\n').strip() for res in response]
  157. generated_texts.extend(response)
  158. return generated_texts
  159. def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
  160. assert ocr_bbox is None or isinstance(ocr_bbox, List)
  161. def box_area(box):
  162. return (box[2] - box[0]) * (box[3] - box[1])
  163. def intersection_area(box1, box2):
  164. x1 = max(box1[0], box2[0])
  165. y1 = max(box1[1], box2[1])
  166. x2 = min(box1[2], box2[2])
  167. y2 = min(box1[3], box2[3])
  168. return max(0, x2 - x1) * max(0, y2 - y1)
  169. def IoU(box1, box2):
  170. intersection = intersection_area(box1, box2)
  171. union = box_area(box1) + box_area(box2) - intersection + 1e-6
  172. if box_area(box1) > 0 and box_area(box2) > 0:
  173. ratio1 = intersection / box_area(box1)
  174. ratio2 = intersection / box_area(box2)
  175. else:
  176. ratio1, ratio2 = 0, 0
  177. return max(intersection / union, ratio1, ratio2)
  178. boxes = boxes.tolist()
  179. filtered_boxes = []
  180. if ocr_bbox:
  181. filtered_boxes.extend(ocr_bbox)
  182. # print('ocr_bbox!!!', ocr_bbox)
  183. for i, box1 in enumerate(boxes):
  184. # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
  185. is_valid_box = True
  186. for j, box2 in enumerate(boxes):
  187. if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
  188. is_valid_box = False
  189. break
  190. if is_valid_box:
  191. # add the following 2 lines to include ocr bbox
  192. if ocr_bbox:
  193. if not any(IoU(box1, box3) > iou_threshold for k, box3 in enumerate(ocr_bbox)):
  194. filtered_boxes.append(box1)
  195. else:
  196. filtered_boxes.append(box1)
  197. return torch.tensor(filtered_boxes)
  198. def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
  199. transform = T.Compose(
  200. [
  201. T.RandomResize([800], max_size=1333),
  202. T.ToTensor(),
  203. T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  204. ]
  205. )
  206. image_source = Image.open(image_path).convert("RGB")
  207. image = np.asarray(image_source)
  208. image_transformed, _ = transform(image_source, None)
  209. return image, image_transformed
  210. def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
  211. text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
  212. """
  213. This function annotates an image with bounding boxes and labels.
  214. Parameters:
  215. image_source (np.ndarray): The source image to be annotated.
  216. boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
  217. logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
  218. phrases (List[str]): A list of labels for each bounding box.
  219. text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
  220. Returns:
  221. np.ndarray: The annotated image.
  222. """
  223. h, w, _ = image_source.shape
  224. boxes = boxes * torch.Tensor([w, h, w, h])
  225. xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
  226. xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
  227. detections = sv.Detections(xyxy=xyxy)
  228. labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
  229. from util.box_annotator import BoxAnnotator
  230. box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
  231. annotated_frame = image_source.copy()
  232. annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
  233. label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
  234. return annotated_frame, label_coordinates
  235. def predict(model, image, caption, box_threshold, text_threshold):
  236. """ Use huggingface model to replace the original model
  237. """
  238. model, processor = model['model'], model['processor']
  239. device = model.device
  240. inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
  241. with torch.no_grad():
  242. outputs = model(**inputs)
  243. results = processor.post_process_grounded_object_detection(
  244. outputs,
  245. inputs.input_ids,
  246. box_threshold=box_threshold, # 0.4,
  247. text_threshold=text_threshold, # 0.3,
  248. target_sizes=[image.size[::-1]]
  249. )[0]
  250. boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
  251. return boxes, logits, phrases
  252. def predict_yolo(model, image_path, box_threshold):
  253. """ Use huggingface model to replace the original model
  254. """
  255. # model = model['model']
  256. result = model.predict(
  257. source=image_path,
  258. conf=box_threshold,
  259. # iou=0.5, # default 0.7
  260. )
  261. boxes = result[0].boxes.xyxy#.tolist() # in pixel space
  262. conf = result[0].boxes.conf
  263. phrases = [str(i) for i in range(len(boxes))]
  264. return boxes, conf, phrases
  265. def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None):
  266. """ ocr_bbox: list of xyxy format bbox
  267. """
  268. TEXT_PROMPT = "clickable buttons on the screen"
  269. # BOX_TRESHOLD = 0.02 # 0.05/0.02 for web and 0.1 for mobile
  270. TEXT_TRESHOLD = 0.01 # 0.9 # 0.01
  271. image_source = Image.open(img_path).convert("RGB")
  272. w, h = image_source.size
  273. # import pdb; pdb.set_trace()
  274. if False: # TODO
  275. xyxy, logits, phrases = predict(model=model, image=image_source, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD)
  276. else:
  277. xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD)
  278. xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
  279. image_source = np.asarray(image_source)
  280. phrases = [str(i) for i in range(len(phrases))]
  281. # annotate the image with labels
  282. h, w, _ = image_source.shape
  283. if ocr_bbox:
  284. ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
  285. ocr_bbox=ocr_bbox.tolist()
  286. else:
  287. print('no ocr bbox!!!')
  288. ocr_bbox = None
  289. filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
  290. # get parsed icon local semantics
  291. if use_local_semantics:
  292. caption_model = caption_model_processor['model']
  293. if 'phi3_v' in caption_model.config.model_type:
  294. parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
  295. else:
  296. parsed_content_icon = get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=prompt)
  297. ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
  298. icon_start = len(ocr_text)
  299. parsed_content_icon_ls = []
  300. for i, txt in enumerate(parsed_content_icon):
  301. parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
  302. parsed_content_merged = ocr_text + parsed_content_icon_ls
  303. else:
  304. ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
  305. parsed_content_merged = ocr_text
  306. filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
  307. phrases = [i for i in range(len(filtered_boxes))]
  308. # draw boxes
  309. if draw_bbox_config:
  310. annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
  311. else:
  312. annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
  313. pil_img = Image.fromarray(annotated_frame)
  314. buffered = io.BytesIO()
  315. pil_img.save(buffered, format="PNG")
  316. encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
  317. if output_coord_in_ratio:
  318. # h, w, _ = image_source.shape
  319. label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
  320. assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
  321. return encoded_image, label_coordinates, parsed_content_merged
  322. def get_xywh(input):
  323. x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
  324. x, y, w, h = int(x), int(y), int(w), int(h)
  325. return x, y, w, h
  326. def get_xyxy(input):
  327. x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
  328. x, y, xp, yp = int(x), int(y), int(xp), int(yp)
  329. return x, y, xp, yp
  330. def get_xywh_yolo(input):
  331. x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
  332. x, y, w, h = int(x), int(y), int(w), int(h)
  333. return x, y, w, h
  334. def run_api(body, max_tokens=1024):
  335. '''
  336. API call, check https://platform.openai.com/docs/guides/vision for the latest api usage.
  337. '''
  338. max_num_trial = 3
  339. num_trial = 0
  340. while num_trial < max_num_trial:
  341. try:
  342. response = client.chat.completions.create(
  343. model=deployment,
  344. messages=body,
  345. temperature=0.01,
  346. max_tokens=max_tokens,
  347. )
  348. return response.choices[0].message.content
  349. except:
  350. print('retry call gptv', num_trial)
  351. num_trial += 1
  352. time.sleep(10)
  353. return ''
  354. def call_gpt4v_new(message_text, image_path=None, max_tokens=2048):
  355. if image_path:
  356. try:
  357. with open(image_path, "rb") as img_file:
  358. encoded_image = base64.b64encode(img_file.read()).decode('ascii')
  359. except:
  360. encoded_image = image_path
  361. if image_path:
  362. content = [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}}, {"type": "text","text": message_text},]
  363. else:
  364. content = [{"type": "text","text": message_text},]
  365. max_num_trial = 3
  366. num_trial = 0
  367. call_api_success = True
  368. while num_trial < max_num_trial:
  369. try:
  370. response = client.chat.completions.create(
  371. model=deployment,
  372. messages=[
  373. {
  374. "role": "system",
  375. "content": [
  376. {
  377. "type": "text",
  378. "text": "You are an AI assistant that is good at making plans and analyzing screens, and helping people find information."
  379. },
  380. ]
  381. },
  382. {
  383. "role": "user",
  384. "content": content
  385. }
  386. ],
  387. temperature=0.01,
  388. max_tokens=max_tokens,
  389. )
  390. ans_1st_pass = response.choices[0].message.content
  391. break
  392. except:
  393. print('retry call gptv', num_trial)
  394. num_trial += 1
  395. ans_1st_pass = ''
  396. time.sleep(10)
  397. if num_trial == max_num_trial:
  398. call_api_success = False
  399. return ans_1st_pass, call_api_success
  400. def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None):
  401. if easyocr_args is None:
  402. easyocr_args = {}
  403. result = reader.readtext(image_path, **easyocr_args)
  404. is_goal_filtered = False
  405. if goal_filtering:
  406. ocr_filter_fs = "Example 1:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Share', 0.949013667261589), ([[3068, 197], [3135, 197], [3135, 227], [3068, 227]], 'Link _', 0.3567054243152049), ([[3006, 321], [3178, 321], [3178, 354], [3006, 354]], 'Manage Access', 0.8800734456437066)] ``` \n Example 2:\n Based on task and ocr results, ```In summary, the task related bboxes are: [([[3060, 111], [3135, 111], [3135, 141], [3060, 141]], 'Search Google or type a URL', 0.949013667261589)] ```"
  407. # message_text = f"Based on the ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. The task is: {goal_filtering}, the ocr results are: {str(result)}. Your final answer should be in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."
  408. message_text = f"Based on the task and ocr results which contains text+bounding box in a dictionary, please filter it so that it only contains the task related bboxes. Requirement: 1. first give a brief analysis. 2. provide an answer in the format: ```In summary, the task related bboxes are: ..```, you must put it inside ``` ```. Do not include any info after ```.\n {ocr_filter_fs}\n The task is: {goal_filtering}, the ocr results are: {str(result)}."
  409. prompt = [{"role":"system", "content": "You are an AI assistant that helps people find the correct way to operate computer or smartphone."}, {"role":"user","content": message_text},]
  410. print('[Perform OCR filtering by goal] ongoing ...')
  411. # pred, _, _ = call_gpt4(prompt)
  412. pred, _, = call_gpt4v(message_text)
  413. # import pdb; pdb.set_trace()
  414. try:
  415. # match = re.search(r"```(.*?)```", pred, re.DOTALL)
  416. # result = match.group(1).strip()
  417. # pred = result.split('In summary, the task related bboxes are:')[-1].strip()
  418. pred = pred.split('In summary, the task related bboxes are:')[-1].strip().strip('```')
  419. result = ast.literal_eval(pred)
  420. print('[Perform OCR filtering by goal] success!!! Filtered buttons: ', pred)
  421. is_goal_filtered = True
  422. except:
  423. print('[Perform OCR filtering by goal] failed or unused!!!')
  424. pass
  425. # added_prompt = [{"role":"assistant","content":pred},
  426. # {"role":"user","content": "given the previous answers, please provide the final answer in the exact same format as the ocr results, please do not include any other redundant information, please do not include any analysis."}]
  427. # prompt.extend(added_prompt)
  428. # pred, _, _ = call_gpt4(prompt)
  429. # print('goal filtering pred 2nd:', pred)
  430. # result = ast.literal_eval(pred)
  431. # print('goal filtering pred:', result[-5:])
  432. coord = [item[0] for item in result]
  433. text = [item[1] for item in result]
  434. # confidence = [item[2] for item in result]
  435. # if confidence_filtering:
  436. # coord = [coord[i] for i in range(len(coord)) if confidence[i] > confidence_filtering]
  437. # text = [text[i] for i in range(len(text)) if confidence[i] > confidence_filtering]
  438. # read the image using cv2
  439. if display_img:
  440. opencv_img = cv2.imread(image_path)
  441. opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
  442. bb = []
  443. for item in coord:
  444. x, y, a, b = get_xywh(item)
  445. # print(x, y, a, b)
  446. bb.append((x, y, a, b))
  447. cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
  448. # Display the image
  449. plt.imshow(opencv_img)
  450. else:
  451. if output_bb_format == 'xywh':
  452. bb = [get_xywh(item) for item in coord]
  453. elif output_bb_format == 'xyxy':
  454. bb = [get_xyxy(item) for item in coord]
  455. # print('bounding box!!!', bb)
  456. return (text, bb), is_goal_filtered
  457. def get_pred_gptv(message_text, yolo_labled_img, label_coordinates, summarize_history=True, verbose=True, history=None, id_key='Click ID'):
  458. """ This func first
  459. 1. call gptv(yolo_labled_img, text bbox+task) -> ans_1st_cal
  460. 2. call gpt4(ans_1st_cal, label_coordinates) -> final ans
  461. """
  462. # Configuration
  463. encoded_image = yolo_labled_img
  464. # Payload for the request
  465. if not history:
  466. messages = [
  467. {"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
  468. {"role": "user","content": [{"type": "text","text": message_text}, {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},]}
  469. ]
  470. else:
  471. messages = [
  472. {"role": "system", "content": [{"type": "text","text": "You are an AI assistant that is great at interpreting screenshot and predict action."},]},
  473. history,
  474. {"role": "user","content": [{"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},{"type": "text","text": message_text},]}
  475. ]
  476. payload = {
  477. "messages": messages,
  478. "temperature": 0.01, # 0.01
  479. "top_p": 0.95,
  480. "max_tokens": 800
  481. }
  482. max_num_trial = 3
  483. num_trial = 0
  484. call_api_success = True
  485. while num_trial < max_num_trial:
  486. try:
  487. # response = requests.post(GPT4V_ENDPOINT, headers=headers, json=payload)
  488. # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
  489. # ans_1st_pass = response.json()['choices'][0]['message']['content']
  490. response = client.chat.completions.create(
  491. model=deployment,
  492. messages=messages,
  493. temperature=0.01,
  494. max_tokens=512,
  495. )
  496. ans_1st_pass = response.choices[0].message.content
  497. break
  498. except requests.RequestException as e:
  499. print('retry call gptv', num_trial)
  500. num_trial += 1
  501. ans_1st_pass = ''
  502. time.sleep(30)
  503. # raise SystemExit(f"Failed to make the request. Error: {e}")
  504. if num_trial == max_num_trial:
  505. call_api_success = False
  506. if verbose:
  507. print('Answer by GPTV: ', ans_1st_pass)
  508. # extract by simple parsing
  509. try:
  510. match = re.search(r"```(.*?)```", ans_1st_pass, re.DOTALL)
  511. if match:
  512. result = match.group(1).strip()
  513. pred = result.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
  514. pred = ast.literal_eval(pred)
  515. else:
  516. pred = ans_1st_pass.split('In summary, the next action I will perform is:')[-1].strip().replace('\\', '')
  517. pred = ast.literal_eval(pred)
  518. if id_key in pred:
  519. icon_id = pred[id_key]
  520. bbox = label_coordinates[str(icon_id)]
  521. pred['click_point'] = [bbox[0] + bbox[2]/2, bbox[1] + bbox[3]/2]
  522. except:
  523. # import pdb; pdb.set_trace()
  524. print('gptv action regex extract fail!!!')
  525. print('ans_1st_pass:', ans_1st_pass)
  526. pred = {'action_type': 'CLICK', 'click_point': [0, 0], 'value': 'None', 'is_completed': False}
  527. step_pred_summary = None
  528. if summarize_history:
  529. step_pred_summary, _ = call_gpt4v_new('Summarize what action you decide to perform in the current step, in one sentence, and do not include any icon box number: ' + ans_1st_pass, max_tokens=128)
  530. print('step_pred_summary', step_pred_summary)
  531. return pred, [call_api_success, ans_1st_pass, None, step_pred_summary]
  532. # return pred, [call_api_success, message_2nd, completion_2nd.choices[0].message.content, step_pred_summary]