grounded_sam_demo.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import argparse
  2. import os
  3. import sys
  4. import numpy as np
  5. import json
  6. import torch
  7. from PIL import Image
  8. sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
  9. sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
  10. # Grounding DINO
  11. import GroundingDINO.groundingdino.datasets.transforms as T
  12. from GroundingDINO.groundingdino.models import build_model
  13. from GroundingDINO.groundingdino.util.slconfig import SLConfig
  14. from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
  15. # segment anything
  16. from segment_anything import (
  17. sam_model_registry,
  18. sam_hq_model_registry,
  19. SamPredictor
  20. )
  21. import cv2
  22. import numpy as np
  23. import matplotlib.pyplot as plt
  24. def load_image(image_path):
  25. # load image
  26. image_pil = Image.open(image_path).convert("RGB") # load image
  27. transform = T.Compose(
  28. [
  29. T.RandomResize([800], max_size=1333),
  30. T.ToTensor(),
  31. T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  32. ]
  33. )
  34. image, _ = transform(image_pil, None) # 3, h, w
  35. return image_pil, image
  36. def load_model(model_config_path, model_checkpoint_path, bert_base_uncased_path, device):
  37. args = SLConfig.fromfile(model_config_path)
  38. args.device = device
  39. args.bert_base_uncased_path = bert_base_uncased_path
  40. model = build_model(args)
  41. checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
  42. load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
  43. print(load_res)
  44. _ = model.eval()
  45. return model
  46. def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
  47. caption = caption.lower()
  48. caption = caption.strip()
  49. if not caption.endswith("."):
  50. caption = caption + "."
  51. model = model.to(device)
  52. image = image.to(device)
  53. with torch.no_grad():
  54. outputs = model(image[None], captions=[caption])
  55. logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
  56. boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
  57. logits.shape[0]
  58. # filter output
  59. logits_filt = logits.clone()
  60. boxes_filt = boxes.clone()
  61. filt_mask = logits_filt.max(dim=1)[0] > box_threshold
  62. logits_filt = logits_filt[filt_mask] # num_filt, 256
  63. boxes_filt = boxes_filt[filt_mask] # num_filt, 4
  64. logits_filt.shape[0]
  65. # get phrase
  66. tokenlizer = model.tokenizer
  67. tokenized = tokenlizer(caption)
  68. # build pred
  69. pred_phrases = []
  70. for logit, box in zip(logits_filt, boxes_filt):
  71. pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
  72. if with_logits:
  73. pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
  74. else:
  75. pred_phrases.append(pred_phrase)
  76. return boxes_filt, pred_phrases
  77. def show_mask(mask, ax, random_color=False):
  78. if random_color:
  79. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  80. else:
  81. color = np.array([30/255, 144/255, 255/255, 0.6])
  82. h, w = mask.shape[-2:]
  83. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  84. ax.imshow(mask_image)
  85. def show_box(box, ax, label):
  86. x0, y0 = box[0], box[1]
  87. w, h = box[2] - box[0], box[3] - box[1]
  88. ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
  89. ax.text(x0, y0, label)
  90. def save_mask_data(output_dir, mask_list, box_list, label_list):
  91. value = 0 # 0 for background
  92. mask_img = torch.zeros(mask_list.shape[-2:])
  93. for idx, mask in enumerate(mask_list):
  94. mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
  95. plt.figure(figsize=(10, 10))
  96. plt.imshow(mask_img.numpy())
  97. plt.axis('off')
  98. plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
  99. json_data = [{
  100. 'value': value,
  101. 'label': 'background'
  102. }]
  103. for label, box in zip(label_list, box_list):
  104. value += 1
  105. name, logit = label.split('(')
  106. logit = logit[:-1] # the last is ')'
  107. json_data.append({
  108. 'value': value,
  109. 'label': name,
  110. 'logit': float(logit),
  111. 'box': box.numpy().tolist(),
  112. })
  113. with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
  114. json.dump(json_data, f)
  115. if __name__ == "__main__":
  116. parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
  117. parser.add_argument("--config", type=str, required=True, help="path to config file")
  118. parser.add_argument(
  119. "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
  120. )
  121. parser.add_argument(
  122. "--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h"
  123. )
  124. parser.add_argument(
  125. "--sam_checkpoint", type=str, required=False, help="path to sam checkpoint file"
  126. )
  127. parser.add_argument(
  128. "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
  129. )
  130. parser.add_argument(
  131. "--use_sam_hq", action="store_true", help="using sam-hq for prediction"
  132. )
  133. parser.add_argument("--input_image", type=str, required=True, help="path to image file")
  134. parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
  135. parser.add_argument(
  136. "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
  137. )
  138. parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
  139. parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
  140. parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
  141. parser.add_argument("--bert_base_uncased_path", type=str, required=False, help="bert_base_uncased model path, default=False")
  142. args = parser.parse_args()
  143. # cfg
  144. config_file = args.config # change the path of the model config file
  145. grounded_checkpoint = args.grounded_checkpoint # change the path of the model
  146. sam_version = args.sam_version
  147. sam_checkpoint = args.sam_checkpoint
  148. sam_hq_checkpoint = args.sam_hq_checkpoint
  149. use_sam_hq = args.use_sam_hq
  150. image_path = args.input_image
  151. text_prompt = args.text_prompt
  152. output_dir = args.output_dir
  153. box_threshold = args.box_threshold
  154. text_threshold = args.text_threshold
  155. device = args.device
  156. bert_base_uncased_path = args.bert_base_uncased_path
  157. # make dir
  158. os.makedirs(output_dir, exist_ok=True)
  159. # load image
  160. image_pil, image = load_image(image_path)
  161. # load model
  162. model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device)
  163. # visualize raw image
  164. image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
  165. # run grounding dino model
  166. boxes_filt, pred_phrases = get_grounding_output(
  167. model, image, text_prompt, box_threshold, text_threshold, device=device
  168. )
  169. # initialize SAM
  170. if use_sam_hq:
  171. predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
  172. else:
  173. predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
  174. image = cv2.imread(image_path)
  175. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  176. predictor.set_image(image)
  177. size = image_pil.size
  178. H, W = size[1], size[0]
  179. for i in range(boxes_filt.size(0)):
  180. boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
  181. boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
  182. boxes_filt[i][2:] += boxes_filt[i][:2]
  183. boxes_filt = boxes_filt.cpu()
  184. transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
  185. masks, _, _ = predictor.predict_torch(
  186. point_coords = None,
  187. point_labels = None,
  188. boxes = transformed_boxes.to(device),
  189. multimask_output = False,
  190. )
  191. # draw output image
  192. plt.figure(figsize=(10, 10))
  193. plt.imshow(image)
  194. for mask in masks:
  195. show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
  196. for box, label in zip(boxes_filt, pred_phrases):
  197. show_box(box.numpy(), plt.gca(), label)
  198. plt.axis('off')
  199. plt.savefig(
  200. os.path.join(output_dir, "grounded_sam_output.jpg"),
  201. bbox_inches="tight", dpi=300, pad_inches=0.0
  202. )
  203. save_mask_data(output_dir, masks, boxes_filt, pred_phrases)