bsrgan.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. # -*- coding: utf-8 -*-
  2. """
  3. # --------------------------------------------
  4. # Super-Resolution
  5. # --------------------------------------------
  6. #
  7. # Kai Zhang (cskaizhang@gmail.com)
  8. # https://github.com/cszn
  9. # From 2019/03--2021/08
  10. # --------------------------------------------
  11. """
  12. import numpy as np
  13. import cv2
  14. import torch
  15. from functools import partial
  16. import random
  17. from scipy import ndimage
  18. import scipy
  19. import scipy.stats as ss
  20. from scipy.interpolate import interp2d
  21. from scipy.linalg import orth
  22. import albumentations
  23. import ldm.modules.image_degradation.utils_image as util
  24. def modcrop_np(img, sf):
  25. '''
  26. Args:
  27. img: numpy image, WxH or WxHxC
  28. sf: scale factor
  29. Return:
  30. cropped image
  31. '''
  32. w, h = img.shape[:2]
  33. im = np.copy(img)
  34. return im[:w - w % sf, :h - h % sf, ...]
  35. """
  36. # --------------------------------------------
  37. # anisotropic Gaussian kernels
  38. # --------------------------------------------
  39. """
  40. def analytic_kernel(k):
  41. """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
  42. k_size = k.shape[0]
  43. # Calculate the big kernels size
  44. big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
  45. # Loop over the small kernel to fill the big one
  46. for r in range(k_size):
  47. for c in range(k_size):
  48. big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
  49. # Crop the edges of the big kernel to ignore very small values and increase run time of SR
  50. crop = k_size // 2
  51. cropped_big_k = big_k[crop:-crop, crop:-crop]
  52. # Normalize to 1
  53. return cropped_big_k / cropped_big_k.sum()
  54. def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
  55. """ generate an anisotropic Gaussian kernel
  56. Args:
  57. ksize : e.g., 15, kernel size
  58. theta : [0, pi], rotation angle range
  59. l1 : [0.1,50], scaling of eigenvalues
  60. l2 : [0.1,l1], scaling of eigenvalues
  61. If l1 = l2, will get an isotropic Gaussian kernel.
  62. Returns:
  63. k : kernel
  64. """
  65. v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
  66. V = np.array([[v[0], v[1]], [v[1], -v[0]]])
  67. D = np.array([[l1, 0], [0, l2]])
  68. Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
  69. k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
  70. return k
  71. def gm_blur_kernel(mean, cov, size=15):
  72. center = size / 2.0 + 0.5
  73. k = np.zeros([size, size])
  74. for y in range(size):
  75. for x in range(size):
  76. cy = y - center + 1
  77. cx = x - center + 1
  78. k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
  79. k = k / np.sum(k)
  80. return k
  81. def shift_pixel(x, sf, upper_left=True):
  82. """shift pixel for super-resolution with different scale factors
  83. Args:
  84. x: WxHxC or WxH
  85. sf: scale factor
  86. upper_left: shift direction
  87. """
  88. h, w = x.shape[:2]
  89. shift = (sf - 1) * 0.5
  90. xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
  91. if upper_left:
  92. x1 = xv + shift
  93. y1 = yv + shift
  94. else:
  95. x1 = xv - shift
  96. y1 = yv - shift
  97. x1 = np.clip(x1, 0, w - 1)
  98. y1 = np.clip(y1, 0, h - 1)
  99. if x.ndim == 2:
  100. x = interp2d(xv, yv, x)(x1, y1)
  101. if x.ndim == 3:
  102. for i in range(x.shape[-1]):
  103. x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
  104. return x
  105. def blur(x, k):
  106. '''
  107. x: image, NxcxHxW
  108. k: kernel, Nx1xhxw
  109. '''
  110. n, c = x.shape[:2]
  111. p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
  112. x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
  113. k = k.repeat(1, c, 1, 1)
  114. k = k.view(-1, 1, k.shape[2], k.shape[3])
  115. x = x.view(1, -1, x.shape[2], x.shape[3])
  116. x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
  117. x = x.view(n, c, x.shape[2], x.shape[3])
  118. return x
  119. def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
  120. """"
  121. # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
  122. # Kai Zhang
  123. # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
  124. # max_var = 2.5 * sf
  125. """
  126. # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
  127. lambda_1 = min_var + np.random.rand() * (max_var - min_var)
  128. lambda_2 = min_var + np.random.rand() * (max_var - min_var)
  129. theta = np.random.rand() * np.pi # random theta
  130. noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
  131. # Set COV matrix using Lambdas and Theta
  132. LAMBDA = np.diag([lambda_1, lambda_2])
  133. Q = np.array([[np.cos(theta), -np.sin(theta)],
  134. [np.sin(theta), np.cos(theta)]])
  135. SIGMA = Q @ LAMBDA @ Q.T
  136. INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
  137. # Set expectation position (shifting kernel for aligned image)
  138. MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
  139. MU = MU[None, None, :, None]
  140. # Create meshgrid for Gaussian
  141. [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
  142. Z = np.stack([X, Y], 2)[:, :, :, None]
  143. # Calculate Gaussian for every pixel of the kernel
  144. ZZ = Z - MU
  145. ZZ_t = ZZ.transpose(0, 1, 3, 2)
  146. raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
  147. # shift the kernel so it will be centered
  148. # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
  149. # Normalize the kernel and return
  150. # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
  151. kernel = raw_kernel / np.sum(raw_kernel)
  152. return kernel
  153. def fspecial_gaussian(hsize, sigma):
  154. hsize = [hsize, hsize]
  155. siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
  156. std = sigma
  157. [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
  158. arg = -(x * x + y * y) / (2 * std * std)
  159. h = np.exp(arg)
  160. h[h < scipy.finfo(float).eps * h.max()] = 0
  161. sumh = h.sum()
  162. if sumh != 0:
  163. h = h / sumh
  164. return h
  165. def fspecial_laplacian(alpha):
  166. alpha = max([0, min([alpha, 1])])
  167. h1 = alpha / (alpha + 1)
  168. h2 = (1 - alpha) / (alpha + 1)
  169. h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
  170. h = np.array(h)
  171. return h
  172. def fspecial(filter_type, *args, **kwargs):
  173. '''
  174. python code from:
  175. https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
  176. '''
  177. if filter_type == 'gaussian':
  178. return fspecial_gaussian(*args, **kwargs)
  179. if filter_type == 'laplacian':
  180. return fspecial_laplacian(*args, **kwargs)
  181. """
  182. # --------------------------------------------
  183. # degradation models
  184. # --------------------------------------------
  185. """
  186. def bicubic_degradation(x, sf=3):
  187. '''
  188. Args:
  189. x: HxWxC image, [0, 1]
  190. sf: down-scale factor
  191. Return:
  192. bicubicly downsampled LR image
  193. '''
  194. x = util.imresize_np(x, scale=1 / sf)
  195. return x
  196. def srmd_degradation(x, k, sf=3):
  197. ''' blur + bicubic downsampling
  198. Args:
  199. x: HxWxC image, [0, 1]
  200. k: hxw, double
  201. sf: down-scale factor
  202. Return:
  203. downsampled LR image
  204. Reference:
  205. @inproceedings{zhang2018learning,
  206. title={Learning a single convolutional super-resolution network for multiple degradations},
  207. author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
  208. booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
  209. pages={3262--3271},
  210. year={2018}
  211. }
  212. '''
  213. x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
  214. x = bicubic_degradation(x, sf=sf)
  215. return x
  216. def dpsr_degradation(x, k, sf=3):
  217. ''' bicubic downsampling + blur
  218. Args:
  219. x: HxWxC image, [0, 1]
  220. k: hxw, double
  221. sf: down-scale factor
  222. Return:
  223. downsampled LR image
  224. Reference:
  225. @inproceedings{zhang2019deep,
  226. title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
  227. author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
  228. booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
  229. pages={1671--1681},
  230. year={2019}
  231. }
  232. '''
  233. x = bicubic_degradation(x, sf=sf)
  234. x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
  235. return x
  236. def classical_degradation(x, k, sf=3):
  237. ''' blur + downsampling
  238. Args:
  239. x: HxWxC image, [0, 1]/[0, 255]
  240. k: hxw, double
  241. sf: down-scale factor
  242. Return:
  243. downsampled LR image
  244. '''
  245. x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
  246. # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
  247. st = 0
  248. return x[st::sf, st::sf, ...]
  249. def add_sharpening(img, weight=0.5, radius=50, threshold=10):
  250. """USM sharpening. borrowed from real-ESRGAN
  251. Input image: I; Blurry image: B.
  252. 1. K = I + weight * (I - B)
  253. 2. Mask = 1 if abs(I - B) > threshold, else: 0
  254. 3. Blur mask:
  255. 4. Out = Mask * K + (1 - Mask) * I
  256. Args:
  257. img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
  258. weight (float): Sharp weight. Default: 1.
  259. radius (float): Kernel size of Gaussian blur. Default: 50.
  260. threshold (int):
  261. """
  262. if radius % 2 == 0:
  263. radius += 1
  264. blur = cv2.GaussianBlur(img, (radius, radius), 0)
  265. residual = img - blur
  266. mask = np.abs(residual) * 255 > threshold
  267. mask = mask.astype('float32')
  268. soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
  269. K = img + weight * residual
  270. K = np.clip(K, 0, 1)
  271. return soft_mask * K + (1 - soft_mask) * img
  272. def add_blur(img, sf=4):
  273. wd2 = 4.0 + sf
  274. wd = 2.0 + 0.2 * sf
  275. if random.random() < 0.5:
  276. l1 = wd2 * random.random()
  277. l2 = wd2 * random.random()
  278. k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
  279. else:
  280. k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
  281. img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
  282. return img
  283. def add_resize(img, sf=4):
  284. rnum = np.random.rand()
  285. if rnum > 0.8: # up
  286. sf1 = random.uniform(1, 2)
  287. elif rnum < 0.7: # down
  288. sf1 = random.uniform(0.5 / sf, 1)
  289. else:
  290. sf1 = 1.0
  291. img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
  292. img = np.clip(img, 0.0, 1.0)
  293. return img
  294. # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
  295. # noise_level = random.randint(noise_level1, noise_level2)
  296. # rnum = np.random.rand()
  297. # if rnum > 0.6: # add color Gaussian noise
  298. # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
  299. # elif rnum < 0.4: # add grayscale Gaussian noise
  300. # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
  301. # else: # add noise
  302. # L = noise_level2 / 255.
  303. # D = np.diag(np.random.rand(3))
  304. # U = orth(np.random.rand(3, 3))
  305. # conv = np.dot(np.dot(np.transpose(U), D), U)
  306. # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
  307. # img = np.clip(img, 0.0, 1.0)
  308. # return img
  309. def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
  310. noise_level = random.randint(noise_level1, noise_level2)
  311. rnum = np.random.rand()
  312. if rnum > 0.6: # add color Gaussian noise
  313. img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
  314. elif rnum < 0.4: # add grayscale Gaussian noise
  315. img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
  316. else: # add noise
  317. L = noise_level2 / 255.
  318. D = np.diag(np.random.rand(3))
  319. U = orth(np.random.rand(3, 3))
  320. conv = np.dot(np.dot(np.transpose(U), D), U)
  321. img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
  322. img = np.clip(img, 0.0, 1.0)
  323. return img
  324. def add_speckle_noise(img, noise_level1=2, noise_level2=25):
  325. noise_level = random.randint(noise_level1, noise_level2)
  326. img = np.clip(img, 0.0, 1.0)
  327. rnum = random.random()
  328. if rnum > 0.6:
  329. img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
  330. elif rnum < 0.4:
  331. img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
  332. else:
  333. L = noise_level2 / 255.
  334. D = np.diag(np.random.rand(3))
  335. U = orth(np.random.rand(3, 3))
  336. conv = np.dot(np.dot(np.transpose(U), D), U)
  337. img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
  338. img = np.clip(img, 0.0, 1.0)
  339. return img
  340. def add_Poisson_noise(img):
  341. img = np.clip((img * 255.0).round(), 0, 255) / 255.
  342. vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
  343. if random.random() < 0.5:
  344. img = np.random.poisson(img * vals).astype(np.float32) / vals
  345. else:
  346. img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
  347. img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
  348. noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
  349. img += noise_gray[:, :, np.newaxis]
  350. img = np.clip(img, 0.0, 1.0)
  351. return img
  352. def add_JPEG_noise(img):
  353. quality_factor = random.randint(30, 95)
  354. img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
  355. result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
  356. img = cv2.imdecode(encimg, 1)
  357. img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
  358. return img
  359. def random_crop(lq, hq, sf=4, lq_patchsize=64):
  360. h, w = lq.shape[:2]
  361. rnd_h = random.randint(0, h - lq_patchsize)
  362. rnd_w = random.randint(0, w - lq_patchsize)
  363. lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
  364. rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
  365. hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
  366. return lq, hq
  367. def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
  368. """
  369. This is the degradation model of BSRGAN from the paper
  370. "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
  371. ----------
  372. img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
  373. sf: scale factor
  374. isp_model: camera ISP model
  375. Returns
  376. -------
  377. img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
  378. hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
  379. """
  380. isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
  381. sf_ori = sf
  382. h1, w1 = img.shape[:2]
  383. img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
  384. h, w = img.shape[:2]
  385. if h < lq_patchsize * sf or w < lq_patchsize * sf:
  386. raise ValueError(f'img size ({h1}X{w1}) is too small!')
  387. hq = img.copy()
  388. if sf == 4 and random.random() < scale2_prob: # downsample1
  389. if np.random.rand() < 0.5:
  390. img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
  391. interpolation=random.choice([1, 2, 3]))
  392. else:
  393. img = util.imresize_np(img, 1 / 2, True)
  394. img = np.clip(img, 0.0, 1.0)
  395. sf = 2
  396. shuffle_order = random.sample(range(7), 7)
  397. idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
  398. if idx1 > idx2: # keep downsample3 last
  399. shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
  400. for i in shuffle_order:
  401. if i == 0:
  402. img = add_blur(img, sf=sf)
  403. elif i == 1:
  404. img = add_blur(img, sf=sf)
  405. elif i == 2:
  406. a, b = img.shape[1], img.shape[0]
  407. # downsample2
  408. if random.random() < 0.75:
  409. sf1 = random.uniform(1, 2 * sf)
  410. img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
  411. interpolation=random.choice([1, 2, 3]))
  412. else:
  413. k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
  414. k_shifted = shift_pixel(k, sf)
  415. k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
  416. img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
  417. img = img[0::sf, 0::sf, ...] # nearest downsampling
  418. img = np.clip(img, 0.0, 1.0)
  419. elif i == 3:
  420. # downsample3
  421. img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
  422. img = np.clip(img, 0.0, 1.0)
  423. elif i == 4:
  424. # add Gaussian noise
  425. img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
  426. elif i == 5:
  427. # add JPEG noise
  428. if random.random() < jpeg_prob:
  429. img = add_JPEG_noise(img)
  430. elif i == 6:
  431. # add processed camera sensor noise
  432. if random.random() < isp_prob and isp_model is not None:
  433. with torch.no_grad():
  434. img, hq = isp_model.forward(img.copy(), hq)
  435. # add final JPEG compression noise
  436. img = add_JPEG_noise(img)
  437. # random crop
  438. img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
  439. return img, hq
  440. # todo no isp_model?
  441. def degradation_bsrgan_variant(image, sf=4, isp_model=None):
  442. """
  443. This is the degradation model of BSRGAN from the paper
  444. "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
  445. ----------
  446. sf: scale factor
  447. isp_model: camera ISP model
  448. Returns
  449. -------
  450. img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
  451. hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
  452. """
  453. image = util.uint2single(image)
  454. isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
  455. sf_ori = sf
  456. h1, w1 = image.shape[:2]
  457. image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
  458. h, w = image.shape[:2]
  459. hq = image.copy()
  460. if sf == 4 and random.random() < scale2_prob: # downsample1
  461. if np.random.rand() < 0.5:
  462. image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
  463. interpolation=random.choice([1, 2, 3]))
  464. else:
  465. image = util.imresize_np(image, 1 / 2, True)
  466. image = np.clip(image, 0.0, 1.0)
  467. sf = 2
  468. shuffle_order = random.sample(range(7), 7)
  469. idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
  470. if idx1 > idx2: # keep downsample3 last
  471. shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
  472. for i in shuffle_order:
  473. if i == 0:
  474. image = add_blur(image, sf=sf)
  475. elif i == 1:
  476. image = add_blur(image, sf=sf)
  477. elif i == 2:
  478. a, b = image.shape[1], image.shape[0]
  479. # downsample2
  480. if random.random() < 0.75:
  481. sf1 = random.uniform(1, 2 * sf)
  482. image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
  483. interpolation=random.choice([1, 2, 3]))
  484. else:
  485. k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
  486. k_shifted = shift_pixel(k, sf)
  487. k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
  488. image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
  489. image = image[0::sf, 0::sf, ...] # nearest downsampling
  490. image = np.clip(image, 0.0, 1.0)
  491. elif i == 3:
  492. # downsample3
  493. image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
  494. image = np.clip(image, 0.0, 1.0)
  495. elif i == 4:
  496. # add Gaussian noise
  497. image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
  498. elif i == 5:
  499. # add JPEG noise
  500. if random.random() < jpeg_prob:
  501. image = add_JPEG_noise(image)
  502. # elif i == 6:
  503. # # add processed camera sensor noise
  504. # if random.random() < isp_prob and isp_model is not None:
  505. # with torch.no_grad():
  506. # img, hq = isp_model.forward(img.copy(), hq)
  507. # add final JPEG compression noise
  508. image = add_JPEG_noise(image)
  509. image = util.single2uint(image)
  510. example = {"image":image}
  511. return example
  512. # TODO in case there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
  513. def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
  514. """
  515. This is an extended degradation model by combining
  516. the degradation models of BSRGAN and Real-ESRGAN
  517. ----------
  518. img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
  519. sf: scale factor
  520. use_shuffle: the degradation shuffle
  521. use_sharp: sharpening the img
  522. Returns
  523. -------
  524. img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
  525. hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
  526. """
  527. h1, w1 = img.shape[:2]
  528. img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
  529. h, w = img.shape[:2]
  530. if h < lq_patchsize * sf or w < lq_patchsize * sf:
  531. raise ValueError(f'img size ({h1}X{w1}) is too small!')
  532. if use_sharp:
  533. img = add_sharpening(img)
  534. hq = img.copy()
  535. if random.random() < shuffle_prob:
  536. shuffle_order = random.sample(range(13), 13)
  537. else:
  538. shuffle_order = list(range(13))
  539. # local shuffle for noise, JPEG is always the last one
  540. shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
  541. shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
  542. poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
  543. for i in shuffle_order:
  544. if i == 0:
  545. img = add_blur(img, sf=sf)
  546. elif i == 1:
  547. img = add_resize(img, sf=sf)
  548. elif i == 2:
  549. img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
  550. elif i == 3:
  551. if random.random() < poisson_prob:
  552. img = add_Poisson_noise(img)
  553. elif i == 4:
  554. if random.random() < speckle_prob:
  555. img = add_speckle_noise(img)
  556. elif i == 5:
  557. if random.random() < isp_prob and isp_model is not None:
  558. with torch.no_grad():
  559. img, hq = isp_model.forward(img.copy(), hq)
  560. elif i == 6:
  561. img = add_JPEG_noise(img)
  562. elif i == 7:
  563. img = add_blur(img, sf=sf)
  564. elif i == 8:
  565. img = add_resize(img, sf=sf)
  566. elif i == 9:
  567. img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
  568. elif i == 10:
  569. if random.random() < poisson_prob:
  570. img = add_Poisson_noise(img)
  571. elif i == 11:
  572. if random.random() < speckle_prob:
  573. img = add_speckle_noise(img)
  574. elif i == 12:
  575. if random.random() < isp_prob and isp_model is not None:
  576. with torch.no_grad():
  577. img, hq = isp_model.forward(img.copy(), hq)
  578. else:
  579. print('check the shuffle!')
  580. # resize to desired size
  581. img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
  582. interpolation=random.choice([1, 2, 3]))
  583. # add final JPEG compression noise
  584. img = add_JPEG_noise(img)
  585. # random crop
  586. img, hq = random_crop(img, hq, sf, lq_patchsize)
  587. return img, hq
  588. if __name__ == '__main__':
  589. print("hey")
  590. img = util.imread_uint('utils/test.png', 3)
  591. print(img)
  592. img = util.uint2single(img)
  593. print(img)
  594. img = img[:448, :448]
  595. h = img.shape[0] // 4
  596. print("resizing to", h)
  597. sf = 4
  598. deg_fn = partial(degradation_bsrgan_variant, sf=sf)
  599. for i in range(20):
  600. print(i)
  601. img_lq = deg_fn(img)
  602. print(img_lq)
  603. img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
  604. print(img_lq.shape)
  605. print("bicubic", img_lq_bicubic.shape)
  606. print(img_hq.shape)
  607. lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
  608. interpolation=0)
  609. lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
  610. interpolation=0)
  611. img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
  612. util.imsave(img_concat, str(i) + '.png')