automatic_label_simple_demo.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import cv2
  2. import numpy as np
  3. import supervision as sv
  4. from typing import List
  5. from PIL import Image
  6. import torch
  7. from groundingdino.util.inference import Model
  8. from segment_anything import sam_model_registry, SamPredictor
  9. # Tag2Text
  10. # from ram.models import tag2text_caption
  11. from ram.models import ram
  12. # from ram import inference_tag2text
  13. from ram import inference_ram
  14. import torchvision
  15. import torchvision.transforms as TS
  16. # Hyper-Params
  17. SOURCE_IMAGE_PATH = "./assets/demo9.jpg"
  18. DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  19. GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
  20. GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
  21. SAM_ENCODER_VERSION = "vit_h"
  22. SAM_CHECKPOINT_PATH = "./sam_vit_h_4b8939.pth"
  23. TAG2TEXT_CHECKPOINT_PATH = "./tag2text_swin_14m.pth"
  24. RAM_CHECKPOINT_PATH = "./ram_swin_large_14m.pth"
  25. TAG2TEXT_THRESHOLD = 0.64
  26. BOX_THRESHOLD = 0.2
  27. TEXT_THRESHOLD = 0.2
  28. IOU_THRESHOLD = 0.5
  29. # Building GroundingDINO inference model
  30. grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
  31. # Building SAM Model and SAM Predictor
  32. sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
  33. sam_predictor = SamPredictor(sam)
  34. # Tag2Text
  35. # initialize Tag2Text
  36. normalize = TS.Normalize(
  37. mean=[0.485, 0.456, 0.406],
  38. std=[0.229, 0.224, 0.225]
  39. )
  40. transform = TS.Compose(
  41. [
  42. TS.Resize((384, 384)),
  43. TS.ToTensor(),
  44. normalize
  45. ]
  46. )
  47. DELETE_TAG_INDEX = [] # filter out attributes and action which are difficult to be grounded
  48. for idx in range(3012, 3429):
  49. DELETE_TAG_INDEX.append(idx)
  50. # tag2text_model = tag2text_caption(
  51. # pretrained=TAG2TEXT_CHECKPOINT_PATH,
  52. # image_size=384,
  53. # vit='swin_b',
  54. # delete_tag_index=DELETE_TAG_INDEX
  55. # )
  56. # # threshold for tagging
  57. # # we reduce the threshold to obtain more tags
  58. # tag2text_model.threshold = TAG2TEXT_THRESHOLD
  59. # tag2text_model.eval()
  60. # tag2text_model = tag2text_model.to(DEVICE)
  61. ram_model = ram(pretrained=RAM_CHECKPOINT_PATH,
  62. image_size=384,
  63. vit='swin_l')
  64. ram_model.eval()
  65. ram_model = ram_model.to(DEVICE)
  66. # load image
  67. image = cv2.imread(SOURCE_IMAGE_PATH) # bgr
  68. image_pillow = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # rgb
  69. image_pillow = image_pillow.resize((384, 384))
  70. image_pillow = transform(image_pillow).unsqueeze(0).to(DEVICE)
  71. specified_tags='None'
  72. # res = inference_tag2text(image_pillow , tag2text_model, specified_tags)
  73. res = inference_ram(image_pillow , ram_model)
  74. # Currently ", " is better for detecting single tags
  75. # while ". " is a little worse in some case
  76. AUTOMATIC_CLASSES=res[0].split(" | ")
  77. print(f"Tags: {res[0].replace(' |', ',')}")
  78. # detect objects
  79. detections = grounding_dino_model.predict_with_classes(
  80. image=image,
  81. classes=AUTOMATIC_CLASSES,
  82. box_threshold=BOX_THRESHOLD,
  83. text_threshold=BOX_THRESHOLD
  84. )
  85. # NMS post process
  86. print(f"Before NMS: {len(detections.xyxy)} boxes")
  87. nms_idx = torchvision.ops.nms(
  88. torch.from_numpy(detections.xyxy),
  89. torch.from_numpy(detections.confidence),
  90. IOU_THRESHOLD
  91. ).numpy().tolist()
  92. detections.xyxy = detections.xyxy[nms_idx]
  93. detections.confidence = detections.confidence[nms_idx]
  94. detections.class_id = detections.class_id[nms_idx]
  95. print(f"After NMS: {len(detections.xyxy)} boxes")
  96. # annotate image with detections
  97. box_annotator = sv.BoxAnnotator()
  98. labels = [
  99. f"{AUTOMATIC_CLASSES[class_id]} {confidence:0.2f}"
  100. for _, _, confidence, class_id, _, _
  101. in detections]
  102. annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
  103. # save the annotated grounding dino image
  104. cv2.imwrite("groundingdino_auto_annotated_image.jpg", annotated_frame)
  105. # Prompting SAM with detected boxes
  106. def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
  107. sam_predictor.set_image(image)
  108. result_masks = []
  109. for box in xyxy:
  110. masks, scores, logits = sam_predictor.predict(
  111. box=box,
  112. multimask_output=True
  113. )
  114. index = np.argmax(scores)
  115. result_masks.append(masks[index])
  116. return np.array(result_masks)
  117. # convert detections to masks
  118. detections.mask = segment(
  119. sam_predictor=sam_predictor,
  120. image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
  121. xyxy=detections.xyxy
  122. )
  123. # annotate image with detections
  124. box_annotator = sv.BoxAnnotator()
  125. mask_annotator = sv.MaskAnnotator()
  126. labels = [
  127. f"{AUTOMATIC_CLASSES[class_id]} {confidence:0.2f}"
  128. for _, _, confidence, class_id, _, _
  129. in detections]
  130. annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
  131. annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
  132. # save the annotated grounded-sam image
  133. cv2.imwrite("ram_grounded_sam_auto_annotated_image.jpg", annotated_image)