color.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from __future__ import annotations
  2. from colour import Color
  3. from colour import hex2rgb
  4. from colour import rgb2hex
  5. import numpy as np
  6. import random
  7. from manimlib.constants import COLORMAP_3B1B
  8. from manimlib.constants import WHITE
  9. from manimlib.utils.bezier import interpolate
  10. from manimlib.utils.iterables import resize_with_interpolation
  11. from typing import TYPE_CHECKING
  12. if TYPE_CHECKING:
  13. from typing import Iterable, Sequence
  14. from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array
  15. def color_to_rgb(color: ManimColor) -> Vect3:
  16. if isinstance(color, str):
  17. return hex_to_rgb(color)
  18. elif isinstance(color, Color):
  19. return np.array(color.get_rgb())
  20. else:
  21. raise Exception("Invalid color type")
  22. def color_to_rgba(color: ManimColor, alpha: float = 1.0) -> Vect4:
  23. return np.array([*color_to_rgb(color), alpha])
  24. def rgb_to_color(rgb: Vect3 | Sequence[float]) -> Color:
  25. try:
  26. return Color(rgb=tuple(rgb))
  27. except ValueError:
  28. return Color(WHITE)
  29. def rgba_to_color(rgba: Vect4) -> Color:
  30. return rgb_to_color(rgba[:3])
  31. def rgb_to_hex(rgb: Vect3 | Sequence[float]) -> str:
  32. return rgb2hex(rgb, force_long=True).upper()
  33. def hex_to_rgb(hex_code: str) -> Vect3:
  34. return np.array(hex2rgb(hex_code))
  35. def invert_color(color: ManimColor) -> Color:
  36. return rgb_to_color(1.0 - color_to_rgb(color))
  37. def color_to_int_rgb(color: ManimColor) -> np.ndarray[int, np.dtype[np.uint8]]:
  38. return (255 * color_to_rgb(color)).astype('uint8')
  39. def color_to_int_rgba(color: ManimColor, opacity: float = 1.0) -> np.ndarray[int, np.dtype[np.uint8]]:
  40. alpha = int(255 * opacity)
  41. return np.array([*color_to_int_rgb(color), alpha], dtype=np.uint8)
  42. def color_to_hex(color: ManimColor) -> str:
  43. return Color(color).get_hex_l().upper()
  44. def hex_to_int(rgb_hex: str) -> int:
  45. return int(rgb_hex[1:], 16)
  46. def int_to_hex(rgb_int: int) -> str:
  47. return f"#{rgb_int:06x}".upper()
  48. def color_gradient(
  49. reference_colors: Iterable[ManimColor],
  50. length_of_output: int
  51. ) -> list[Color]:
  52. if length_of_output == 0:
  53. return []
  54. rgbs = list(map(color_to_rgb, reference_colors))
  55. alphas = np.linspace(0, (len(rgbs) - 1), length_of_output)
  56. floors = alphas.astype('int')
  57. alphas_mod1 = alphas % 1
  58. # End edge case
  59. alphas_mod1[-1] = 1
  60. floors[-1] = len(rgbs) - 2
  61. return [
  62. rgb_to_color(np.sqrt(interpolate(rgbs[i]**2, rgbs[i + 1]**2, alpha)))
  63. for i, alpha in zip(floors, alphas_mod1)
  64. ]
  65. def interpolate_color(
  66. color1: ManimColor,
  67. color2: ManimColor,
  68. alpha: float
  69. ) -> Color:
  70. rgb = np.sqrt(interpolate(color_to_rgb(color1)**2, color_to_rgb(color2)**2, alpha))
  71. return rgb_to_color(rgb)
  72. def interpolate_color_by_hsl(
  73. color1: ManimColor,
  74. color2: ManimColor,
  75. alpha: float
  76. ) -> Color:
  77. hsl1 = np.array(Color(color1).get_hsl())
  78. hsl2 = np.array(Color(color2).get_hsl())
  79. return Color(hsl=interpolate(hsl1, hsl2, alpha))
  80. def average_color(*colors: ManimColor) -> Color:
  81. rgbs = np.array(list(map(color_to_rgb, colors)))
  82. return rgb_to_color(np.sqrt((rgbs**2).mean(0)))
  83. def random_color() -> Color:
  84. return Color(rgb=tuple(np.random.random(3)))
  85. def random_bright_color(
  86. hue_range: tuple[float, float] = (0.0, 1.0),
  87. saturation_range: tuple[float, float] = (0.5, 0.8),
  88. luminance_range: tuple[float, float] = (0.5, 1.0),
  89. ) -> Color:
  90. return Color(hsl=(
  91. interpolate(*hue_range, random.random()),
  92. interpolate(*saturation_range, random.random()),
  93. interpolate(*luminance_range, random.random()),
  94. ))
  95. def get_colormap_list(
  96. map_name: str = "viridis",
  97. n_colors: int = 9
  98. ) -> Vect3Array:
  99. """
  100. Options for map_name:
  101. 3b1b_colormap
  102. magma
  103. inferno
  104. plasma
  105. viridis
  106. cividis
  107. twilight
  108. twilight_shifted
  109. turbo
  110. """
  111. from matplotlib.cm import cmaps_listed
  112. if map_name == "3b1b_colormap":
  113. rgbs = np.array([color_to_rgb(color) for color in COLORMAP_3B1B])
  114. else:
  115. rgbs = cmaps_listed[map_name].colors # Make more general?
  116. return resize_with_interpolation(np.array(rgbs), n_colors)