123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- from typing import Optional
- import gradio as gr
- import numpy as np
- import torch
- from PIL import Image
- import io
- import base64, os
- from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
- import torch
- from PIL import Image
- yolo_model = get_yolo_model()
- caption_model_processor = get_caption_model_processor('florence', device='cuda') # 'blip2-opt-2.7b-ui', phi3v_ui florence
- platform = 'pc'
- if platform == 'pc':
- draw_bbox_config = {
- 'text_scale': 0.8,
- 'text_thickness': 2,
- 'text_padding': 2,
- 'thickness': 2,
- }
- BOX_TRESHOLD = 0.05
- elif platform == 'web':
- draw_bbox_config = {
- 'text_scale': 0.8,
- 'text_thickness': 2,
- 'text_padding': 3,
- 'thickness': 3,
- }
- BOX_TRESHOLD = 0.05
- elif platform == 'mobile':
- draw_bbox_config = {
- 'text_scale': 0.8,
- 'text_thickness': 2,
- 'text_padding': 3,
- 'thickness': 3,
- }
- BOX_TRESHOLD = 0.05
- MARKDOWN = """
- # OmniParser for Pure Vision Based General GUI Agent 🔥
- <div>
- <a href="https://arxiv.org/pdf/2408.00203">
- <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
- </a>
- </div>
- OmniParser is a screen parsing tool to convert general GUI screen to structured elements. **Trained models will be released soon**
- """
- DEVICE = torch.device('cuda')
- # @spaces.GPU
- # @torch.inference_mode()
- # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
- def process(
- image_input,
- prompt: str = None
- ) -> Optional[Image.Image]:
- image_path = "/home/yadonglu/sandbox/data/omniparser_demo/image_input.png"
- image_input.save(image_path)
- # import pdb; pdb.set_trace()
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
- text, ocr_bbox = ocr_bbox_rslt
- print('prompt:', prompt)
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, yolo_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=0.3,prompt=prompt)
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
- print('finish processing')
- parsed_content_list = '\n'.join(parsed_content_list)
- return image, str(parsed_content_list)
- with gr.Blocks() as demo:
- gr.Markdown(MARKDOWN)
- with gr.Row():
- with gr.Column():
- image_input_component = gr.Image(
- type='pil', label='Upload image')
- prompt_input_component = gr.Textbox(label='Prompt', placeholder='')
- submit_button_component = gr.Button(
- value='Submit', variant='primary')
- with gr.Column():
- image_output_component = gr.Image(type='pil', label='Image Output')
- text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
- submit_button_component.click(
- fn=process,
- inputs=[
- image_input_component,
- prompt_input_component,
- ],
- outputs=[image_output_component, text_output_component]
- )
- # demo.launch(debug=False, show_error=True, share=True)
- demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
|