omniparser.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_dino_model, get_yolo_model
  2. import torch
  3. from ultralytics import YOLO
  4. from PIL import Image
  5. from typing import Dict, Tuple, List
  6. import io
  7. import base64
  8. config = {
  9. 'som_model_path': 'finetuned_icon_detect.pt',
  10. 'device': 'cpu',
  11. 'caption_model_path': 'Salesforce/blip2-opt-2.7b',
  12. 'draw_bbox_config': {
  13. 'text_scale': 0.8,
  14. 'text_thickness': 2,
  15. 'text_padding': 3,
  16. 'thickness': 3,
  17. },
  18. 'BOX_TRESHOLD': 0.05
  19. }
  20. class Omniparser(object):
  21. def __init__(self, config: Dict):
  22. self.config = config
  23. self.som_model = get_yolo_model(model_path=config['som_model_path'])
  24. # self.caption_model_processor = get_caption_model_processor(config['caption_model_path'], device=cofig['device'])
  25. # self.caption_model_processor['model'].to(torch.float32)
  26. def parse(self, image_path: str):
  27. print('Parsing image:', image_path)
  28. ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
  29. text, ocr_bbox = ocr_bbox_rslt
  30. draw_bbox_config = self.config['draw_bbox_config']
  31. BOX_TRESHOLD = self.config['BOX_TRESHOLD']
  32. dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, self.som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=None, ocr_text=text,use_local_semantics=False)
  33. image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
  34. # formating output
  35. return_list = [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
  36. 'text': parsed_content_list[i].split(': ')[1], 'type':'text'} for i, (k, coord) in enumerate(label_coordinates.items()) if i < len(parsed_content_list)]
  37. return_list.extend(
  38. [{'from': 'omniparser', 'shape': {'x':coord[0], 'y':coord[1], 'width':coord[2], 'height':coord[3]},
  39. 'text': 'None', 'type':'icon'} for i, (k, coord) in enumerate(label_coordinates.items()) if i >= len(parsed_content_list)]
  40. )
  41. return [image, return_list]
  42. parser = Omniparser(config)
  43. image_path = 'examples/pc_1.png'
  44. # time the parser
  45. import time
  46. s = time.time()
  47. image, parsed_content_list = parser.parse(image_path)
  48. device = config['device']
  49. print(f'Time taken for Omniparser on {device}:', time.time() - s)