box_annotator.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from typing import List, Optional, Union, Tuple
  2. import cv2
  3. import numpy as np
  4. from supervision.detection.core import Detections
  5. from supervision.draw.color import Color, ColorPalette
  6. class BoxAnnotator:
  7. """
  8. A class for drawing bounding boxes on an image using detections provided.
  9. Attributes:
  10. color (Union[Color, ColorPalette]): The color to draw the bounding box,
  11. can be a single color or a color palette
  12. thickness (int): The thickness of the bounding box lines, default is 2
  13. text_color (Color): The color of the text on the bounding box, default is white
  14. text_scale (float): The scale of the text on the bounding box, default is 0.5
  15. text_thickness (int): The thickness of the text on the bounding box,
  16. default is 1
  17. text_padding (int): The padding around the text on the bounding box,
  18. default is 5
  19. """
  20. def __init__(
  21. self,
  22. color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
  23. thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
  24. text_color: Color = Color.BLACK,
  25. text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
  26. text_thickness: int = 2, #1, # 2 for demo
  27. text_padding: int = 10,
  28. avoid_overlap: bool = True,
  29. ):
  30. self.color: Union[Color, ColorPalette] = color
  31. self.thickness: int = thickness
  32. self.text_color: Color = text_color
  33. self.text_scale: float = text_scale
  34. self.text_thickness: int = text_thickness
  35. self.text_padding: int = text_padding
  36. self.avoid_overlap: bool = avoid_overlap
  37. def annotate(
  38. self,
  39. scene: np.ndarray,
  40. detections: Detections,
  41. labels: Optional[List[str]] = None,
  42. skip_label: bool = False,
  43. image_size: Optional[Tuple[int, int]] = None,
  44. ) -> np.ndarray:
  45. """
  46. Draws bounding boxes on the frame using the detections provided.
  47. Args:
  48. scene (np.ndarray): The image on which the bounding boxes will be drawn
  49. detections (Detections): The detections for which the
  50. bounding boxes will be drawn
  51. labels (Optional[List[str]]): An optional list of labels
  52. corresponding to each detection. If `labels` are not provided,
  53. corresponding `class_id` will be used as label.
  54. skip_label (bool): Is set to `True`, skips bounding box label annotation.
  55. Returns:
  56. np.ndarray: The image with the bounding boxes drawn on it
  57. Example:
  58. ```python
  59. import supervision as sv
  60. classes = ['person', ...]
  61. image = ...
  62. detections = sv.Detections(...)
  63. box_annotator = sv.BoxAnnotator()
  64. labels = [
  65. f"{classes[class_id]} {confidence:0.2f}"
  66. for _, _, confidence, class_id, _ in detections
  67. ]
  68. annotated_frame = box_annotator.annotate(
  69. scene=image.copy(),
  70. detections=detections,
  71. labels=labels
  72. )
  73. ```
  74. """
  75. font = cv2.FONT_HERSHEY_SIMPLEX
  76. for i in range(len(detections)):
  77. x1, y1, x2, y2 = detections.xyxy[i].astype(int)
  78. class_id = (
  79. detections.class_id[i] if detections.class_id is not None else None
  80. )
  81. idx = class_id if class_id is not None else i
  82. color = (
  83. self.color.by_idx(idx)
  84. if isinstance(self.color, ColorPalette)
  85. else self.color
  86. )
  87. cv2.rectangle(
  88. img=scene,
  89. pt1=(x1, y1),
  90. pt2=(x2, y2),
  91. color=color.as_bgr(),
  92. thickness=self.thickness,
  93. )
  94. if skip_label:
  95. continue
  96. text = (
  97. f"{class_id}"
  98. if (labels is None or len(detections) != len(labels))
  99. else labels[i]
  100. )
  101. text_width, text_height = cv2.getTextSize(
  102. text=text,
  103. fontFace=font,
  104. fontScale=self.text_scale,
  105. thickness=self.text_thickness,
  106. )[0]
  107. if not self.avoid_overlap:
  108. text_x = x1 + self.text_padding
  109. text_y = y1 - self.text_padding
  110. text_background_x1 = x1
  111. text_background_y1 = y1 - 2 * self.text_padding - text_height
  112. text_background_x2 = x1 + 2 * self.text_padding + text_width
  113. text_background_y2 = y1
  114. # text_x = x1 - self.text_padding - text_width
  115. # text_y = y1 + self.text_padding + text_height
  116. # text_background_x1 = x1 - 2 * self.text_padding - text_width
  117. # text_background_y1 = y1
  118. # text_background_x2 = x1
  119. # text_background_y2 = y1 + 2 * self.text_padding + text_height
  120. else:
  121. text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
  122. cv2.rectangle(
  123. img=scene,
  124. pt1=(text_background_x1, text_background_y1),
  125. pt2=(text_background_x2, text_background_y2),
  126. color=color.as_bgr(),
  127. thickness=cv2.FILLED,
  128. )
  129. # import pdb; pdb.set_trace()
  130. box_color = color.as_rgb()
  131. luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
  132. text_color = (0,0,0) if luminance > 160 else (255,255,255)
  133. cv2.putText(
  134. img=scene,
  135. text=text,
  136. org=(text_x, text_y),
  137. fontFace=font,
  138. fontScale=self.text_scale,
  139. # color=self.text_color.as_rgb(),
  140. color=text_color,
  141. thickness=self.text_thickness,
  142. lineType=cv2.LINE_AA,
  143. )
  144. return scene
  145. def box_area(box):
  146. return (box[2] - box[0]) * (box[3] - box[1])
  147. def intersection_area(box1, box2):
  148. x1 = max(box1[0], box2[0])
  149. y1 = max(box1[1], box2[1])
  150. x2 = min(box1[2], box2[2])
  151. y2 = min(box1[3], box2[3])
  152. return max(0, x2 - x1) * max(0, y2 - y1)
  153. def IoU(box1, box2, return_max=True):
  154. intersection = intersection_area(box1, box2)
  155. union = box_area(box1) + box_area(box2) - intersection
  156. if box_area(box1) > 0 and box_area(box2) > 0:
  157. ratio1 = intersection / box_area(box1)
  158. ratio2 = intersection / box_area(box2)
  159. else:
  160. ratio1, ratio2 = 0, 0
  161. if return_max:
  162. return max(intersection / union, ratio1, ratio2)
  163. else:
  164. return intersection / union
  165. def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
  166. """ check overlap of text and background detection box, and get_optimal_label_pos,
  167. pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
  168. Threshold: default to 0.3
  169. """
  170. def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
  171. is_overlap = False
  172. for i in range(len(detections)):
  173. detection = detections.xyxy[i].astype(int)
  174. if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
  175. is_overlap = True
  176. break
  177. # check if the text is out of the image
  178. if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
  179. is_overlap = True
  180. return is_overlap
  181. # if pos == 'top left':
  182. text_x = x1 + text_padding
  183. text_y = y1 - text_padding
  184. text_background_x1 = x1
  185. text_background_y1 = y1 - 2 * text_padding - text_height
  186. text_background_x2 = x1 + 2 * text_padding + text_width
  187. text_background_y2 = y1
  188. is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
  189. if not is_overlap:
  190. return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
  191. # elif pos == 'outer left':
  192. text_x = x1 - text_padding - text_width
  193. text_y = y1 + text_padding + text_height
  194. text_background_x1 = x1 - 2 * text_padding - text_width
  195. text_background_y1 = y1
  196. text_background_x2 = x1
  197. text_background_y2 = y1 + 2 * text_padding + text_height
  198. is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
  199. if not is_overlap:
  200. return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
  201. # elif pos == 'outer right':
  202. text_x = x2 + text_padding
  203. text_y = y1 + text_padding + text_height
  204. text_background_x1 = x2
  205. text_background_y1 = y1
  206. text_background_x2 = x2 + 2 * text_padding + text_width
  207. text_background_y2 = y1 + 2 * text_padding + text_height
  208. is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
  209. if not is_overlap:
  210. return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
  211. # elif pos == 'top right':
  212. text_x = x2 - text_padding - text_width
  213. text_y = y1 - text_padding
  214. text_background_x1 = x2 - 2 * text_padding - text_width
  215. text_background_y1 = y1 - 2 * text_padding - text_height
  216. text_background_x2 = x2
  217. text_background_y2 = y1
  218. is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
  219. if not is_overlap:
  220. return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
  221. return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2