api.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # based on https://github.com/isl-org/MiDaS
  2. import cv2
  3. import torch
  4. import torch.nn as nn
  5. from torchvision.transforms import Compose
  6. from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
  7. from ldm.modules.midas.midas.midas_net import MidasNet
  8. from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
  9. from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
  10. ISL_PATHS = {
  11. "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
  12. "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
  13. "midas_v21": "",
  14. "midas_v21_small": "",
  15. }
  16. def disabled_train(self, mode=True):
  17. """Overwrite model.train with this function to make sure train/eval mode
  18. does not change anymore."""
  19. return self
  20. def load_midas_transform(model_type):
  21. # https://github.com/isl-org/MiDaS/blob/master/run.py
  22. # load transform only
  23. if model_type == "dpt_large": # DPT-Large
  24. net_w, net_h = 384, 384
  25. resize_mode = "minimal"
  26. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  27. elif model_type == "dpt_hybrid": # DPT-Hybrid
  28. net_w, net_h = 384, 384
  29. resize_mode = "minimal"
  30. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  31. elif model_type == "midas_v21":
  32. net_w, net_h = 384, 384
  33. resize_mode = "upper_bound"
  34. normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  35. elif model_type == "midas_v21_small":
  36. net_w, net_h = 256, 256
  37. resize_mode = "upper_bound"
  38. normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  39. else:
  40. assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
  41. transform = Compose(
  42. [
  43. Resize(
  44. net_w,
  45. net_h,
  46. resize_target=None,
  47. keep_aspect_ratio=True,
  48. ensure_multiple_of=32,
  49. resize_method=resize_mode,
  50. image_interpolation_method=cv2.INTER_CUBIC,
  51. ),
  52. normalization,
  53. PrepareForNet(),
  54. ]
  55. )
  56. return transform
  57. def load_model(model_type):
  58. # https://github.com/isl-org/MiDaS/blob/master/run.py
  59. # load network
  60. model_path = ISL_PATHS[model_type]
  61. if model_type == "dpt_large": # DPT-Large
  62. model = DPTDepthModel(
  63. path=model_path,
  64. backbone="vitl16_384",
  65. non_negative=True,
  66. )
  67. net_w, net_h = 384, 384
  68. resize_mode = "minimal"
  69. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  70. elif model_type == "dpt_hybrid": # DPT-Hybrid
  71. model = DPTDepthModel(
  72. path=model_path,
  73. backbone="vitb_rn50_384",
  74. non_negative=True,
  75. )
  76. net_w, net_h = 384, 384
  77. resize_mode = "minimal"
  78. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  79. elif model_type == "midas_v21":
  80. model = MidasNet(model_path, non_negative=True)
  81. net_w, net_h = 384, 384
  82. resize_mode = "upper_bound"
  83. normalization = NormalizeImage(
  84. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  85. )
  86. elif model_type == "midas_v21_small":
  87. model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
  88. non_negative=True, blocks={'expand': True})
  89. net_w, net_h = 256, 256
  90. resize_mode = "upper_bound"
  91. normalization = NormalizeImage(
  92. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  93. )
  94. else:
  95. print(f"model_type '{model_type}' not implemented, use: --model_type large")
  96. assert False
  97. transform = Compose(
  98. [
  99. Resize(
  100. net_w,
  101. net_h,
  102. resize_target=None,
  103. keep_aspect_ratio=True,
  104. ensure_multiple_of=32,
  105. resize_method=resize_mode,
  106. image_interpolation_method=cv2.INTER_CUBIC,
  107. ),
  108. normalization,
  109. PrepareForNet(),
  110. ]
  111. )
  112. return model.eval(), transform
  113. class MiDaSInference(nn.Module):
  114. MODEL_TYPES_TORCH_HUB = [
  115. "DPT_Large",
  116. "DPT_Hybrid",
  117. "MiDaS_small"
  118. ]
  119. MODEL_TYPES_ISL = [
  120. "dpt_large",
  121. "dpt_hybrid",
  122. "midas_v21",
  123. "midas_v21_small",
  124. ]
  125. def __init__(self, model_type):
  126. super().__init__()
  127. assert (model_type in self.MODEL_TYPES_ISL)
  128. model, _ = load_model(model_type)
  129. self.model = model
  130. self.model.train = disabled_train
  131. def forward(self, x):
  132. # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
  133. # NOTE: we expect that the correct transform has been called during dataloading.
  134. with torch.no_grad():
  135. prediction = self.model(x)
  136. prediction = torch.nn.functional.interpolate(
  137. prediction.unsqueeze(1),
  138. size=x.shape[2:],
  139. mode="bicubic",
  140. align_corners=False,
  141. )
  142. assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
  143. return prediction