combine_model_seq.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. """
  2. Copyright 2023 Yingqiang Ge
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. """
  13. __author__ = "Yingqiang Ge"
  14. __copyright__ = "Copyright 2023, OpenAGI"
  15. __date__ = "2023/04/10"
  16. __license__ = "Apache 2.0"
  17. __version__ = "0.0.1"
  18. import os
  19. # os.chdir('../')
  20. import models.github_models.colorization.colorizers as colorizers
  21. from models.github_models.colorization.colorizers import *
  22. from torchvision import transforms
  23. from transformers import (
  24. AutoTokenizer,
  25. AutoModelForQuestionAnswering,
  26. AutoModelForSequenceClassification,
  27. AutoModelForSeq2SeqLM,
  28. AutoModelForCausalLM,
  29. AutoModelForMaskedLM,
  30. DetrImageProcessor,
  31. DetrForObjectDetection,
  32. ViTFeatureExtractor,
  33. ViTForImageClassification,
  34. AutoImageProcessor,
  35. Swin2SRForImageSuperResolution,
  36. set_seed,
  37. ViltProcessor,
  38. ViltForQuestionAnswering,
  39. VisionEncoderDecoderModel
  40. )
  41. from diffusers import StableDiffusionPipeline
  42. import torch
  43. import os
  44. from runpy import run_path
  45. from skimage import img_as_ubyte
  46. import cv2
  47. import warnings
  48. warnings.filterwarnings('ignore')
  49. import gc
  50. set_seed(42)
  51. class SeqCombine:
  52. def __init__(self, args):
  53. self.device_list = args.device_list
  54. os.environ['TRANSFORMERS_CACHE'] = args.huggingface_cache
  55. print("Initializing image classifier...")
  56. self.img_classifier_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224', device_map = 'auto')
  57. self.img_classifier = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')#.to(self.device)
  58. self.img_classifier.eval()
  59. print("Initializing colorizers...")
  60. self.colorizer= colorizers.siggraph17().eval()#.to(self.device)
  61. print("Initializing object detector...")
  62. self.object_detector_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101")
  63. self.object_detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101")#.to(self.device)
  64. self.object_detector.eval()
  65. print("Initializing image super resolution...")
  66. self.image_super_resolution_processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
  67. self.image_super_resolution_model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
  68. self.image_super_resolution_model.eval()
  69. print("Initializing image caption...")
  70. self.image_caption_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")#.to(self.device)
  71. self.image_caption_model.eval()
  72. self.image_caption_feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
  73. self.image_caption_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
  74. print("Initializing text to image generator...")
  75. self.text_to_image_generator = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",cache_dir=args.huggingface_cache)#, torch_dtype=torch.float16)#.to(self.device)
  76. def dummy(images, **kwargs):
  77. return images, False
  78. self.text_to_image_generator.safety_checker = dummy
  79. self.text_to_image_generator.enable_attention_slicing()
  80. print("Initializing sentiment analysis...")
  81. self.sentiment_analysis_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
  82. self.sentiment_analysis_module = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
  83. self.sentiment_analysis_module.eval()
  84. print("Initializing QA...")
  85. self.question_answering_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased-distilled-squad")
  86. self.question_answerer = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-cased-distilled-squad")
  87. self.question_answerer.eval()
  88. print("Initializing summarization...")
  89. self.summarization_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  90. self.summarizer = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
  91. self.summarizer.eval()
  92. print("Initializing text generation...")
  93. self.text_generation_tokenizer = AutoTokenizer.from_pretrained("gpt2")
  94. self.text_generation_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
  95. self.text_generator = AutoModelForCausalLM.from_pretrained("gpt2")
  96. self.text_generator.eval()
  97. print("Initializing VQA...")
  98. self.vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
  99. self.vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
  100. self.vqa_model.eval()
  101. self.img_transform = transforms.Compose([
  102. transforms.Resize(256),
  103. transforms.CenterCrop(256),
  104. transforms.PILToTensor(),
  105. ])
  106. print("Initializing image deblurring...")
  107. #load debluring and denoising models
  108. parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
  109. weights = os.path.join('models','github_models','Restormer','Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
  110. load_arch = run_path(os.path.join('models','github_models','Restormer','basicsr', 'models', 'archs', 'restormer_arch.py'))
  111. self.image_deblurring_model = load_arch['Restormer'](**parameters)
  112. #device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
  113. #model.to(device)
  114. checkpoint = torch.load(weights)
  115. self.image_deblurring_model.load_state_dict(checkpoint['params'])
  116. self.image_deblurring_model.eval()
  117. print("Initializing image denoising...")
  118. parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
  119. weights = os.path.join('models','github_models','Restormer','Denoising', 'pretrained_models', 'real_denoising.pth')
  120. parameters['LayerNorm_type'] = 'BiasFree'
  121. load_arch = run_path(os.path.join('models','github_models','Restormer','basicsr', 'models', 'archs', 'restormer_arch.py'))
  122. self.image_denoising_model = load_arch['Restormer'](**parameters)
  123. #device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
  124. #model.to(device)
  125. checkpoint = torch.load(weights)
  126. self.image_denoising_model.load_state_dict(checkpoint['params'])
  127. self.image_denoising_model.eval()
  128. print("Initializing translator...")
  129. self.translation_tokenizer = AutoTokenizer.from_pretrained("t5-base")
  130. self.translator = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
  131. self.translator.eval()
  132. print("Initializing unmasker...")
  133. self.unmask_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
  134. self.unmasker = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
  135. self.unmasker.eval()
  136. self.module2function_dict = {
  137. "Image Classification": ["self.image_classification", self.img_classifier], \
  138. "Colorization": ["self.image_colorization", self.colorizer], \
  139. "Object Detection": ["self.image_object_detect", self.object_detector], \
  140. "Image Deblurring": ["self.image_deblurring", self.image_deblurring_model], \
  141. "Image Denoising": ["self.image_denoising", self.image_denoising_model], \
  142. "Image Super Resolution": ["self.image_super_resolution", self.image_super_resolution_model], \
  143. "Image Captioning": ["self.image_caption", self.image_caption_model], \
  144. "Text to Image Generation": ["self.text_to_image_generation", self.text_to_image_generator], \
  145. "Visual Question Answering": ["self.vqa", self.vqa_model], \
  146. "Sentiment Analysis": ["self.sentiment_analysis", self.sentiment_analysis_module], \
  147. "Question Answering": ["self.question_answering", self.question_answerer], \
  148. "Text Summarization": ["self.text_summarization", self.summarizer], \
  149. "Text Generation": ["self.text_generation", self.text_generator], \
  150. "Machine Translation": ["self.machine_translation", self.translator], \
  151. "Fill Mask": ["self.fill_mask", self.unmasker], \
  152. }
  153. def construct_module_seq(self, generated_module_seq):
  154. module_list = generated_module_seq.split(",")
  155. self.module_function_list = []
  156. self.module_list = []
  157. self.used_device_list = []
  158. i = 0
  159. cur_device = self.device_list[i]
  160. for module in module_list:
  161. module = module.strip()
  162. temp_values = self.module2function_dict[module]
  163. temp_m = temp_values[1]
  164. if cur_device != "cpu":
  165. if torch.cuda.mem_get_info(cur_device)[0]/1024**3 >= 3:
  166. temp_m = temp_m.to(cur_device)
  167. self.used_device_list.append(cur_device)
  168. else:
  169. i += 1
  170. cur_device = self.device_list[i]
  171. temp_m = temp_m.to(cur_device)
  172. self.used_device_list.append(cur_device)
  173. else:
  174. temp_m = temp_m.to(cur_device)
  175. self.used_device_list.append(cur_device)
  176. temp_f = eval(temp_values[0])
  177. self.module_function_list.append(temp_f)
  178. self.module_list.append(temp_m)
  179. def run_module_seq(self, input_data):
  180. temp = input_data
  181. for i,m in enumerate(self.module_function_list):
  182. temp = m(temp, self.used_device_list[i])
  183. return temp
  184. def close_module_seq(self):
  185. for m in self.module_list:
  186. m = m.to(torch.device("cpu"))
  187. torch.cuda.empty_cache()
  188. gc.collect()
  189. return
  190. def image_classification(self, imgs, device):
  191. img_classifier_inputs = self.img_classifier_feature_extractor(images=imgs, return_tensors="pt").to(device)
  192. with torch.no_grad():
  193. img_classifier_outputs = self.img_classifier(**img_classifier_inputs)
  194. img_classifier_logits = img_classifier_outputs.logits
  195. # img_classifier_logits.shape
  196. # model predicts one of the 1000 ImageNet classes
  197. predicted_class_idx = img_classifier_logits.argmax(1)#.item()
  198. predicted_class = [self.img_classifier.config.id2label[i.item()] for i in predicted_class_idx]
  199. return predicted_class
  200. def image_colorization(self, imgs, device):
  201. temp_imgs = []
  202. for img in imgs:
  203. img = img.permute(1,2,0).cpu().numpy()
  204. (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))
  205. tens_l_rs = tens_l_rs.to(device)
  206. # colorizer outputs 256x256 ab map
  207. # resize and concatenate to original L channel
  208. img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1))
  209. out_img = postprocess_tens(tens_l_orig, self.colorizer(tens_l_rs).cpu())
  210. norm_out_img = cv2.normalize(out_img, None, alpha = 0, beta = 255, norm_type = cv2.NORM_MINMAX, dtype = cv2.CV_32F)
  211. norm_out_img = norm_out_img.astype(np.uint8)
  212. # colorized_img = Image.fromarray(norm_out_img,'RGB')
  213. # temp_imgs.append(torch.from_numpy(np.array(colorized_img)).permute(2,0,1))
  214. temp_imgs.append(torch.from_numpy(norm_out_img).permute(2,0,1))
  215. return temp_imgs
  216. def image_object_detect(self, imgs, device):
  217. imgs = torch.stack(imgs)
  218. object_detector_inputs = self.object_detector_processor(images=imgs, return_tensors="pt").to(device)
  219. with torch.no_grad():
  220. object_detector_outputs = self.object_detector(**object_detector_inputs)
  221. # convert outputs (bounding boxes and class logits) to COCO API
  222. # let's only keep detections with score > 0.9
  223. target_sizes = torch.tensor([[object_detector_inputs['pixel_values'].shape[2],\
  224. object_detector_inputs['pixel_values'].shape[3]] \
  225. for i in range(object_detector_inputs['pixel_values'].shape[0])]).to(device)
  226. results = self.object_detector_processor.post_process_object_detection(object_detector_outputs, target_sizes=target_sizes, threshold=0.9)
  227. predicted_results = []
  228. for r in results:
  229. output = ""
  230. for score, label, box in zip(r["scores"], r["labels"], r["boxes"]):
  231. output += self.object_detector.config.id2label[label.item()]
  232. output += ", "
  233. predicted_results.append(output[:-2])
  234. return predicted_results
  235. def image_caption(self, imgs, device):
  236. """
  237. input:
  238. imgs: list of image tensors
  239. output:
  240. preds: list of strings
  241. """
  242. max_length = 40
  243. num_beams = 4
  244. gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
  245. pixel_values = self.image_caption_feature_extractor(images=imgs, return_tensors="pt").pixel_values
  246. pixel_values = pixel_values.to(device)
  247. with torch.no_grad():
  248. output_ids = self.image_caption_model.generate(pixel_values, **gen_kwargs)
  249. preds = self.image_caption_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
  250. preds = [pred.strip() for pred in preds]
  251. return preds
  252. def text_to_image_generation(self, prompts, device):
  253. with torch.no_grad():
  254. images = self.text_to_image_generator(prompts).images
  255. return [self.img_transform(im) for im in images]
  256. def image_super_resolution(self, imgs, device):
  257. """
  258. imgs: can be list of Images or list of image tensor (3,H,W) or list of image array (3,H,W),
  259. while we fix it to be list of image tensor
  260. output: numpy.array (3,H,W,B)
  261. res: list of image tensor where each element is (3,H,W)
  262. """
  263. batch_size = len(imgs)
  264. inputs = torch.stack(imgs).permute(0,2,3,1)
  265. inputs = self.image_super_resolution_processor(inputs, return_tensors="pt").to(device)
  266. # forward pass
  267. with torch.no_grad():
  268. outputs = self.image_super_resolution_model(**inputs)
  269. reformed_outputs = []
  270. for i in range(batch_size):
  271. output_ = outputs.reconstruction.data[i]
  272. output = output_.squeeze().float().cpu().clamp_(0, 1).numpy()
  273. output = np.moveaxis(output, source=0, destination=-1)
  274. output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
  275. reformed_outputs.append(torch.from_numpy(output).permute(2,0,1))
  276. return reformed_outputs
  277. def vqa(self, input_data, device):
  278. """
  279. input:
  280. imgs: list of image tensors
  281. questions: list of strings
  282. output:
  283. answers: list of strings
  284. """
  285. imgs = input_data[1]
  286. questions = list(input_data[0])
  287. encoding = self.vqa_processor(imgs, questions, return_tensors="pt", padding=True).to(device)
  288. # forward pass
  289. with torch.no_grad():
  290. outputs = self.vqa_model(**encoding)
  291. logits = outputs.logits
  292. idxs = torch.argmax(logits, 1)
  293. answers = [self.vqa_model.config.id2label[idx.item()] for idx in idxs]
  294. return answers
  295. def sentiment_analysis(self, sentences, device):
  296. """
  297. input:
  298. sentences: list of strings
  299. output:
  300. predicted_labels: list of strings
  301. """
  302. inputs = self.sentiment_analysis_tokenizer(sentences, return_tensors="pt", padding=True).to(device)
  303. # Get the outputs from the model
  304. with torch.no_grad():
  305. outputs = self.sentiment_analysis_module(**inputs)
  306. # Get the logits from the outputs
  307. logits = outputs.logits
  308. # Apply softmax to get probabilities
  309. probabilities = torch.softmax(logits, dim=-1)
  310. # Get the most likely label and score
  311. predicted_label_ids = torch.argmax(probabilities, 1)#.item()
  312. predicted_labels = [self.sentiment_analysis_module.config.id2label[i.item()] for i in predicted_label_ids]
  313. return predicted_labels
  314. def question_answering(self, input_data, device):
  315. """
  316. input:
  317. questions: list of strings
  318. contexts: list of strings
  319. output:
  320. results: list of strings
  321. """
  322. questions = list(input_data[1])
  323. contexts = list(input_data[0])
  324. batch_size = len(questions)
  325. inputs = self.question_answering_tokenizer(questions, contexts, return_tensors="pt", padding=True).to(device)
  326. # Get the outputs from the model
  327. with torch.no_grad():
  328. outputs = self.question_answerer(**inputs)
  329. # Get the start and end logits
  330. start_logits = outputs.start_logits
  331. end_logits = outputs.end_logits
  332. # Get the most likely start and end indices
  333. start_index = torch.argmax(start_logits)
  334. end_index = torch.argmax(end_logits)
  335. # Get the answer span from the inputs
  336. results = []
  337. for i in range(batch_size):
  338. answer_ids = inputs["input_ids"][i][start_index:end_index+1]
  339. answer_tokens = self.question_answering_tokenizer.convert_ids_to_tokens(answer_ids)
  340. answer_text = self.question_answering_tokenizer.convert_tokens_to_string(answer_tokens)
  341. # Print the answer
  342. results.append(answer_text)
  343. return results
  344. def text_summarization(self, text, device):
  345. """
  346. input:
  347. text: list of strings
  348. output:
  349. summary_text: list of strings
  350. """
  351. inputs = self.summarization_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
  352. # Get the outputs from the model
  353. with torch.no_grad():
  354. outputs = self.summarizer.generate(**inputs)
  355. summary_text = [self.summarization_tokenizer.decode(summary_ids).strip("</s>") for summary_ids in outputs]
  356. return summary_text
  357. def machine_translation(self, sentences, device):
  358. """
  359. input:
  360. sentences: list of English strings
  361. output:
  362. translated_text: list of German strings
  363. """
  364. text = ["translate English to German: " + s for s in sentences]
  365. # Encode the input with the tokenizer
  366. inputs = self.translation_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
  367. # Get the outputs from the model
  368. with torch.no_grad():
  369. outputs = self.translator.generate(**inputs, min_length=5, max_length=1000)
  370. # Decode the outputs to get the translated text
  371. translated_text = [self.translation_tokenizer.decode(translated_ids).strip("<pad></s>") for translated_ids in outputs]
  372. return translated_text
  373. def text_generation(self, sentences, device):
  374. """
  375. input:
  376. sentences: list of strings
  377. output:
  378. generated_text: list of strings
  379. """
  380. res = []
  381. for s in sentences:
  382. # Encode the input with the tokenizer
  383. inputs = self.text_generation_tokenizer(s, return_tensors="pt", padding=True).to(device)
  384. # Get the outputs from the model
  385. with torch.no_grad():
  386. outputs = self.text_generator.generate(**inputs, min_length=5, max_length=30, pad_token_id=50256)
  387. # Decode the outputs to get the generated text
  388. generated_s = self.text_generation_tokenizer.decode(outputs[0])
  389. res.append(generated_s)
  390. return res
  391. def fill_mask(self, sentences, device):
  392. """
  393. input:
  394. sentences: list of strings with "[MASK]"
  395. output:
  396. results: lsit of strings
  397. """
  398. batch_size = len(sentences)
  399. inputs = self.unmask_tokenizer(sentences, return_tensors="pt", padding=True).to(device)
  400. # Get the outputs from the model
  401. with torch.no_grad():
  402. outputs = self.unmasker(**inputs)
  403. # Get the logits from the outputs
  404. logits = outputs.logits
  405. # Apply softmax to get probabilities
  406. probabilities = torch.softmax(logits, dim=-1)
  407. results = []
  408. for i in range(batch_size):
  409. # Get the top 5 tokens and their probabilities for the masked position
  410. masked_index = inputs.input_ids[i].tolist().index(self.unmask_tokenizer.mask_token_id)
  411. top_tokens = torch.topk(probabilities[i][masked_index], k=1)
  412. # Decode the tokens to get the words
  413. word = self.unmask_tokenizer.convert_ids_to_tokens(top_tokens.indices)
  414. completed_text = sentences[i].replace(self.unmask_tokenizer.mask_token, word[0])
  415. results.append(completed_text)
  416. return results
  417. def image_deblurring(self, imgs, device):
  418. restoreds = []
  419. with torch.no_grad():
  420. img_multiple_of = 8
  421. for cur in imgs:
  422. h = cur.shape[1]
  423. w = cur.shape[2]
  424. img = cur.contiguous().view(h, w, 3)
  425. input_ = img.float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
  426. height,width = input_.shape[2], input_.shape[3]
  427. H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
  428. padh = H-height if height%img_multiple_of!=0 else 0
  429. padw = W-width if width%img_multiple_of!=0 else 0
  430. input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
  431. restored = self.image_deblurring_model(input_)
  432. restored = torch.clamp(restored, 0, 1)
  433. restored = restored[:,:,:height,:width]
  434. restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
  435. restored = img_as_ubyte(restored[0]).reshape((3,h,w))
  436. restored = torch.from_numpy(restored)
  437. restoreds.append(restored)
  438. #print(restored)
  439. return restoreds
  440. def image_denoising(self,imgs, device):
  441. restoreds = []
  442. img_multiple_of = 8
  443. with torch.no_grad():
  444. for cur in imgs:
  445. h = cur.shape[1]
  446. w = cur.shape[2]
  447. img = cur.contiguous().view(h, w, 3)
  448. input_ = img.float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
  449. height,width = input_.shape[2], input_.shape[3]
  450. H,W = ((height+img_multiple_of)//img_multiple_of)*img_multiple_of, ((width+img_multiple_of)//img_multiple_of)*img_multiple_of
  451. padh = H-height if height%img_multiple_of!=0 else 0
  452. padw = W-width if width%img_multiple_of!=0 else 0
  453. input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
  454. restored = self.image_denoising_model(input_)
  455. restored = torch.clamp(restored, 0, 1)
  456. restored = restored[:,:,:height,:width]
  457. restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
  458. restored = img_as_ubyte(restored[0]).reshape((3,h,w))
  459. restored = torch.from_numpy(restored)
  460. restoreds.append(restored)
  461. #print(restored)
  462. return restoreds