yolov5_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import cv2
  6. import numpy as np
  7. import time
  8. import torchvision
  9. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  10. # scales img(bs,3,y,x) by ratio constrained to gs-multiple
  11. if ratio == 1.0:
  12. return img
  13. else:
  14. h, w = img.shape[2:]
  15. s = (int(h * ratio), int(w * ratio)) # new size
  16. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  17. if not same_shape: # pad/crop img
  18. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  19. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  20. def fuse_conv_and_bn(conv, bn):
  21. # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  22. fusedconv = nn.Conv2d(conv.in_channels,
  23. conv.out_channels,
  24. kernel_size=conv.kernel_size,
  25. stride=conv.stride,
  26. padding=conv.padding,
  27. groups=conv.groups,
  28. bias=True).requires_grad_(False).to(conv.weight.device)
  29. # prepare filters
  30. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  31. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  32. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  33. # prepare spatial bias
  34. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  35. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  36. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  37. return fusedconv
  38. def check_anchor_order(m):
  39. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  40. a = m.anchors.prod(-1).view(-1) # anchor area
  41. da = a[-1] - a[0] # delta a
  42. ds = m.stride[-1] - m.stride[0] # delta s
  43. if da.sign() != ds.sign(): # same order
  44. m.anchors[:] = m.anchors.flip(0)
  45. def initialize_weights(model):
  46. for m in model.modules():
  47. t = type(m)
  48. if t is nn.Conv2d:
  49. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  50. elif t is nn.BatchNorm2d:
  51. m.eps = 1e-3
  52. m.momentum = 0.03
  53. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  54. m.inplace = True
  55. def make_divisible(x, divisor):
  56. # Returns nearest x divisible by divisor
  57. if isinstance(divisor, torch.Tensor):
  58. divisor = int(divisor.max()) # to int
  59. return math.ceil(x / divisor) * divisor
  60. def intersect_dicts(da, db, exclude=()):
  61. # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
  62. return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
  63. def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
  64. # Check version vs. required version
  65. from packaging import version
  66. current, minimum = (version.parse(x) for x in (current, minimum))
  67. result = (current == minimum) if pinned else (current >= minimum) # bool
  68. if hard: # assert min requirements met
  69. assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
  70. else:
  71. return result
  72. class Colors:
  73. # Ultralytics color palette https://ultralytics.com/
  74. def __init__(self):
  75. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  76. hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  77. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  78. self.palette = [self.hex2rgb('#' + c) for c in hex]
  79. self.n = len(self.palette)
  80. def __call__(self, i, bgr=False):
  81. c = self.palette[int(i) % self.n]
  82. return (c[2], c[1], c[0]) if bgr else c
  83. @staticmethod
  84. def hex2rgb(h): # rgb order (PIL)
  85. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  86. def box_iou(box1, box2):
  87. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  88. """
  89. Return intersection-over-union (Jaccard index) of boxes.
  90. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  91. Arguments:
  92. box1 (Tensor[N, 4])
  93. box2 (Tensor[M, 4])
  94. Returns:
  95. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  96. IoU values for every element in boxes1 and boxes2
  97. """
  98. def box_area(box):
  99. # box = 4xn
  100. return (box[2] - box[0]) * (box[3] - box[1])
  101. area1 = box_area(box1.T)
  102. area2 = box_area(box2.T)
  103. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  104. inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
  105. return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
  106. def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
  107. labels=(), max_det=300):
  108. """Runs Non-Maximum Suppression (NMS) on inference results
  109. Returns:
  110. list of detections, on (n,6) tensor per image [xyxy, conf, cls]
  111. """
  112. if isinstance(prediction, np.ndarray):
  113. prediction = torch.from_numpy(prediction)
  114. nc = prediction.shape[2] - 5 # number of classes
  115. xc = prediction[..., 4] > conf_thres # candidates
  116. # Checks
  117. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  118. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  119. # Settings
  120. min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
  121. max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
  122. time_limit = 10.0 # seconds to quit after
  123. redundant = True # require redundant detections
  124. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  125. merge = False # use merge-NMS
  126. t = time.time()
  127. output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
  128. for xi, x in enumerate(prediction): # image index, image inference
  129. # Apply constraints
  130. # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
  131. x = x[xc[xi]] # confidence
  132. # Cat apriori labels if autolabelling
  133. if labels and len(labels[xi]):
  134. l = labels[xi]
  135. v = torch.zeros((len(l), nc + 5), device=x.device)
  136. v[:, :4] = l[:, 1:5] # box
  137. v[:, 4] = 1.0 # conf
  138. v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
  139. x = torch.cat((x, v), 0)
  140. # If none remain process next image
  141. if not x.shape[0]:
  142. continue
  143. # Compute conf
  144. x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
  145. # Box (center x, center y, width, height) to (x1, y1, x2, y2)
  146. box = xywh2xyxy(x[:, :4])
  147. # Detections matrix nx6 (xyxy, conf, cls)
  148. if multi_label:
  149. i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
  150. x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
  151. else: # best class only
  152. conf, j = x[:, 5:].max(1, keepdim=True)
  153. x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
  154. # Filter by class
  155. if classes is not None:
  156. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  157. # Apply finite constraint
  158. # if not torch.isfinite(x).all():
  159. # x = x[torch.isfinite(x).all(1)]
  160. # Check shape
  161. n = x.shape[0] # number of boxes
  162. if not n: # no boxes
  163. continue
  164. elif n > max_nms: # excess boxes
  165. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
  166. # Batched NMS
  167. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  168. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  169. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  170. if i.shape[0] > max_det: # limit detections
  171. i = i[:max_det]
  172. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  173. # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  174. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  175. weights = iou * scores[None] # box weights
  176. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  177. if redundant:
  178. i = i[iou.sum(1) > 1] # require redundancy
  179. output[xi] = x[i]
  180. if (time.time() - t) > time_limit:
  181. print(f'WARNING: NMS time limit {time_limit}s exceeded')
  182. break # time limit exceeded
  183. return output
  184. def xywh2xyxy(x):
  185. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  186. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  187. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  188. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  189. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  190. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  191. return y
  192. DEFAULT_LANG_LIST = ['eng', 'ja']
  193. def draw_bbox(pred, img, lang_list=None):
  194. if lang_list is None:
  195. lang_list = DEFAULT_LANG_LIST
  196. lw = max(round(sum(img.shape) / 2 * 0.003), 2) # line width
  197. pred = pred.astype(np.int32)
  198. colors = Colors()
  199. img = np.copy(img)
  200. for ii, obj in enumerate(pred):
  201. p1, p2 = (obj[0], obj[1]), (obj[2], obj[3])
  202. label = lang_list[obj[-1]] + str(ii+1)
  203. cv2.rectangle(img, p1, p2, colors(obj[-1], bgr=True), lw, lineType=cv2.LINE_AA)
  204. t_w, t_h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=lw)[0]
  205. cv2.putText(img, label, (p1[0], p1[1] + t_h + 2), 0, lw / 3, colors(obj[-1], bgr=True), max(lw-1, 1), cv2.LINE_AA)
  206. return img