matrix.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from __future__ import annotations
  2. import itertools as it
  3. import numpy as np
  4. from manimlib.constants import DOWN, LEFT, RIGHT, ORIGIN
  5. from manimlib.constants import DEGREES
  6. from manimlib.mobject.numbers import DecimalNumber
  7. from manimlib.mobject.svg.tex_mobject import Tex
  8. from manimlib.mobject.types.vectorized_mobject import VGroup
  9. from manimlib.mobject.types.vectorized_mobject import VMobject
  10. from typing import TYPE_CHECKING
  11. if TYPE_CHECKING:
  12. from typing import Sequence, Union, Tuple, Optional
  13. from manimlib.typing import ManimColor, Vect3, VectNArray, Self
  14. StringMatrixType = Union[Sequence[Sequence[str]], np.ndarray[int, np.dtype[np.str_]]]
  15. FloatMatrixType = Union[Sequence[Sequence[float]], VectNArray]
  16. VMobjectMatrixType = Sequence[Sequence[VMobject]]
  17. GenericMatrixType = Union[FloatMatrixType, StringMatrixType, VMobjectMatrixType]
  18. class Matrix(VMobject):
  19. def __init__(
  20. self,
  21. matrix: GenericMatrixType,
  22. v_buff: float = 0.5,
  23. h_buff: float = 0.5,
  24. bracket_h_buff: float = 0.2,
  25. bracket_v_buff: float = 0.25,
  26. height: float | None = None,
  27. element_config: dict = dict(),
  28. element_alignment_corner: Vect3 = DOWN,
  29. ellipses_row: Optional[int] = None,
  30. ellipses_col: Optional[int] = None,
  31. ):
  32. """
  33. Matrix can either include numbers, tex_strings,
  34. or mobjects
  35. """
  36. super().__init__()
  37. self.mob_matrix = self.create_mobject_matrix(
  38. matrix, v_buff, h_buff, element_alignment_corner,
  39. **element_config
  40. )
  41. # Create helpful groups for the elements
  42. n_cols = len(self.mob_matrix[0])
  43. self.elements = [elem for row in self.mob_matrix for elem in row]
  44. self.columns = VGroup(*(
  45. VGroup(*(row[i] for row in self.mob_matrix))
  46. for i in range(n_cols)
  47. ))
  48. self.rows = VGroup(*(VGroup(*row) for row in self.mob_matrix))
  49. if height is not None:
  50. self.rows.set_height(height - 2 * bracket_v_buff)
  51. self.brackets = self.create_brackets(self.rows, bracket_v_buff, bracket_h_buff)
  52. self.ellipses = []
  53. # Add elements and brackets
  54. self.add(*self.elements)
  55. self.add(*self.brackets)
  56. self.center()
  57. # Potentially add ellipses
  58. self.swap_entries_for_ellipses(
  59. ellipses_row,
  60. ellipses_col,
  61. )
  62. def copy(self, deep: bool = False):
  63. result = super().copy(deep)
  64. self_family = self.get_family()
  65. copy_family = result.get_family()
  66. for attr in ["elements", "ellipses"]:
  67. setattr(result, attr, [
  68. copy_family[self_family.index(mob)]
  69. for mob in getattr(self, attr)
  70. ])
  71. return result
  72. def create_mobject_matrix(
  73. self,
  74. matrix: GenericMatrixType,
  75. v_buff: float,
  76. h_buff: float,
  77. aligned_corner: Vect3,
  78. **element_config
  79. ) -> VMobjectMatrixType:
  80. """
  81. Creates and organizes the matrix of mobjects
  82. """
  83. mob_matrix = [
  84. [
  85. self.element_to_mobject(element, **element_config)
  86. for element in row
  87. ]
  88. for row in matrix
  89. ]
  90. max_width = max(elem.get_width() for row in mob_matrix for elem in row)
  91. max_height = max(elem.get_height() for row in mob_matrix for elem in row)
  92. x_step = (max_width + h_buff) * RIGHT
  93. y_step = (max_height + v_buff) * DOWN
  94. for i, row in enumerate(mob_matrix):
  95. for j, elem in enumerate(row):
  96. elem.move_to(i * y_step + j * x_step, aligned_corner)
  97. return mob_matrix
  98. def element_to_mobject(self, element, **config) -> VMobject:
  99. if isinstance(element, VMobject):
  100. return element
  101. elif isinstance(element, float | complex):
  102. return DecimalNumber(element, **config)
  103. else:
  104. return Tex(str(element), **config)
  105. def create_brackets(self, rows, v_buff: float, h_buff: float) -> VGroup:
  106. brackets = Tex("".join((
  107. R"\left[\begin{array}{c}",
  108. *len(rows) * [R"\quad \\"],
  109. R"\end{array}\right]",
  110. )))
  111. brackets.set_height(rows.get_height() + v_buff)
  112. l_bracket = brackets[:len(brackets) // 2]
  113. r_bracket = brackets[len(brackets) // 2:]
  114. l_bracket.next_to(rows, LEFT, h_buff)
  115. r_bracket.next_to(rows, RIGHT, h_buff)
  116. return VGroup(l_bracket, r_bracket)
  117. def get_column(self, index: int):
  118. if not 0 <= index < len(self.columns):
  119. raise IndexError(f"Index {index} out of bound for matrix with {len(self.columns)} columns")
  120. return self.columns[index]
  121. def get_row(self, index: int):
  122. if not 0 <= index < len(self.rows):
  123. raise IndexError(f"Index {index} out of bound for matrix with {len(self.rows)} rows")
  124. return self.rows[index]
  125. def get_columns(self) -> VGroup:
  126. return self.columns
  127. def get_rows(self) -> VGroup:
  128. return self.rows
  129. def set_column_colors(self, *colors: ManimColor) -> Self:
  130. columns = self.get_columns()
  131. for color, column in zip(colors, columns):
  132. column.set_color(color)
  133. return self
  134. def add_background_to_entries(self) -> Self:
  135. for mob in self.get_entries():
  136. mob.add_background_rectangle()
  137. return self
  138. def swap_entry_for_dots(self, entry, dots):
  139. dots.move_to(entry)
  140. entry.become(dots)
  141. if entry in self.elements:
  142. self.elements.remove(entry)
  143. if entry not in self.ellipses:
  144. self.ellipses.append(entry)
  145. def swap_entries_for_ellipses(
  146. self,
  147. row_index: Optional[int] = None,
  148. col_index: Optional[int] = None,
  149. height_ratio: float = 0.65,
  150. width_ratio: float = 0.4
  151. ):
  152. rows = self.get_rows()
  153. cols = self.get_columns()
  154. avg_row_height = rows.get_height() / len(rows)
  155. vdots_height = height_ratio * avg_row_height
  156. avg_col_width = cols.get_width() / len(cols)
  157. hdots_width = width_ratio * avg_col_width
  158. use_vdots = row_index is not None and -len(rows) <= row_index < len(rows)
  159. use_hdots = col_index is not None and -len(cols) <= col_index < len(cols)
  160. if use_vdots:
  161. for column in cols:
  162. # Add vdots
  163. dots = Tex(R"\vdots")
  164. dots.set_height(vdots_height)
  165. self.swap_entry_for_dots(column[row_index], dots)
  166. if use_hdots:
  167. for row in rows:
  168. # Add hdots
  169. dots = Tex(R"\hdots")
  170. dots.set_width(hdots_width)
  171. self.swap_entry_for_dots(row[col_index], dots)
  172. if use_vdots and use_hdots:
  173. rows[row_index][col_index].rotate(-45 * DEGREES)
  174. return self
  175. def get_mob_matrix(self) -> VMobjectMatrixType:
  176. return self.mob_matrix
  177. def get_entries(self) -> VGroup:
  178. return VGroup(*self.elements)
  179. def get_brackets(self) -> VGroup:
  180. return VGroup(*self.brackets)
  181. def get_ellipses(self) -> VGroup:
  182. return VGroup(*self.ellipses)
  183. class DecimalMatrix(Matrix):
  184. def __init__(
  185. self,
  186. matrix: FloatMatrixType,
  187. num_decimal_places: int = 2,
  188. decimal_config: dict = dict(),
  189. **config
  190. ):
  191. self.float_matrix = matrix
  192. super().__init__(
  193. matrix,
  194. element_config=dict(
  195. num_decimal_places=num_decimal_places,
  196. **decimal_config
  197. ),
  198. **config
  199. )
  200. def element_to_mobject(self, element, **decimal_config) -> DecimalNumber:
  201. return DecimalNumber(element, **decimal_config)
  202. class IntegerMatrix(DecimalMatrix):
  203. def __init__(
  204. self,
  205. matrix: FloatMatrixType,
  206. num_decimal_places: int = 0,
  207. decimal_config: dict = dict(),
  208. **config
  209. ):
  210. super().__init__(matrix, num_decimal_places, decimal_config, **config)
  211. class TexMatrix(Matrix):
  212. def __init__(
  213. self,
  214. matrix: StringMatrixType,
  215. tex_config: dict = dict(),
  216. **config,
  217. ):
  218. super().__init__(
  219. matrix,
  220. element_config=tex_config,
  221. **config
  222. )
  223. class MobjectMatrix(Matrix):
  224. def __init__(
  225. self,
  226. group: VGroup,
  227. n_rows: int | None = None,
  228. n_cols: int | None = None,
  229. height: float = 4.0,
  230. element_alignment_corner=ORIGIN,
  231. **config,
  232. ):
  233. # Have fallback defaults of n_rows and n_cols
  234. n_mobs = len(group)
  235. if n_rows is None:
  236. n_rows = int(np.sqrt(n_mobs)) if n_cols is None else n_mobs // n_cols
  237. if n_cols is None:
  238. n_cols = n_mobs // n_rows
  239. if len(group) < n_rows * n_cols:
  240. raise Exception("Input to MobjectMatrix must have at least n_rows * n_cols entries")
  241. mob_matrix = [
  242. [group[n * n_cols + k] for k in range(n_cols)]
  243. for n in range(n_rows)
  244. ]
  245. config.update(
  246. height=height,
  247. element_alignment_corner=element_alignment_corner,
  248. )
  249. super().__init__(mob_matrix, **config)
  250. def element_to_mobject(self, element: VMobject, **config) -> VMobject:
  251. return element