automatic_label_ram_demo.py 11 KB


  1. import argparse
  2. import os
  3. import numpy as np
  4. import json
  5. import torch
  6. import torchvision
  7. from PIL import Image
  8. import litellm
  9. # Grounding DINO
  10. import GroundingDINO.groundingdino.datasets.transforms as T
  11. from GroundingDINO.groundingdino.models import build_model
  12. from GroundingDINO.groundingdino.util.slconfig import SLConfig
  13. from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
  14. # segment anything
  15. from segment_anything import (
  16. build_sam,
  17. build_sam_hq,
  18. SamPredictor
  19. )
  20. import cv2
  21. import numpy as np
  22. import matplotlib.pyplot as plt
  23. # Recognize Anything Model & Tag2Text
  24. from ram.models import ram
  25. from ram import inference_ram
  26. import torchvision.transforms as TS
  27. # ChatGPT or nltk is required when using tags_chineses
  28. # import openai
  29. # import nltk
  30. def load_image(image_path):
  31. # load image
  32. image_pil = Image.open(image_path).convert("RGB") # load image
  33. transform = T.Compose(
  34. [
  35. T.RandomResize([800], max_size=1333),
  36. T.ToTensor(),
  37. T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
  38. ]
  39. )
  40. image, _ = transform(image_pil, None) # 3, h, w
  41. return image_pil, image
  42. def check_tags_chinese(tags_chinese, pred_phrases, max_tokens=100, model="gpt-3.5-turbo"):
  43. object_list = [obj.split('(')[0] for obj in pred_phrases]
  44. object_num = []
  45. for obj in set(object_list):
  46. object_num.append(f'{object_list.count(obj)} {obj}')
  47. object_num = ', '.join(object_num)
  48. print(f"Correct object number: {object_num}")
  49. if openai_key:
  50. prompt = [
  51. {
  52. 'role': 'system',
  53. 'content': 'Revise the number in the tags_chinese if it is wrong. ' + \
  54. f'tags_chinese: {tags_chinese}. ' + \
  55. f'True object number: {object_num}. ' + \
  56. 'Only give the revised tags_chinese: '
  57. }
  58. ]
  59. response = litellm.completion(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
  60. reply = response['choices'][0]['message']['content']
  61. # sometimes return with "tags_chinese: xxx, xxx, xxx"
  62. tags_chinese = reply.split(':')[-1].strip()
  63. return tags_chinese
  64. def load_model(model_config_path, model_checkpoint_path, device):
  65. args = SLConfig.fromfile(model_config_path)
  66. args.device = device
  67. model = build_model(args)
  68. checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
  69. load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
  70. print(load_res)
  71. _ = model.eval()
  72. return model
  73. def get_grounding_output(model, image, caption, box_threshold, text_threshold,device="cpu"):
  74. caption = caption.lower()
  75. caption = caption.strip()
  76. if not caption.endswith("."):
  77. caption = caption + "."
  78. model = model.to(device)
  79. image = image.to(device)
  80. with torch.no_grad():
  81. outputs = model(image[None], captions=[caption])
  82. logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
  83. boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
  84. logits.shape[0]
  85. # filter output
  86. logits_filt = logits.clone()
  87. boxes_filt = boxes.clone()
  88. filt_mask = logits_filt.max(dim=1)[0] > box_threshold
  89. logits_filt = logits_filt[filt_mask] # num_filt, 256
  90. boxes_filt = boxes_filt[filt_mask] # num_filt, 4
  91. logits_filt.shape[0]
  92. # get phrase
  93. tokenlizer = model.tokenizer
  94. tokenized = tokenlizer(caption)
  95. # build pred
  96. pred_phrases = []
  97. scores = []
  98. for logit, box in zip(logits_filt, boxes_filt):
  99. pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
  100. pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
  101. scores.append(logit.max().item())
  102. return boxes_filt, torch.Tensor(scores), pred_phrases
  103. def show_mask(mask, ax, random_color=False):
  104. if random_color:
  105. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  106. else:
  107. color = np.array([30/255, 144/255, 255/255, 0.6])
  108. h, w = mask.shape[-2:]
  109. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  110. ax.imshow(mask_image)
  111. def show_box(box, ax, label):
  112. x0, y0 = box[0], box[1]
  113. w, h = box[2] - box[0], box[3] - box[1]
  114. ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
  115. ax.text(x0, y0, label)
  116. def save_mask_data(output_dir, tags_chinese, mask_list, box_list, label_list):
  117. value = 0 # 0 for background
  118. mask_img = torch.zeros(mask_list.shape[-2:])
  119. for idx, mask in enumerate(mask_list):
  120. mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
  121. plt.figure(figsize=(10, 10))
  122. plt.imshow(mask_img.numpy())
  123. plt.axis('off')
  124. plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)
  125. json_data = {
  126. 'tags_chinese': tags_chinese,
  127. 'mask':[{
  128. 'value': value,
  129. 'label': 'background'
  130. }]
  131. }
  132. for label, box in zip(label_list, box_list):
  133. value += 1
  134. name, logit = label.split('(')
  135. logit = logit[:-1] # the last is ')'
  136. json_data['mask'].append({
  137. 'value': value,
  138. 'label': name,
  139. 'logit': float(logit),
  140. 'box': box.numpy().tolist(),
  141. })
  142. with open(os.path.join(output_dir, 'label.json'), 'w') as f:
  143. json.dump(json_data, f)
  144. if __name__ == "__main__":
  145. parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
  146. parser.add_argument("--config", type=str, required=True, help="path to config file")
  147. parser.add_argument(
  148. "--ram_checkpoint", type=str, required=True, help="path to checkpoint file"
  149. )
  150. parser.add_argument(
  151. "--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
  152. )
  153. parser.add_argument(
  154. "--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
  155. )
  156. parser.add_argument(
  157. "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
  158. )
  159. parser.add_argument(
  160. "--use_sam_hq", action="store_true", help="using sam-hq for prediction"
  161. )
  162. parser.add_argument("--input_image", type=str, required=True, help="path to image file")
  163. parser.add_argument("--split", default=",", type=str, help="split for text prompt")
  164. parser.add_argument("--openai_key", type=str, help="key for chatgpt")
  165. parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
  166. parser.add_argument(
  167. "--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
  168. )
  169. parser.add_argument("--box_threshold", type=float, default=0.25, help="box threshold")
  170. parser.add_argument("--text_threshold", type=float, default=0.2, help="text threshold")
  171. parser.add_argument("--iou_threshold", type=float, default=0.5, help="iou threshold")
  172. parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
  173. args = parser.parse_args()
  174. # cfg
  175. config_file = args.config # change the path of the model config file
  176. ram_checkpoint = args.ram_checkpoint # change the path of the model
  177. grounded_checkpoint = args.grounded_checkpoint # change the path of the model
  178. sam_checkpoint = args.sam_checkpoint
  179. sam_hq_checkpoint = args.sam_hq_checkpoint
  180. use_sam_hq = args.use_sam_hq
  181. image_path = args.input_image
  182. split = args.split
  183. openai_key = args.openai_key
  184. openai_proxy = args.openai_proxy
  185. output_dir = args.output_dir
  186. box_threshold = args.box_threshold
  187. text_threshold = args.text_threshold
  188. iou_threshold = args.iou_threshold
  189. device = args.device
  190. # ChatGPT or nltk is required when using tags_chineses
  191. # openai.api_key = openai_key
  192. # if openai_proxy:
  193. # openai.proxy = {"http": openai_proxy, "https": openai_proxy}
  194. # make dir
  195. os.makedirs(output_dir, exist_ok=True)
  196. # load image
  197. image_pil, image = load_image(image_path)
  198. # load model
  199. model = load_model(config_file, grounded_checkpoint, device=device)
  200. # visualize raw image
  201. image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
  202. # initialize Recognize Anything Model
  203. normalize = TS.Normalize(mean=[0.485, 0.456, 0.406],
  204. std=[0.229, 0.224, 0.225])
  205. transform = TS.Compose([
  206. TS.Resize((384, 384)),
  207. TS.ToTensor(), normalize
  208. ])
  209. # load model
  210. ram_model = ram(pretrained=ram_checkpoint,
  211. image_size=384,
  212. vit='swin_l')
  213. # threshold for tagging
  214. # we reduce the threshold to obtain more tags
  215. ram_model.eval()
  216. ram_model = ram_model.to(device)
  217. raw_image = image_pil.resize(
  218. (384, 384))
  219. raw_image = transform(raw_image).unsqueeze(0).to(device)
  220. res = inference_ram(raw_image , ram_model)
  221. # Currently ", " is better for detecting single tags
  222. # while ". " is a little worse in some case
  223. tags=res[0].replace(' |', ',')
  224. tags_chinese=res[1].replace(' |', ',')
  225. print("Image Tags: ", res[0])
  226. print("图像标签: ", res[1])
  227. # run grounding dino model
  228. boxes_filt, scores, pred_phrases = get_grounding_output(
  229. model, image, tags, box_threshold, text_threshold, device=device
  230. )
  231. # initialize SAM
  232. if use_sam_hq:
  233. print("Initialize SAM-HQ Predictor")
  234. predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device))
  235. else:
  236. predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
  237. image = cv2.imread(image_path)
  238. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  239. predictor.set_image(image)
  240. size = image_pil.size
  241. H, W = size[1], size[0]
  242. for i in range(boxes_filt.size(0)):
  243. boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
  244. boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
  245. boxes_filt[i][2:] += boxes_filt[i][:2]
  246. boxes_filt = boxes_filt.cpu()
  247. # use NMS to handle overlapped boxes
  248. print(f"Before NMS: {boxes_filt.shape[0]} boxes")
  249. nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
  250. boxes_filt = boxes_filt[nms_idx]
  251. pred_phrases = [pred_phrases[idx] for idx in nms_idx]
  252. print(f"After NMS: {boxes_filt.shape[0]} boxes")
  253. tags_chinese = check_tags_chinese(tags_chinese, pred_phrases)
  254. print(f"Revise tags_chinese with number: {tags_chinese}")
  255. transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
  256. masks, _, _ = predictor.predict_torch(
  257. point_coords = None,
  258. point_labels = None,
  259. boxes = transformed_boxes.to(device),
  260. multimask_output = False,
  261. )
  262. # draw output image
  263. plt.figure(figsize=(10, 10))
  264. plt.imshow(image)
  265. for mask in masks:
  266. show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
  267. for box, label in zip(boxes_filt, pred_phrases):
  268. show_box(box.numpy(), plt.gca(), label)
  269. # plt.title('RAM-tags' + tags + '\n' + 'RAM-tags_chineseing: ' + tags_chinese + '\n')
  270. plt.axis('off')
  271. plt.savefig(
  272. os.path.join(output_dir, "automatic_label_output.jpg"),
  273. bbox_inches="tight", dpi=300, pad_inches=0.0
  274. )
  275. save_mask_data(output_dir, tags_chinese, masks, boxes_filt, pred_phrases)