gradio_demo.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from typing import Optional
  2. import gradio as gr
  3. import numpy as np
  4. import torch
  5. from PIL import Image
  6. import io
  7. import base64, os
  8. from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
  9. import torch
  10. from PIL import Image
  11. yolo_model = get_yolo_model()
  12. caption_model_processor = get_caption_model_processor('florence', device='cuda') # 'blip2-opt-2.7b-ui', phi3v_ui florence
  13. platform = 'pc'
  14. if platform == 'pc':
  15. draw_bbox_config = {
  16. 'text_scale': 0.8,
  17. 'text_thickness': 2,
  18. 'text_padding': 2,
  19. 'thickness': 2,
  20. }
  21. BOX_TRESHOLD = 0.05
  22. elif platform == 'web':
  23. draw_bbox_config = {
  24. 'text_scale': 0.8,
  25. 'text_thickness': 2,
  26. 'text_padding': 3,
  27. 'thickness': 3,
  28. }
  29. BOX_TRESHOLD = 0.05
  30. elif platform == 'mobile':
  31. draw_bbox_config = {
  32. 'text_scale': 0.8,
  33. 'text_thickness': 2,
  34. 'text_padding': 3,
  35. 'thickness': 3,
  36. }
  37. BOX_TRESHOLD = 0.05
  38. MARKDOWN = """
  39. # OmniParser for Pure Vision Based General GUI Agent 🔥
  40. <div>
  41. <a href="https://arxiv.org/pdf/2408.00203">
  42. <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
  43. </a>
  44. </div>
  45. OmniParser is a screen parsing tool to convert general GUI screen to structured elements. **Trained models will be released soon**
  46. """
  47. DEVICE = torch.device('cuda')
  48. # @spaces.GPU
  49. # @torch.inference_mode()
  50. # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
  51. def process(
  52. image_input,
  53. prompt: str = None
  54. ) -> Optional[Image.Image]:
  55. image_path = "/home/yadonglu/sandbox/data/omniparser_demo/image_input.png"
  56. image_input.save(image_path)
  57. # import pdb; pdb.set_trace()
  58. ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9})
  59. text, ocr_bbox = ocr_bbox_rslt
  60. print('prompt:', prompt)
  61. dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, yolo_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=0.3,prompt=prompt)
  62. image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
  63. print('finish processing')
  64. parsed_content_list = '\n'.join(parsed_content_list)
  65. return image, str(parsed_content_list)
  66. with gr.Blocks() as demo:
  67. gr.Markdown(MARKDOWN)
  68. with gr.Row():
  69. with gr.Column():
  70. image_input_component = gr.Image(
  71. type='pil', label='Upload image')
  72. prompt_input_component = gr.Textbox(label='Prompt', placeholder='')
  73. submit_button_component = gr.Button(
  74. value='Submit', variant='primary')
  75. with gr.Column():
  76. image_output_component = gr.Image(type='pil', label='Image Output')
  77. text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
  78. submit_button_component.click(
  79. fn=process,
  80. inputs=[
  81. image_input_component,
  82. prompt_input_component,
  83. ],
  84. outputs=[image_output_component, text_output_component]
  85. )
  86. # demo.launch(debug=False, show_error=True, share=True)
  87. demo.launch(share=True, server_port=7861, server_name='0.0.0.0')