image.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. from __future__ import annotations
  2. import re
  3. from io import BytesIO
  4. import base64
  5. from .typing import ImageType, Union, Image
  6. try:
  7. from PIL.Image import open as open_image, new as new_image
  8. from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
  9. has_requirements = True
  10. except ImportError:
  11. has_requirements = False
  12. from .errors import MissingRequirementsError
  13. ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
  14. EXTENSIONS_MAP: dict[str, str] = {
  15. "image/png": "png",
  16. "image/jpeg": "jpg",
  17. "image/gif": "gif",
  18. "image/webp": "webp",
  19. }
  20. def to_image(image: ImageType, is_svg: bool = False) -> Image:
  21. """
  22. Converts the input image to a PIL Image object.
  23. Args:
  24. image (Union[str, bytes, Image]): The input image.
  25. Returns:
  26. Image: The converted PIL Image object.
  27. """
  28. if not has_requirements:
  29. raise MissingRequirementsError('Install "pillow" package for images')
  30. if isinstance(image, str):
  31. is_data_uri_an_image(image)
  32. image = extract_data_uri(image)
  33. if is_svg:
  34. try:
  35. import cairosvg
  36. except ImportError:
  37. raise MissingRequirementsError('Install "cairosvg" package for svg images')
  38. if not isinstance(image, bytes):
  39. image = image.read()
  40. buffer = BytesIO()
  41. cairosvg.svg2png(image, write_to=buffer)
  42. return open_image(buffer)
  43. if isinstance(image, bytes):
  44. is_accepted_format(image)
  45. return open_image(BytesIO(image))
  46. elif not isinstance(image, Image):
  47. image = open_image(image)
  48. image.load()
  49. return image
  50. return image
  51. def is_allowed_extension(filename: str) -> bool:
  52. """
  53. Checks if the given filename has an allowed extension.
  54. Args:
  55. filename (str): The filename to check.
  56. Returns:
  57. bool: True if the extension is allowed, False otherwise.
  58. """
  59. return '.' in filename and \
  60. filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  61. def is_data_uri_an_image(data_uri: str) -> bool:
  62. """
  63. Checks if the given data URI represents an image.
  64. Args:
  65. data_uri (str): The data URI to check.
  66. Raises:
  67. ValueError: If the data URI is invalid or the image format is not allowed.
  68. """
  69. # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
  70. if not re.match(r'data:image/(\w+);base64,', data_uri):
  71. raise ValueError("Invalid data URI image.")
  72. # Extract the image format from the data URI
  73. image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
  74. # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
  75. if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
  76. raise ValueError("Invalid image format (from mime file type).")
  77. def is_accepted_format(binary_data: bytes) -> str:
  78. """
  79. Checks if the given binary data represents an image with an accepted format.
  80. Args:
  81. binary_data (bytes): The binary data to check.
  82. Raises:
  83. ValueError: If the image format is not allowed.
  84. """
  85. if binary_data.startswith(b'\xFF\xD8\xFF'):
  86. return "image/jpeg"
  87. elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
  88. return "image/png"
  89. elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
  90. return "image/gif"
  91. elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
  92. return "image/jpeg"
  93. elif binary_data.startswith(b'\xFF\xD8'):
  94. return "image/jpeg"
  95. elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
  96. return "image/webp"
  97. else:
  98. raise ValueError("Invalid image format (from magic code).")
  99. def extract_data_uri(data_uri: str) -> bytes:
  100. """
  101. Extracts the binary data from the given data URI.
  102. Args:
  103. data_uri (str): The data URI.
  104. Returns:
  105. bytes: The extracted binary data.
  106. """
  107. data = data_uri.split(",")[1]
  108. data = base64.b64decode(data)
  109. return data
  110. def get_orientation(image: Image) -> int:
  111. """
  112. Gets the orientation of the given image.
  113. Args:
  114. image (Image): The image.
  115. Returns:
  116. int: The orientation value.
  117. """
  118. exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
  119. if exif_data is not None:
  120. orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
  121. if orientation is not None:
  122. return orientation
  123. def process_image(image: Image, new_width: int, new_height: int) -> Image:
  124. """
  125. Processes the given image by adjusting its orientation and resizing it.
  126. Args:
  127. image (Image): The image to process.
  128. new_width (int): The new width of the image.
  129. new_height (int): The new height of the image.
  130. Returns:
  131. Image: The processed image.
  132. """
  133. # Fix orientation
  134. orientation = get_orientation(image)
  135. if orientation:
  136. if orientation > 4:
  137. image = image.transpose(FLIP_LEFT_RIGHT)
  138. if orientation in [3, 4]:
  139. image = image.transpose(ROTATE_180)
  140. if orientation in [5, 6]:
  141. image = image.transpose(ROTATE_270)
  142. if orientation in [7, 8]:
  143. image = image.transpose(ROTATE_90)
  144. # Resize image
  145. image.thumbnail((new_width, new_height))
  146. # Remove transparency
  147. if image.mode == "RGBA":
  148. image.load()
  149. white = new_image('RGB', image.size, (255, 255, 255))
  150. white.paste(image, mask=image.split()[-1])
  151. return white
  152. # Convert to RGB for jpg format
  153. elif image.mode != "RGB":
  154. image = image.convert("RGB")
  155. return image
  156. def to_base64_jpg(image: Image, compression_rate: float) -> str:
  157. """
  158. Converts the given image to a base64-encoded string.
  159. Args:
  160. image (Image.Image): The image to convert.
  161. compression_rate (float): The compression rate (0.0 to 1.0).
  162. Returns:
  163. str: The base64-encoded image.
  164. """
  165. output_buffer = BytesIO()
  166. image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
  167. return base64.b64encode(output_buffer.getvalue()).decode()
  168. def format_images_markdown(images: Union[str, list], alt: str, preview: Union[str, list] = None) -> str:
  169. """
  170. Formats the given images as a markdown string.
  171. Args:
  172. images: The images to format.
  173. alt (str): The alt for the images.
  174. preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
  175. Returns:
  176. str: The formatted markdown string.
  177. """
  178. if isinstance(images, str):
  179. result = f"[![{alt}]({preview.replace('{image}', images) if preview else images})]({images})"
  180. else:
  181. if not isinstance(preview, list):
  182. preview = [preview.replace('{image}', image) if preview else image for image in images]
  183. result = "\n".join(
  184. f"[![#{idx+1} {alt}]({preview[idx]})]({image})"
  185. #f'[<img src="{preview[idx]}" width="200" alt="#{idx+1} {alt}">]({image})'
  186. for idx, image in enumerate(images)
  187. )
  188. start_flag = "<!-- generated images start -->\n"
  189. end_flag = "<!-- generated images end -->\n"
  190. return f"\n{start_flag}{result}\n{end_flag}\n"
  191. def to_bytes(image: ImageType) -> bytes:
  192. """
  193. Converts the given image to bytes.
  194. Args:
  195. image (ImageType): The image to convert.
  196. Returns:
  197. bytes: The image as bytes.
  198. """
  199. if isinstance(image, bytes):
  200. return image
  201. elif isinstance(image, str):
  202. is_data_uri_an_image(image)
  203. return extract_data_uri(image)
  204. elif isinstance(image, Image):
  205. bytes_io = BytesIO()
  206. image.save(bytes_io, image.format)
  207. image.seek(0)
  208. return bytes_io.getvalue()
  209. else:
  210. return image.read()
  211. def to_data_uri(image: ImageType) -> str:
  212. if not isinstance(image, str):
  213. data = to_bytes(image)
  214. data_base64 = base64.b64encode(data).decode()
  215. return f"data:{is_accepted_format(data)};base64,{data_base64}"
  216. return image
  217. class ImageResponse:
  218. def __init__(
  219. self,
  220. images: Union[str, list],
  221. alt: str,
  222. options: dict = {}
  223. ):
  224. self.images = images
  225. self.alt = alt
  226. self.options = options
  227. def __str__(self) -> str:
  228. return format_images_markdown(self.images, self.alt, self.get("preview"))
  229. def get(self, key: str):
  230. return self.options.get(key)
  231. def get_list(self) -> list[str]:
  232. return [self.images] if isinstance(self.images, str) else self.images
  233. class ImagePreview(ImageResponse):
  234. def __str__(self):
  235. return ""
  236. def to_string(self):
  237. return super().__str__()
  238. class ImageDataResponse():
  239. def __init__(
  240. self,
  241. images: Union[str, list],
  242. alt: str,
  243. ):
  244. self.images = images
  245. self.alt = alt
  246. def get_list(self) -> list[str]:
  247. return [self.images] if isinstance(self.images, str) else self.images
  248. class ImageRequest:
  249. def __init__(
  250. self,
  251. options: dict = {}
  252. ):
  253. self.options = options
  254. def get(self, key: str):
  255. return self.options.get(key)