images.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import logging
  2. import numpy as np
  3. from ray.rllib.utils.annotations import DeveloperAPI
  4. logger = logging.getLogger(__name__)
  5. try:
  6. import cv2
  7. cv2.ocl.setUseOpenCL(False)
  8. logger.debug("CV2 found for image processing.")
  9. except ImportError:
  10. cv2 = None
  11. if cv2 is None:
  12. try:
  13. from skimage import color, io, transform
  14. logger.debug("CV2 not found for image processing, using Skimage.")
  15. except ImportError:
  16. raise ModuleNotFoundError("Either scikit-image or opencv is required")
  17. @DeveloperAPI
  18. def resize(img: np.ndarray, height: int, width: int) -> np.ndarray:
  19. if cv2:
  20. return cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
  21. return transform.resize(img, (height, width))
  22. @DeveloperAPI
  23. def rgb2gray(img: np.ndarray) -> np.ndarray:
  24. if cv2:
  25. return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  26. return color.rgb2gray(img)
  27. @DeveloperAPI
  28. def imread(img_file: str) -> np.ndarray:
  29. if cv2:
  30. return cv2.imread(img_file).astype(np.float32)
  31. return io.imread(img_file).astype(np.float32)