auto_plot_image.py 820 B

123456789101112131415161718192021
  1. import torch
  2. import numpy as np
  3. import cv2
  4. def plot_image(save_path, image, convert_RGB2BGR=True):
  5. if isinstance(image, torch.Tensor):
  6. image = image.detach().cpu().numpy()
  7. image = image.astype(float)
  8. if image.max() < 1.1 and image.min() > -0.1: # [0, 1]
  9. image = image * 255
  10. elif image.max() < 1.1 and image.min() > -1.1: # [-1, 1]
  11. image = (image + 1.0) * 0.5 * 255
  12. image = image.clip(0, 255)
  13. image = image.astype(np.uint8)
  14. if len(image.shape) == 4 and image.shape[0] == 1:
  15. image = image[0]
  16. if len(image.shape) == 3 and image.shape[0] <= 4: # C, H, W
  17. image = torch.from_numpy(image).permute(1, 2, 0).numpy()
  18. if len(image.shape) == 3 and convert_RGB2BGR:
  19. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  20. cv2.imwrite(save_path, image)