mobject.py 78 KB


  1. from __future__ import annotations
  2. import copy
  3. from functools import wraps
  4. import itertools as it
  5. import os
  6. import pickle
  7. import random
  8. import sys
  9. import moderngl
  10. import numbers
  11. import numpy as np
  12. from manimlib.constants import DEFAULT_MOBJECT_TO_EDGE_BUFFER
  13. from manimlib.constants import DEFAULT_MOBJECT_TO_MOBJECT_BUFFER
  14. from manimlib.constants import DOWN, IN, LEFT, ORIGIN, OUT, RIGHT, UP
  15. from manimlib.constants import FRAME_X_RADIUS, FRAME_Y_RADIUS
  16. from manimlib.constants import MED_SMALL_BUFF
  17. from manimlib.constants import TAU
  18. from manimlib.constants import WHITE
  19. from manimlib.event_handler import EVENT_DISPATCHER
  20. from manimlib.event_handler.event_listner import EventListener
  21. from manimlib.event_handler.event_type import EventType
  22. from manimlib.logger import log
  23. from manimlib.shader_wrapper import ShaderWrapper
  24. from manimlib.utils.color import color_gradient
  25. from manimlib.utils.color import color_to_rgb
  26. from manimlib.utils.color import get_colormap_list
  27. from manimlib.utils.color import rgb_to_hex
  28. from manimlib.utils.iterables import arrays_match
  29. from manimlib.utils.iterables import array_is_constant
  30. from manimlib.utils.iterables import batch_by_property
  31. from manimlib.utils.iterables import list_update
  32. from manimlib.utils.iterables import listify
  33. from manimlib.utils.iterables import resize_array
  34. from manimlib.utils.iterables import resize_preserving_order
  35. from manimlib.utils.iterables import resize_with_interpolation
  36. from manimlib.utils.bezier import integer_interpolate
  37. from manimlib.utils.bezier import interpolate
  38. from manimlib.utils.paths import straight_path
  39. from manimlib.utils.shaders import get_colormap_code
  40. from manimlib.utils.space_ops import angle_of_vector
  41. from manimlib.utils.space_ops import get_norm
  42. from manimlib.utils.space_ops import rotation_matrix_transpose
  43. from typing import TYPE_CHECKING
  44. from typing import TypeVar, Generic, Iterable
  45. SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
  46. if TYPE_CHECKING:
  47. from typing import Callable, Iterator, Union, Tuple, Optional, Any
  48. import numpy.typing as npt
  49. from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
  50. from moderngl.context import Context
  51. T = TypeVar('T')
  52. TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None]
  53. NonTimeUpdater = Callable[["Mobject"], "Mobject" | None]
  54. Updater = Union[TimeBasedUpdater, NonTimeUpdater]
  55. class Mobject(object):
  56. """
  57. Mathematical Object
  58. """
  59. dim: int = 3
  60. shader_folder: str = ""
  61. render_primitive: int = moderngl.TRIANGLE_STRIP
  62. # Must match in attributes of vert shader
  63. data_dtype: np.dtype = np.dtype([
  64. ('point', np.float32, (3,)),
  65. ('rgba', np.float32, (4,)),
  66. ])
  67. aligned_data_keys = ['point']
  68. pointlike_data_keys = ['point']
  69. def __init__(
  70. self,
  71. color: ManimColor = WHITE,
  72. opacity: float = 1.0,
  73. shading: Tuple[float, float, float] = (0.0, 0.0, 0.0),
  74. # For shaders
  75. texture_paths: dict[str, str] | None = None,
  76. # If true, the mobject will not get rotated according to camera position
  77. is_fixed_in_frame: bool = False,
  78. depth_test: bool = False,
  79. z_index: int = 0,
  80. ):
  81. self.color = color
  82. self.opacity = opacity
  83. self.shading = shading
  84. self.texture_paths = texture_paths
  85. self.depth_test = depth_test
  86. self.z_index = z_index
  87. # Internal state
  88. self.submobjects: list[Mobject] = []
  89. self.parents: list[Mobject] = []
  90. self.family: list[Mobject] | None = [self]
  91. self.locked_data_keys: set[str] = set()
  92. self.const_data_keys: set[str] = set()
  93. self.locked_uniform_keys: set[str] = set()
  94. self.saved_state = None
  95. self.target = None
  96. self.bounding_box: Vect3Array = np.zeros((3, 3))
  97. self.shader_wrapper: Optional[ShaderWrapper] = None
  98. self._is_animating: bool = False
  99. self._needs_new_bounding_box: bool = True
  100. self._data_has_changed: bool = True
  101. self.shader_code_replacements: dict[str, str] = dict()
  102. self.init_data()
  103. self.init_uniforms()
  104. self.init_updaters()
  105. self.init_event_listners()
  106. self.init_points()
  107. self.init_colors()
  108. if self.depth_test:
  109. self.apply_depth_test()
  110. if is_fixed_in_frame:
  111. self.fix_in_frame()
  112. def __str__(self):
  113. return self.__class__.__name__
  114. def __add__(self, other: Mobject) -> Mobject:
  115. assert isinstance(other, Mobject)
  116. return self.get_group_class()(self, other)
  117. def __mul__(self, other: int) -> Mobject:
  118. assert isinstance(other, int)
  119. return self.replicate(other)
  120. def init_data(self, length: int = 0):
  121. self.data = np.zeros(length, dtype=self.data_dtype)
  122. self._data_defaults = np.ones(1, dtype=self.data.dtype)
  123. def init_uniforms(self):
  124. self.uniforms: UniformDict = {
  125. "is_fixed_in_frame": 0.0,
  126. "shading": np.array(self.shading, dtype=float),
  127. "clip_plane": np.zeros(4),
  128. }
  129. def init_colors(self):
  130. self.set_color(self.color, self.opacity)
  131. def init_points(self):
  132. # Typically implemented in subclass, unlpess purposefully left blank
  133. pass
  134. def set_uniforms(self, uniforms: dict) -> Self:
  135. for key, value in uniforms.items():
  136. if isinstance(value, np.ndarray):
  137. value = value.copy()
  138. self.uniforms[key] = value
  139. return self
  140. @property
  141. def animate(self) -> _AnimationBuilder:
  142. """
  143. Methods called with Mobject.animate.method() can be passed
  144. into a Scene.play call, as if you were calling
  145. ApplyMethod(mobject.method)
  146. Borrowed from https://github.com/ManimCommunity/manim/
  147. """
  148. return _AnimationBuilder(self)
  149. @property
  150. def always(self) -> _UpdaterBuilder:
  151. """
  152. Methods called with mobject.always.method(*args, **kwargs)
  153. will result in the call mobject.method(*args, **kwargs)
  154. on every frame
  155. """
  156. return _UpdaterBuilder(self)
  157. @property
  158. def f_always(self) -> _FunctionalUpdaterBuilder:
  159. """
  160. Similar to Mobject.always, but with the intent that arguments
  161. are functions returning the corresponding type fit for the method
  162. Methods called with
  163. mobject.f_always.method(
  164. func1, func2, ...,
  165. kwarg1=kw_func1,
  166. kwarg2=kw_func2,
  167. ...
  168. )
  169. will result in the call
  170. mobject.method(
  171. func1(), func2(), ...,
  172. kwarg1=kw_func1(),
  173. kwarg2=kw_func2(),
  174. ...
  175. )
  176. on every frame
  177. """
  178. return _FunctionalUpdaterBuilder(self)
  179. def note_changed_data(self, recurse_up: bool = True) -> Self:
  180. self._data_has_changed = True
  181. if recurse_up:
  182. for mob in self.parents:
  183. mob.note_changed_data()
  184. return self
  185. @staticmethod
  186. def affects_data(func: Callable[..., T]) -> Callable[..., T]:
  187. @wraps(func)
  188. def wrapper(self, *args, **kwargs):
  189. result = func(self, *args, **kwargs)
  190. self.note_changed_data()
  191. return result
  192. return wrapper
  193. @staticmethod
  194. def affects_family_data(func: Callable[..., T]) -> Callable[..., T]:
  195. @wraps(func)
  196. def wrapper(self, *args, **kwargs):
  197. result = func(self, *args, **kwargs)
  198. for mob in self.family_members_with_points():
  199. mob.note_changed_data()
  200. return result
  201. return wrapper
  202. # Only these methods should directly affect points
  203. @affects_data
  204. def set_data(self, data: np.ndarray) -> Self:
  205. assert data.dtype == self.data.dtype
  206. self.resize_points(len(data))
  207. self.data[:] = data
  208. return self
  209. @affects_data
  210. def resize_points(
  211. self,
  212. new_length: int,
  213. resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array
  214. ) -> Self:
  215. if new_length == 0:
  216. if len(self.data) > 0:
  217. self._data_defaults[:1] = self.data[:1]
  218. elif self.get_num_points() == 0:
  219. self.data = self._data_defaults.copy()
  220. self.data = resize_func(self.data, new_length)
  221. self.refresh_bounding_box()
  222. return self
  223. @affects_data
  224. def set_points(self, points: Vect3Array | list[Vect3]) -> Self:
  225. self.resize_points(len(points), resize_func=resize_preserving_order)
  226. self.data["point"][:] = points
  227. return self
  228. @affects_data
  229. def append_points(self, new_points: Vect3Array) -> Self:
  230. n = self.get_num_points()
  231. self.resize_points(n + len(new_points))
  232. # Have most data default to the last value
  233. self.data[n:] = self.data[n - 1]
  234. # Then read in new points
  235. self.data["point"][n:] = new_points
  236. self.refresh_bounding_box()
  237. return self
  238. @affects_family_data
  239. def reverse_points(self) -> Self:
  240. for mob in self.get_family():
  241. mob.data[:] = mob.data[::-1]
  242. return self
  243. @affects_family_data
  244. def apply_points_function(
  245. self,
  246. func: Callable[[np.ndarray], np.ndarray],
  247. about_point: Vect3 | None = None,
  248. about_edge: Vect3 = ORIGIN,
  249. works_on_bounding_box: bool = False
  250. ) -> Self:
  251. if about_point is None and about_edge is not None:
  252. about_point = self.get_bounding_box_point(about_edge)
  253. for mob in self.get_family():
  254. arrs = []
  255. if mob.has_points():
  256. for key in mob.pointlike_data_keys:
  257. arrs.append(mob.data[key])
  258. if works_on_bounding_box:
  259. arrs.append(mob.get_bounding_box())
  260. for arr in arrs:
  261. if about_point is None:
  262. arr[:] = func(arr)
  263. else:
  264. arr[:] = func(arr - about_point) + about_point
  265. if not works_on_bounding_box:
  266. self.refresh_bounding_box(recurse_down=True)
  267. else:
  268. for parent in self.parents:
  269. parent.refresh_bounding_box()
  270. return self
  271. # Others related to points
  272. def match_points(self, mobject: Mobject) -> Self:
  273. self.set_points(mobject.get_points())
  274. return self
  275. def get_points(self) -> Vect3Array:
  276. return self.data["point"]
  277. def clear_points(self) -> Self:
  278. self.resize_points(0)
  279. return self
  280. def get_num_points(self) -> int:
  281. return len(self.get_points())
  282. def get_all_points(self) -> Vect3Array:
  283. if self.submobjects:
  284. return np.vstack([sm.get_points() for sm in self.get_family()])
  285. else:
  286. return self.get_points()
  287. def has_points(self) -> bool:
  288. return len(self.get_points()) > 0
  289. def get_bounding_box(self) -> Vect3Array:
  290. if self._needs_new_bounding_box:
  291. self.bounding_box[:] = self.compute_bounding_box()
  292. self._needs_new_bounding_box = False
  293. return self.bounding_box
  294. def compute_bounding_box(self) -> Vect3Array:
  295. all_points = np.vstack([
  296. self.get_points(),
  297. *(
  298. mob.get_bounding_box()
  299. for mob in self.get_family()[1:]
  300. if mob.has_points()
  301. )
  302. ])
  303. if len(all_points) == 0:
  304. return np.zeros((3, self.dim))
  305. else:
  306. # Lower left and upper right corners
  307. mins = all_points.min(0)
  308. maxs = all_points.max(0)
  309. mids = (mins + maxs) / 2
  310. return np.array([mins, mids, maxs])
  311. def refresh_bounding_box(
  312. self,
  313. recurse_down: bool = False,
  314. recurse_up: bool = True
  315. ) -> Self:
  316. for mob in self.get_family(recurse_down):
  317. mob._needs_new_bounding_box = True
  318. if recurse_up:
  319. for parent in self.parents:
  320. parent.refresh_bounding_box()
  321. return self
  322. def are_points_touching(
  323. self,
  324. points: Vect3Array,
  325. buff: float = 0
  326. ) -> np.ndarray:
  327. bb = self.get_bounding_box()
  328. mins = (bb[0] - buff)
  329. maxs = (bb[2] + buff)
  330. return ((points >= mins) * (points <= maxs)).all(1)
  331. def is_point_touching(
  332. self,
  333. point: Vect3,
  334. buff: float = 0
  335. ) -> bool:
  336. return self.are_points_touching(np.array(point, ndmin=2), buff)[0]
  337. def is_touching(self, mobject: Mobject, buff: float = 1e-2) -> bool:
  338. bb1 = self.get_bounding_box()
  339. bb2 = mobject.get_bounding_box()
  340. return not any((
  341. (bb2[2] < bb1[0] - buff).any(), # E.g. Right of mobject is left of self's left
  342. (bb2[0] > bb1[2] + buff).any(), # E.g. Left of mobject is right of self's right
  343. ))
  344. # Family matters
  345. def __getitem__(self, value: int | slice) -> Mobject:
  346. if isinstance(value, slice):
  347. GroupClass = self.get_group_class()
  348. return GroupClass(*self.split().__getitem__(value))
  349. return self.split().__getitem__(value)
  350. def __iter__(self) -> Iterator[Self]:
  351. return iter(self.split())
  352. def __len__(self) -> int:
  353. return len(self.split())
  354. def split(self) -> list[Self]:
  355. return self.submobjects
  356. @affects_data
  357. def note_changed_family(self, only_changed_order=False) -> Self:
  358. self.family = None
  359. if not only_changed_order:
  360. self.refresh_has_updater_status()
  361. self.refresh_bounding_box()
  362. for parent in self.parents:
  363. parent.note_changed_family()
  364. return self
  365. def get_family(self, recurse: bool = True) -> list[Mobject]:
  366. if not recurse:
  367. return [self]
  368. if self.family is None:
  369. # Reconstruct and save
  370. sub_families = (sm.get_family() for sm in self.submobjects)
  371. self.family = [self, *it.chain(*sub_families)]
  372. return self.family
  373. def family_members_with_points(self) -> list[Mobject]:
  374. return [m for m in self.get_family() if len(m.data) > 0]
  375. def get_ancestors(self, extended: bool = False) -> list[Mobject]:
  376. """
  377. Returns parents, grandparents, etc.
  378. Order of result should be from higher members of the hierarchy down.
  379. If extended is set to true, it includes the ancestors of all family members,
  380. e.g. any other parents of a submobject
  381. """
  382. ancestors = []
  383. to_process = list(self.get_family(recurse=extended))
  384. excluded = set(to_process)
  385. while to_process:
  386. for p in to_process.pop().parents:
  387. if p not in excluded:
  388. ancestors.append(p)
  389. to_process.append(p)
  390. # Ensure mobjects highest in the hierarchy show up first
  391. ancestors.reverse()
  392. # Remove list redundancies while preserving order
  393. return list(dict.fromkeys(ancestors))
  394. def add(self, *mobjects: Mobject) -> Self:
  395. if self in mobjects:
  396. raise Exception("Mobject cannot contain self")
  397. for mobject in mobjects:
  398. if mobject not in self.submobjects:
  399. self.submobjects.append(mobject)
  400. if self not in mobject.parents:
  401. mobject.parents.append(self)
  402. self.note_changed_family()
  403. return self
  404. def remove(
  405. self,
  406. *to_remove: Mobject,
  407. reassemble: bool = True,
  408. recurse: bool = True
  409. ) -> Self:
  410. for parent in self.get_family(recurse):
  411. for child in to_remove:
  412. if child in parent.submobjects:
  413. parent.submobjects.remove(child)
  414. if parent in child.parents:
  415. child.parents.remove(parent)
  416. if reassemble:
  417. parent.note_changed_family()
  418. return self
  419. def clear(self) -> Self:
  420. self.remove(*self.submobjects, recurse=False)
  421. return self
  422. def add_to_back(self, *mobjects: Mobject) -> Self:
  423. self.set_submobjects(list_update(mobjects, self.submobjects))
  424. return self
  425. def replace_submobject(self, index: int, new_submob: Mobject) -> Self:
  426. old_submob = self.submobjects[index]
  427. if self in old_submob.parents:
  428. old_submob.parents.remove(self)
  429. self.submobjects[index] = new_submob
  430. new_submob.parents.append(self)
  431. self.note_changed_family()
  432. return self
  433. def insert_submobject(self, index: int, new_submob: Mobject) -> Self:
  434. self.submobjects.insert(index, new_submob)
  435. self.note_changed_family()
  436. return self
  437. def set_submobjects(self, submobject_list: list[Mobject]) -> Self:
  438. if self.submobjects == submobject_list:
  439. return self
  440. self.clear()
  441. self.add(*submobject_list)
  442. return self
  443. def digest_mobject_attrs(self) -> Self:
  444. """
  445. Ensures all attributes which are mobjects are included
  446. in the submobjects list.
  447. """
  448. mobject_attrs = [x for x in list(self.__dict__.values()) if isinstance(x, Mobject)]
  449. self.set_submobjects(list_update(self.submobjects, mobject_attrs))
  450. return self
  451. # Submobject organization
  452. def arrange(
  453. self,
  454. direction: Vect3 = RIGHT,
  455. center: bool = True,
  456. **kwargs
  457. ) -> Self:
  458. for m1, m2 in zip(self.submobjects, self.submobjects[1:]):
  459. m2.next_to(m1, direction, **kwargs)
  460. if center:
  461. self.center()
  462. return self
  463. def arrange_in_grid(
  464. self,
  465. n_rows: int | None = None,
  466. n_cols: int | None = None,
  467. buff: float | None = None,
  468. h_buff: float | None = None,
  469. v_buff: float | None = None,
  470. buff_ratio: float | None = None,
  471. h_buff_ratio: float = 0.5,
  472. v_buff_ratio: float = 0.5,
  473. aligned_edge: Vect3 = ORIGIN,
  474. fill_rows_first: bool = True
  475. ) -> Self:
  476. submobs = self.submobjects
  477. n_submobs = len(submobs)
  478. if n_rows is None:
  479. n_rows = int(np.sqrt(n_submobs)) if n_cols is None else n_submobs // n_cols
  480. if n_cols is None:
  481. n_cols = n_submobs // n_rows
  482. if buff is not None:
  483. h_buff = buff
  484. v_buff = buff
  485. else:
  486. if buff_ratio is not None:
  487. v_buff_ratio = buff_ratio
  488. h_buff_ratio = buff_ratio
  489. if h_buff is None:
  490. h_buff = h_buff_ratio * self[0].get_width()
  491. if v_buff is None:
  492. v_buff = v_buff_ratio * self[0].get_height()
  493. x_unit = h_buff + max([sm.get_width() for sm in submobs])
  494. y_unit = v_buff + max([sm.get_height() for sm in submobs])
  495. for index, sm in enumerate(submobs):
  496. if fill_rows_first:
  497. x, y = index % n_cols, index // n_cols
  498. else:
  499. x, y = index // n_rows, index % n_rows
  500. sm.move_to(ORIGIN, aligned_edge)
  501. sm.shift(x * x_unit * RIGHT + y * y_unit * DOWN)
  502. self.center()
  503. return self
  504. def arrange_to_fit_dim(self, length: float, dim: int, about_edge=ORIGIN) -> Self:
  505. ref_point = self.get_bounding_box_point(about_edge)
  506. n_submobs = len(self.submobjects)
  507. if n_submobs <= 1:
  508. return
  509. total_length = sum(sm.length_over_dim(dim) for sm in self.submobjects)
  510. buff = (length - total_length) / (n_submobs - 1)
  511. vect = np.zeros(self.dim)
  512. vect[dim] = 1
  513. x = 0
  514. for submob in self.submobjects:
  515. submob.set_coord(x, dim, -vect)
  516. x += submob.length_over_dim(dim) + buff
  517. self.move_to(ref_point, about_edge)
  518. return self
  519. def arrange_to_fit_width(self, width: float, about_edge=ORIGIN) -> Self:
  520. return self.arrange_to_fit_dim(width, 0, about_edge)
  521. def arrange_to_fit_height(self, height: float, about_edge=ORIGIN) -> Self:
  522. return self.arrange_to_fit_dim(height, 1, about_edge)
  523. def arrange_to_fit_depth(self, depth: float, about_edge=ORIGIN) -> Self:
  524. return self.arrange_to_fit_dim(depth, 2, about_edge)
  525. def sort(
  526. self,
  527. point_to_num_func: Callable[[np.ndarray], float] = lambda p: p[0],
  528. submob_func: Callable[[Mobject]] | None = None
  529. ) -> Self:
  530. if submob_func is not None:
  531. self.submobjects.sort(key=submob_func)
  532. else:
  533. self.submobjects.sort(key=lambda m: point_to_num_func(m.get_center()))
  534. self.note_changed_family(only_changed_order=True)
  535. return self
  536. def shuffle(self, recurse: bool = False) -> Self:
  537. if recurse:
  538. for submob in self.submobjects:
  539. submob.shuffle(recurse=True)
  540. random.shuffle(self.submobjects)
  541. self.note_changed_family(only_changed_order=True)
  542. return self
  543. def reverse_submobjects(self) -> Self:
  544. self.submobjects.reverse()
  545. self.note_changed_family(only_changed_order=True)
  546. return self
  547. # Copying and serialization
  548. @staticmethod
  549. def stash_mobject_pointers(func: Callable[..., T]) -> Callable[..., T]:
  550. @wraps(func)
  551. def wrapper(self, *args, **kwargs):
  552. uncopied_attrs = ["parents", "target", "saved_state"]
  553. stash = dict()
  554. for attr in uncopied_attrs:
  555. if hasattr(self, attr):
  556. value = getattr(self, attr)
  557. stash[attr] = value
  558. null_value = [] if isinstance(value, list) else None
  559. setattr(self, attr, null_value)
  560. result = func(self, *args, **kwargs)
  561. self.__dict__.update(stash)
  562. return result
  563. return wrapper
  564. @stash_mobject_pointers
  565. def serialize(self) -> bytes:
  566. return pickle.dumps(self)
  567. def deserialize(self, data: bytes) -> Self:
  568. self.become(pickle.loads(data))
  569. return self
  570. @stash_mobject_pointers
  571. def deepcopy(self) -> Self:
  572. return copy.deepcopy(self)
  573. def copy(self, deep: bool = False) -> Self:
  574. if deep:
  575. return self.deepcopy()
  576. result = copy.copy(self)
  577. result.parents = []
  578. result.target = None
  579. result.saved_state = None
  580. # copy.copy is only a shallow copy, so the internal
  581. # data which are numpy arrays or other mobjects still
  582. # need to be further copied.
  583. result.uniforms = {
  584. key: value.copy() if isinstance(value, np.ndarray) else value
  585. for key, value in self.uniforms.items()
  586. }
  587. # Instead of adding using result.add, which does some checks for updating
  588. # updater statues and bounding box, just directly modify the family-related
  589. # lists
  590. result.submobjects = [sm.copy() for sm in self.submobjects]
  591. for sm in result.submobjects:
  592. sm.parents = [result]
  593. result.family = [result, *it.chain(*(sm.get_family() for sm in result.submobjects))]
  594. # Similarly, instead of calling match_updaters, since we know the status
  595. # won't have changed, just directly match.
  596. result.updaters = list(self.updaters)
  597. result._data_has_changed = True
  598. result.shader_wrapper = None
  599. family = self.get_family()
  600. for attr, value in self.__dict__.items():
  601. if isinstance(value, Mobject) and value is not self:
  602. if value in family:
  603. setattr(result, attr, result.family[family.index(value)])
  604. elif isinstance(value, np.ndarray):
  605. setattr(result, attr, value.copy())
  606. return result
  607. def generate_target(self, use_deepcopy: bool = False) -> Self:
  608. self.target = self.copy(deep=use_deepcopy)
  609. self.target.saved_state = self.saved_state
  610. return self.target
  611. def save_state(self, use_deepcopy: bool = False) -> Self:
  612. self.saved_state = self.copy(deep=use_deepcopy)
  613. self.saved_state.target = self.target
  614. return self
  615. def restore(self) -> Self:
  616. if not hasattr(self, "saved_state") or self.saved_state is None:
  617. raise Exception("Trying to restore without having saved")
  618. self.become(self.saved_state)
  619. return self
  620. def save_to_file(self, file_path: str) -> Self:
  621. with open(file_path, "wb") as fp:
  622. fp.write(self.serialize())
  623. log.info(f"Saved mobject to {file_path}")
  624. return self
  625. @staticmethod
  626. def load(file_path) -> Mobject:
  627. if not os.path.exists(file_path):
  628. log.error(f"No file found at {file_path}")
  629. sys.exit(2)
  630. with open(file_path, "rb") as fp:
  631. mobject = pickle.load(fp)
  632. return mobject
  633. def become(self, mobject: Mobject, match_updaters=False) -> Self:
  634. """
  635. Edit all data and submobjects to be idential
  636. to another mobject
  637. """
  638. self.align_family(mobject)
  639. family1 = self.get_family()
  640. family2 = mobject.get_family()
  641. for sm1, sm2 in zip(family1, family2):
  642. sm1.set_data(sm2.data)
  643. sm1.set_uniforms(sm2.uniforms)
  644. sm1.bounding_box[:] = sm2.bounding_box
  645. sm1.shader_folder = sm2.shader_folder
  646. sm1.texture_paths = sm2.texture_paths
  647. sm1.depth_test = sm2.depth_test
  648. sm1.render_primitive = sm2.render_primitive
  649. sm1._needs_new_bounding_box = sm2._needs_new_bounding_box
  650. # Make sure named family members carry over
  651. for attr, value in list(mobject.__dict__.items()):
  652. if isinstance(value, Mobject) and value in family2:
  653. setattr(self, attr, family1[family2.index(value)])
  654. if match_updaters:
  655. self.match_updaters(mobject)
  656. return self
  657. def looks_identical(self, mobject: Mobject) -> bool:
  658. fam1 = self.family_members_with_points()
  659. fam2 = mobject.family_members_with_points()
  660. if len(fam1) != len(fam2):
  661. return False
  662. for m1, m2 in zip(fam1, fam2):
  663. if m1.get_num_points() != m2.get_num_points():
  664. return False
  665. if not m1.data.dtype == m2.data.dtype:
  666. return False
  667. for key in m1.data.dtype.names:
  668. if not np.isclose(m1.data[key], m2.data[key]).all():
  669. return False
  670. if set(m1.uniforms).difference(m2.uniforms):
  671. return False
  672. for key in m1.uniforms:
  673. value1 = m1.uniforms[key]
  674. value2 = m2.uniforms[key]
  675. if isinstance(value1, np.ndarray) and isinstance(value2, np.ndarray) and not value1.size == value2.size:
  676. return False
  677. if not np.isclose(value1, value2).all():
  678. return False
  679. return True
  680. def has_same_shape_as(self, mobject: Mobject) -> bool:
  681. # Normalize both point sets by centering and making height 1
  682. points1, points2 = (
  683. (m.get_all_points() - m.get_center()) / m.get_height()
  684. for m in (self, mobject)
  685. )
  686. if len(points1) != len(points2):
  687. return False
  688. return bool(np.isclose(points1, points2, atol=self.get_width() * 1e-2).all())
  689. # Creating new Mobjects from this one
  690. def replicate(self, n: int) -> Self:
  691. group_class = self.get_group_class()
  692. return group_class(*(self.copy() for _ in range(n)))
  693. def get_grid(
  694. self,
  695. n_rows: int,
  696. n_cols: int,
  697. height: float | None = None,
  698. width: float | None = None,
  699. group_by_rows: bool = False,
  700. group_by_cols: bool = False,
  701. **kwargs
  702. ) -> Self:
  703. """
  704. Returns a new mobject containing multiple copies of this one
  705. arranged in a grid
  706. """
  707. total = n_rows * n_cols
  708. grid = self.replicate(total)
  709. if group_by_cols:
  710. kwargs["fill_rows_first"] = False
  711. grid.arrange_in_grid(n_rows, n_cols, **kwargs)
  712. if height is not None:
  713. grid.set_height(height)
  714. if width is not None:
  715. grid.set_height(width)
  716. group_class = self.get_group_class()
  717. if group_by_rows:
  718. return group_class(*(grid[n:n + n_cols] for n in range(0, total, n_cols)))
  719. elif group_by_cols:
  720. return group_class(*(grid[n:n + n_rows] for n in range(0, total, n_rows)))
  721. else:
  722. return grid
  723. # Updating
  724. def init_updaters(self):
  725. self.updaters: list[Updater] = list()
  726. self._has_updaters_in_family: Optional[bool] = False
  727. self.updating_suspended: bool = False
  728. def update(self, dt: float = 0, recurse: bool = True) -> Self:
  729. if not self.has_updaters() or self.updating_suspended:
  730. return self
  731. if recurse:
  732. for submob in self.submobjects:
  733. submob.update(dt, recurse)
  734. for updater in self.updaters:
  735. # This is hacky, but if an updater takes dt as an arg,
  736. # it will be passed the change in time from here
  737. if "dt" in updater.__code__.co_varnames:
  738. updater(self, dt=dt)
  739. else:
  740. updater(self)
  741. return self
  742. def get_updaters(self) -> list[Updater]:
  743. return self.updaters
  744. def add_updater(self, update_func: Updater, call: bool = True) -> Self:
  745. self.updaters.append(update_func)
  746. if call:
  747. self.update(dt=0)
  748. self.refresh_has_updater_status()
  749. return self
  750. def insert_updater(self, update_func: Updater, index=0):
  751. self.updaters.insert(index, update_func)
  752. self.refresh_has_updater_status()
  753. return self
  754. def remove_updater(self, update_func: Updater) -> Self:
  755. while update_func in self.updaters:
  756. self.updaters.remove(update_func)
  757. self.refresh_has_updater_status()
  758. return self
  759. def clear_updaters(self, recurse: bool = True) -> Self:
  760. for mob in self.get_family(recurse):
  761. mob.updaters = []
  762. mob._has_updaters_in_family = False
  763. for parent in self.get_ancestors():
  764. parent._has_updaters_in_family = False
  765. return self
  766. def match_updaters(self, mobject: Mobject) -> Self:
  767. self.updaters = list(mobject.updaters)
  768. self.refresh_has_updater_status()
  769. return self
  770. def suspend_updating(self, recurse: bool = True) -> Self:
  771. self.updating_suspended = True
  772. if recurse:
  773. for submob in self.submobjects:
  774. submob.suspend_updating(recurse)
  775. return self
  776. def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Self:
  777. self.updating_suspended = False
  778. if recurse:
  779. for submob in self.submobjects:
  780. submob.resume_updating(recurse)
  781. for parent in self.parents:
  782. parent.resume_updating(recurse=False, call_updater=False)
  783. if call_updater:
  784. self.update(dt=0, recurse=recurse)
  785. return self
  786. def has_updaters(self) -> bool:
  787. if self._has_updaters_in_family is None:
  788. # Recompute and save
  789. self._has_updaters_in_family = bool(self.updaters) or any(
  790. sm.has_updaters() for sm in self.submobjects
  791. )
  792. return self._has_updaters_in_family
  793. def refresh_has_updater_status(self) -> Self:
  794. self._has_updaters_in_family = None
  795. for parent in self.parents:
  796. parent.refresh_has_updater_status()
  797. return self
  798. # Check if mark as static or not for camera
  799. def is_changing(self) -> bool:
  800. return self._is_animating or self.has_updaters()
  801. def set_animating_status(self, is_animating: bool, recurse: bool = True) -> Self:
  802. for mob in (*self.get_family(recurse), *self.get_ancestors()):
  803. mob._is_animating = is_animating
  804. return self
  805. # Transforming operations
  806. def shift(self, vector: Vect3) -> Self:
  807. self.apply_points_function(
  808. lambda points: points + vector,
  809. about_edge=None,
  810. works_on_bounding_box=True,
  811. )
  812. return self
  813. def scale(
  814. self,
  815. scale_factor: float | npt.ArrayLike,
  816. min_scale_factor: float = 1e-8,
  817. about_point: Vect3 | None = None,
  818. about_edge: Vect3 = ORIGIN
  819. ) -> Self:
  820. """
  821. Default behavior is to scale about the center of the mobject.
  822. The argument about_edge can be a vector, indicating which side of
  823. the mobject to scale about, e.g., mob.scale(about_edge = RIGHT)
  824. scales about mob.get_right().
  825. Otherwise, if about_point is given a value, scaling is done with
  826. respect to that point.
  827. """
  828. if isinstance(scale_factor, numbers.Number):
  829. scale_factor = max(scale_factor, min_scale_factor)
  830. else:
  831. scale_factor = np.array(scale_factor).clip(min=min_scale_factor)
  832. self.apply_points_function(
  833. lambda points: scale_factor * points,
  834. about_point=about_point,
  835. about_edge=about_edge,
  836. works_on_bounding_box=True,
  837. )
  838. for mob in self.get_family():
  839. mob._handle_scale_side_effects(scale_factor)
  840. return self
  841. def _handle_scale_side_effects(self, scale_factor):
  842. # In case subclasses, such as DecimalNumber, need to make
  843. # any other changes when the size gets altered
  844. pass
  845. def stretch(self, factor: float, dim: int, **kwargs) -> Self:
  846. def func(points):
  847. points[:, dim] *= factor
  848. return points
  849. self.apply_points_function(func, works_on_bounding_box=True, **kwargs)
  850. return self
  851. def rotate_about_origin(self, angle: float, axis: Vect3 = OUT) -> Self:
  852. return self.rotate(angle, axis, about_point=ORIGIN)
  853. def rotate(
  854. self,
  855. angle: float,
  856. axis: Vect3 = OUT,
  857. about_point: Vect3 | None = None,
  858. **kwargs
  859. ) -> Self:
  860. rot_matrix_T = rotation_matrix_transpose(angle, axis)
  861. self.apply_points_function(
  862. lambda points: np.dot(points, rot_matrix_T),
  863. about_point,
  864. **kwargs
  865. )
  866. return self
  867. def flip(self, axis: Vect3 = UP, **kwargs) -> Self:
  868. return self.rotate(TAU / 2, axis, **kwargs)
  869. def apply_function(self, function: Callable[[np.ndarray], np.ndarray], **kwargs) -> Self:
  870. # Default to applying matrix about the origin, not mobjects center
  871. if len(kwargs) == 0:
  872. kwargs["about_point"] = ORIGIN
  873. self.apply_points_function(
  874. lambda points: np.array([function(p) for p in points]),
  875. **kwargs
  876. )
  877. return self
  878. def apply_function_to_position(self, function: Callable[[np.ndarray], np.ndarray]) -> Self:
  879. self.move_to(function(self.get_center()))
  880. return self
  881. def apply_function_to_submobject_positions(
  882. self,
  883. function: Callable[[np.ndarray], np.ndarray]
  884. ) -> Self:
  885. for submob in self.submobjects:
  886. submob.apply_function_to_position(function)
  887. return self
  888. def apply_matrix(self, matrix: npt.ArrayLike, **kwargs) -> Self:
  889. # Default to applying matrix about the origin, not mobjects center
  890. if ("about_point" not in kwargs) and ("about_edge" not in kwargs):
  891. kwargs["about_point"] = ORIGIN
  892. full_matrix = np.identity(self.dim)
  893. matrix = np.array(matrix)
  894. full_matrix[:matrix.shape[0], :matrix.shape[1]] = matrix
  895. self.apply_points_function(
  896. lambda points: np.dot(points, full_matrix.T),
  897. **kwargs
  898. )
  899. return self
  900. def apply_complex_function(self, function: Callable[[complex], complex], **kwargs) -> Self:
  901. def R3_func(point):
  902. x, y, z = point
  903. xy_complex = function(complex(x, y))
  904. return [
  905. xy_complex.real,
  906. xy_complex.imag,
  907. z
  908. ]
  909. return self.apply_function(R3_func, **kwargs)
  910. def wag(
  911. self,
  912. direction: Vect3 = RIGHT,
  913. axis: Vect3 = DOWN,
  914. wag_factor: float = 1.0
  915. ) -> Self:
  916. for mob in self.family_members_with_points():
  917. alphas = np.dot(mob.get_points(), np.transpose(axis))
  918. alphas -= min(alphas)
  919. alphas /= max(alphas)
  920. alphas = alphas**wag_factor
  921. mob.set_points(mob.get_points() + np.dot(
  922. alphas.reshape((len(alphas), 1)),
  923. np.array(direction).reshape((1, mob.dim))
  924. ))
  925. return self
  926. # Positioning methods
  927. def center(self) -> Self:
  928. self.shift(-self.get_center())
  929. return self
  930. def align_on_border(
  931. self,
  932. direction: Vect3,
  933. buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
  934. ) -> Self:
  935. """
  936. Direction just needs to be a vector pointing towards side or
  937. corner in the 2d plane.
  938. """
  939. target_point = np.sign(direction) * (FRAME_X_RADIUS, FRAME_Y_RADIUS, 0)
  940. point_to_align = self.get_bounding_box_point(direction)
  941. shift_val = target_point - point_to_align - buff * np.array(direction)
  942. shift_val = shift_val * abs(np.sign(direction))
  943. self.shift(shift_val)
  944. return self
  945. def to_corner(
  946. self,
  947. corner: Vect3 = LEFT + DOWN,
  948. buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
  949. ) -> Self:
  950. return self.align_on_border(corner, buff)
  951. def to_edge(
  952. self,
  953. edge: Vect3 = LEFT,
  954. buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
  955. ) -> Self:
  956. return self.align_on_border(edge, buff)
  957. def next_to(
  958. self,
  959. mobject_or_point: Mobject | Vect3,
  960. direction: Vect3 = RIGHT,
  961. buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER,
  962. aligned_edge: Vect3 = ORIGIN,
  963. submobject_to_align: Mobject | None = None,
  964. index_of_submobject_to_align: int | slice | None = None,
  965. coor_mask: Vect3 = np.array([1, 1, 1]),
  966. ) -> Self:
  967. if isinstance(mobject_or_point, Mobject):
  968. mob = mobject_or_point
  969. if index_of_submobject_to_align is not None:
  970. target_aligner = mob[index_of_submobject_to_align]
  971. else:
  972. target_aligner = mob
  973. target_point = target_aligner.get_bounding_box_point(
  974. aligned_edge + direction
  975. )
  976. else:
  977. target_point = mobject_or_point
  978. if submobject_to_align is not None:
  979. aligner = submobject_to_align
  980. elif index_of_submobject_to_align is not None:
  981. aligner = self[index_of_submobject_to_align]
  982. else:
  983. aligner = self
  984. point_to_align = aligner.get_bounding_box_point(aligned_edge - direction)
  985. self.shift((target_point - point_to_align + buff * direction) * coor_mask)
  986. return self
  987. def shift_onto_screen(self, **kwargs) -> Self:
  988. space_lengths = [FRAME_X_RADIUS, FRAME_Y_RADIUS]
  989. for vect in UP, DOWN, LEFT, RIGHT:
  990. dim = np.argmax(np.abs(vect))
  991. buff = kwargs.get("buff", DEFAULT_MOBJECT_TO_EDGE_BUFFER)
  992. max_val = space_lengths[dim] - buff
  993. edge_center = self.get_edge_center(vect)
  994. if np.dot(edge_center, vect) > max_val:
  995. self.to_edge(vect, **kwargs)
  996. return self
  997. def is_off_screen(self) -> bool:
  998. if self.get_left()[0] > FRAME_X_RADIUS:
  999. return True
  1000. if self.get_right()[0] < -FRAME_X_RADIUS:
  1001. return True
  1002. if self.get_bottom()[1] > FRAME_Y_RADIUS:
  1003. return True
  1004. if self.get_top()[1] < -FRAME_Y_RADIUS:
  1005. return True
  1006. return False
  1007. def stretch_about_point(self, factor: float, dim: int, point: Vect3) -> Self:
  1008. return self.stretch(factor, dim, about_point=point)
  1009. def stretch_in_place(self, factor: float, dim: int) -> Self:
  1010. # Now redundant with stretch
  1011. return self.stretch(factor, dim)
  1012. def rescale_to_fit(self, length: float, dim: int, stretch: bool = False, **kwargs) -> Self:
  1013. old_length = self.length_over_dim(dim)
  1014. if old_length == 0:
  1015. return self
  1016. if stretch:
  1017. self.stretch(length / old_length, dim, **kwargs)
  1018. else:
  1019. self.scale(length / old_length, **kwargs)
  1020. return self
  1021. def stretch_to_fit_width(self, width: float, **kwargs) -> Self:
  1022. return self.rescale_to_fit(width, 0, stretch=True, **kwargs)
  1023. def stretch_to_fit_height(self, height: float, **kwargs) -> Self:
  1024. return self.rescale_to_fit(height, 1, stretch=True, **kwargs)
  1025. def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self:
  1026. return self.rescale_to_fit(depth, 2, stretch=True, **kwargs)
  1027. def set_width(self, width: float, stretch: bool = False, **kwargs) -> Self:
  1028. return self.rescale_to_fit(width, 0, stretch=stretch, **kwargs)
  1029. def set_height(self, height: float, stretch: bool = False, **kwargs) -> Self:
  1030. return self.rescale_to_fit(height, 1, stretch=stretch, **kwargs)
  1031. def set_depth(self, depth: float, stretch: bool = False, **kwargs) -> Self:
  1032. return self.rescale_to_fit(depth, 2, stretch=stretch, **kwargs)
  1033. def set_max_width(self, max_width: float, **kwargs) -> Self:
  1034. if self.get_width() > max_width:
  1035. self.set_width(max_width, **kwargs)
  1036. return self
  1037. def set_max_height(self, max_height: float, **kwargs) -> Self:
  1038. if self.get_height() > max_height:
  1039. self.set_height(max_height, **kwargs)
  1040. return self
  1041. def set_max_depth(self, max_depth: float, **kwargs) -> Self:
  1042. if self.get_depth() > max_depth:
  1043. self.set_depth(max_depth, **kwargs)
  1044. return self
  1045. def set_min_width(self, min_width: float, **kwargs) -> Self:
  1046. if self.get_width() < min_width:
  1047. self.set_width(min_width, **kwargs)
  1048. return self
  1049. def set_min_height(self, min_height: float, **kwargs) -> Self:
  1050. if self.get_height() < min_height:
  1051. self.set_height(min_height, **kwargs)
  1052. return self
  1053. def set_min_depth(self, min_depth: float, **kwargs) -> Self:
  1054. if self.get_depth() < min_depth:
  1055. self.set_depth(min_depth, **kwargs)
  1056. return self
  1057. def set_shape(
  1058. self,
  1059. width: Optional[float] = None,
  1060. height: Optional[float] = None,
  1061. depth: Optional[float] = None,
  1062. **kwargs
  1063. ) -> Self:
  1064. if width is not None:
  1065. self.set_width(width, stretch=True, **kwargs)
  1066. if height is not None:
  1067. self.set_height(height, stretch=True, **kwargs)
  1068. if depth is not None:
  1069. self.set_depth(depth, stretch=True, **kwargs)
  1070. return self
  1071. def set_coord(self, value: float, dim: int, direction: Vect3 = ORIGIN) -> Self:
  1072. curr = self.get_coord(dim, direction)
  1073. shift_vect = np.zeros(self.dim)
  1074. shift_vect[dim] = value - curr
  1075. self.shift(shift_vect)
  1076. return self
  1077. def set_x(self, x: float, direction: Vect3 = ORIGIN) -> Self:
  1078. return self.set_coord(x, 0, direction)
  1079. def set_y(self, y: float, direction: Vect3 = ORIGIN) -> Self:
  1080. return self.set_coord(y, 1, direction)
  1081. def set_z(self, z: float, direction: Vect3 = ORIGIN) -> Self:
  1082. return self.set_coord(z, 2, direction)
  1083. def set_z_index(self, z_index: int) -> Self:
  1084. self.z_index = z_index
  1085. return self
  1086. def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self:
  1087. self.scale(factor, **kwargs)
  1088. for submob in self.submobjects:
  1089. submob.scale(1. / factor)
  1090. return self
  1091. def move_to(
  1092. self,
  1093. point_or_mobject: Mobject | Vect3,
  1094. aligned_edge: Vect3 = ORIGIN,
  1095. coor_mask: Vect3 = np.array([1, 1, 1])
  1096. ) -> Self:
  1097. if isinstance(point_or_mobject, Mobject):
  1098. target = point_or_mobject.get_bounding_box_point(aligned_edge)
  1099. else:
  1100. target = point_or_mobject
  1101. point_to_align = self.get_bounding_box_point(aligned_edge)
  1102. self.shift((target - point_to_align) * coor_mask)
  1103. return self
  1104. def replace(self, mobject: Mobject, dim_to_match: int = 0, stretch: bool = False) -> Self:
  1105. if not mobject.get_num_points() and not mobject.submobjects:
  1106. self.scale(0)
  1107. return self
  1108. if stretch:
  1109. for i in range(self.dim):
  1110. self.rescale_to_fit(mobject.length_over_dim(i), i, stretch=True)
  1111. else:
  1112. self.rescale_to_fit(
  1113. mobject.length_over_dim(dim_to_match),
  1114. dim_to_match,
  1115. stretch=False
  1116. )
  1117. self.shift(mobject.get_center() - self.get_center())
  1118. return self
  1119. def surround(
  1120. self,
  1121. mobject: Mobject,
  1122. dim_to_match: int = 0,
  1123. stretch: bool = False,
  1124. buff: float = MED_SMALL_BUFF
  1125. ) -> Self:
  1126. self.replace(mobject, dim_to_match, stretch)
  1127. length = mobject.length_over_dim(dim_to_match)
  1128. self.scale((length + buff) / length)
  1129. return self
  1130. def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self:
  1131. curr_start, curr_end = self.get_start_and_end()
  1132. curr_vect = curr_end - curr_start
  1133. if np.all(curr_vect == 0):
  1134. raise Exception("Cannot position endpoints of closed loop")
  1135. target_vect = end - start
  1136. self.scale(
  1137. get_norm(target_vect) / get_norm(curr_vect),
  1138. about_point=curr_start,
  1139. )
  1140. self.rotate(
  1141. angle_of_vector(target_vect) - angle_of_vector(curr_vect),
  1142. )
  1143. self.rotate(
  1144. np.arctan2(curr_vect[2], get_norm(curr_vect[:2])) - np.arctan2(target_vect[2], get_norm(target_vect[:2])),
  1145. axis=np.array([-target_vect[1], target_vect[0], 0]),
  1146. )
  1147. self.shift(start - self.get_start())
  1148. return self
  1149. # Color functions
  1150. @affects_family_data
  1151. def set_rgba_array(
  1152. self,
  1153. rgba_array: npt.ArrayLike,
  1154. name: str = "rgba",
  1155. recurse: bool = False
  1156. ) -> Self:
  1157. for mob in self.get_family(recurse):
  1158. data = mob.data if mob.get_num_points() > 0 else mob._data_defaults
  1159. data[name][:] = rgba_array
  1160. return self
  1161. def set_color_by_rgba_func(
  1162. self,
  1163. func: Callable[[Vect3], Vect4],
  1164. recurse: bool = True
  1165. ) -> Self:
  1166. """
  1167. Func should take in a point in R3 and output an rgba value
  1168. """
  1169. for mob in self.get_family(recurse):
  1170. rgba_array = [func(point) for point in mob.get_points()]
  1171. mob.set_rgba_array(rgba_array)
  1172. return self
  1173. def set_color_by_rgb_func(
  1174. self,
  1175. func: Callable[[Vect3], Vect3],
  1176. opacity: float = 1,
  1177. recurse: bool = True
  1178. ) -> Self:
  1179. """
  1180. Func should take in a point in R3 and output an rgb value
  1181. """
  1182. for mob in self.get_family(recurse):
  1183. rgba_array = [[*func(point), opacity] for point in mob.get_points()]
  1184. mob.set_rgba_array(rgba_array)
  1185. return self
  1186. @affects_family_data
  1187. def set_rgba_array_by_color(
  1188. self,
  1189. color: ManimColor | Iterable[ManimColor] | None = None,
  1190. opacity: float | Iterable[float] | None = None,
  1191. name: str = "rgba",
  1192. recurse: bool = True
  1193. ) -> Self:
  1194. for mob in self.get_family(recurse):
  1195. data = mob.data if mob.has_points() > 0 else mob._data_defaults
  1196. if color is not None:
  1197. rgbs = np.array(list(map(color_to_rgb, listify(color))))
  1198. if 1 < len(rgbs):
  1199. rgbs = resize_with_interpolation(rgbs, len(data))
  1200. data[name][:, :3] = rgbs
  1201. if opacity is not None:
  1202. if not isinstance(opacity, (float, int)):
  1203. opacity = resize_with_interpolation(np.array(opacity), len(data))
  1204. data[name][:, 3] = opacity
  1205. return self
  1206. def set_color(
  1207. self,
  1208. color: ManimColor | Iterable[ManimColor] | None,
  1209. opacity: float | Iterable[float] | None = None,
  1210. recurse: bool = True
  1211. ) -> Self:
  1212. self.set_rgba_array_by_color(color, opacity, recurse=False)
  1213. # Recurse to submobjects differently from how set_rgba_array_by_color
  1214. # in case they implement set_color differently
  1215. if recurse:
  1216. for submob in self.submobjects:
  1217. submob.set_color(color, recurse=True)
  1218. return self
  1219. def set_opacity(
  1220. self,
  1221. opacity: float | Iterable[float] | None,
  1222. recurse: bool = True
  1223. ) -> Self:
  1224. self.set_rgba_array_by_color(color=None, opacity=opacity, recurse=False)
  1225. if recurse:
  1226. for submob in self.submobjects:
  1227. submob.set_opacity(opacity, recurse=True)
  1228. return self
  1229. def get_color(self) -> str:
  1230. return rgb_to_hex(self.data["rgba"][0, :3])
  1231. def get_opacity(self) -> float:
  1232. return float(self.data["rgba"][0, 3])
  1233. def get_opacities(self) -> float:
  1234. return self.data["rgba"][:, 3]
  1235. def set_color_by_gradient(self, *colors: ManimColor) -> Self:
  1236. if self.has_points():
  1237. self.set_color(colors)
  1238. else:
  1239. self.set_submobject_colors_by_gradient(*colors)
  1240. return self
  1241. def set_submobject_colors_by_gradient(self, *colors: ManimColor) -> Self:
  1242. if len(colors) == 0:
  1243. raise Exception("Need at least one color")
  1244. elif len(colors) == 1:
  1245. return self.set_color(*colors)
  1246. # mobs = self.family_members_with_points()
  1247. mobs = self.submobjects
  1248. new_colors = color_gradient(colors, len(mobs))
  1249. for mob, color in zip(mobs, new_colors):
  1250. mob.set_color(color)
  1251. return self
  1252. def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self:
  1253. self.set_opacity(1.0 - darkness, recurse=recurse)
  1254. def get_shading(self) -> np.ndarray:
  1255. return self.uniforms["shading"]
  1256. def set_shading(
  1257. self,
  1258. reflectiveness: float | None = None,
  1259. gloss: float | None = None,
  1260. shadow: float | None = None,
  1261. recurse: bool = True
  1262. ) -> Self:
  1263. """
  1264. Larger reflectiveness makes things brighter when facing the light
  1265. Larger shadow makes faces opposite the light darker
  1266. Makes parts bright where light gets reflected toward the camera
  1267. """
  1268. for mob in self.get_family(recurse):
  1269. shading = mob.uniforms["shading"]
  1270. for i, value in enumerate([reflectiveness, gloss, shadow]):
  1271. if value is not None:
  1272. shading[i] = value
  1273. mob.set_uniform(shading=shading, recurse=False)
  1274. return self
  1275. def get_reflectiveness(self) -> float:
  1276. return self.get_shading()[0]
  1277. def get_gloss(self) -> float:
  1278. return self.get_shading()[1]
  1279. def get_shadow(self) -> float:
  1280. return self.get_shading()[2]
  1281. def set_reflectiveness(self, reflectiveness: float, recurse: bool = True) -> Self:
  1282. self.set_shading(reflectiveness=reflectiveness, recurse=recurse)
  1283. return self
  1284. def set_gloss(self, gloss: float, recurse: bool = True) -> Self:
  1285. self.set_shading(gloss=gloss, recurse=recurse)
  1286. return self
  1287. def set_shadow(self, shadow: float, recurse: bool = True) -> Self:
  1288. self.set_shading(shadow=shadow, recurse=recurse)
  1289. return self
  1290. # Background rectangle
  1291. def add_background_rectangle(
  1292. self,
  1293. color: ManimColor | None = None,
  1294. opacity: float = 1.0,
  1295. **kwargs
  1296. ) -> Self:
  1297. from manimlib.mobject.shape_matchers import BackgroundRectangle
  1298. self.background_rectangle = BackgroundRectangle(
  1299. self, color=color,
  1300. fill_opacity=opacity,
  1301. **kwargs
  1302. )
  1303. self.add_to_back(self.background_rectangle)
  1304. return self
  1305. def add_background_rectangle_to_submobjects(self, **kwargs) -> Self:
  1306. for submobject in self.submobjects:
  1307. submobject.add_background_rectangle(**kwargs)
  1308. return self
  1309. def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Self:
  1310. for mob in self.family_members_with_points():
  1311. mob.add_background_rectangle(**kwargs)
  1312. return self
  1313. # Getters
  1314. def get_bounding_box_point(self, direction: Vect3) -> Vect3:
  1315. bb = self.get_bounding_box()
  1316. indices = (np.sign(direction) + 1).astype(int)
  1317. return np.array([
  1318. bb[indices[i]][i]
  1319. for i in range(3)
  1320. ])
  1321. def get_edge_center(self, direction: Vect3) -> Vect3:
  1322. return self.get_bounding_box_point(direction)
  1323. def get_corner(self, direction: Vect3) -> Vect3:
  1324. return self.get_bounding_box_point(direction)
  1325. def get_all_corners(self):
  1326. bb = self.get_bounding_box()
  1327. return np.array([
  1328. [bb[indices[-i + 1]][i] for i in range(3)]
  1329. for indices in it.product([0, 2], repeat=3)
  1330. ])
  1331. def get_center(self) -> Vect3:
  1332. return self.get_bounding_box()[1]
  1333. def get_center_of_mass(self) -> Vect3:
  1334. return self.get_all_points().mean(0)
  1335. def get_boundary_point(self, direction: Vect3) -> Vect3:
  1336. all_points = self.get_all_points()
  1337. boundary_directions = all_points - self.get_center()
  1338. norms = np.linalg.norm(boundary_directions, axis=1)
  1339. boundary_directions /= np.repeat(norms, 3).reshape((len(norms), 3))
  1340. index = np.argmax(np.dot(boundary_directions, np.array(direction).T))
  1341. return all_points[index]
  1342. def get_continuous_bounding_box_point(self, direction: Vect3) -> Vect3:
  1343. dl, center, ur = self.get_bounding_box()
  1344. corner_vect = (ur - center)
  1345. return center + direction / np.max(np.abs(np.true_divide(
  1346. direction, corner_vect,
  1347. out=np.zeros(len(direction)),
  1348. where=((corner_vect) != 0)
  1349. )))
  1350. def get_top(self) -> Vect3:
  1351. return self.get_edge_center(UP)
  1352. def get_bottom(self) -> Vect3:
  1353. return self.get_edge_center(DOWN)
  1354. def get_right(self) -> Vect3:
  1355. return self.get_edge_center(RIGHT)
  1356. def get_left(self) -> Vect3:
  1357. return self.get_edge_center(LEFT)
  1358. def get_zenith(self) -> Vect3:
  1359. return self.get_edge_center(OUT)
  1360. def get_nadir(self) -> Vect3:
  1361. return self.get_edge_center(IN)
  1362. def length_over_dim(self, dim: int) -> float:
  1363. bb = self.get_bounding_box()
  1364. return abs((bb[2] - bb[0])[dim])
  1365. def get_width(self) -> float:
  1366. return self.length_over_dim(0)
  1367. def get_height(self) -> float:
  1368. return self.length_over_dim(1)
  1369. def get_depth(self) -> float:
  1370. return self.length_over_dim(2)
  1371. def get_shape(self) -> Tuple[float]:
  1372. return tuple(self.length_over_dim(dim) for dim in range(3))
  1373. def get_coord(self, dim: int, direction: Vect3 = ORIGIN) -> float:
  1374. """
  1375. Meant to generalize get_x, get_y, get_z
  1376. """
  1377. return self.get_bounding_box_point(direction)[dim]
  1378. def get_x(self, direction=ORIGIN) -> float:
  1379. return self.get_coord(0, direction)
  1380. def get_y(self, direction=ORIGIN) -> float:
  1381. return self.get_coord(1, direction)
  1382. def get_z(self, direction=ORIGIN) -> float:
  1383. return self.get_coord(2, direction)
  1384. def get_start(self) -> Vect3:
  1385. self.throw_error_if_no_points()
  1386. return self.get_points()[0].copy()
  1387. def get_end(self) -> Vect3:
  1388. self.throw_error_if_no_points()
  1389. return self.get_points()[-1].copy()
  1390. def get_start_and_end(self) -> tuple[Vect3, Vect3]:
  1391. self.throw_error_if_no_points()
  1392. points = self.get_points()
  1393. return (points[0].copy(), points[-1].copy())
  1394. def point_from_proportion(self, alpha: float) -> Vect3:
  1395. points = self.get_points()
  1396. i, subalpha = integer_interpolate(0, len(points) - 1, alpha)
  1397. return interpolate(points[i], points[i + 1], subalpha)
  1398. def pfp(self, alpha):
  1399. """Abbreviation for point_from_proportion"""
  1400. return self.point_from_proportion(alpha)
  1401. def get_pieces(self, n_pieces: int) -> Group:
  1402. template = self.copy()
  1403. template.set_submobjects([])
  1404. alphas = np.linspace(0, 1, n_pieces + 1)
  1405. return Group(*[
  1406. template.copy().pointwise_become_partial(
  1407. self, a1, a2
  1408. )
  1409. for a1, a2 in zip(alphas[:-1], alphas[1:])
  1410. ])
  1411. def get_z_index_reference_point(self) -> Vect3:
  1412. # TODO, better place to define default z_index_group?
  1413. z_index_group = getattr(self, "z_index_group", self)
  1414. return z_index_group.get_center()
  1415. # Match other mobject properties
  1416. def match_color(self, mobject: Mobject) -> Self:
  1417. return self.set_color(mobject.get_color())
  1418. def match_style(self, mobject: Mobject) -> Self:
  1419. self.set_color(mobject.get_color())
  1420. self.set_opacity(mobject.get_opacity())
  1421. self.set_shading(*mobject.get_shading())
  1422. return self
  1423. def match_dim_size(self, mobject: Mobject, dim: int, **kwargs) -> Self:
  1424. return self.rescale_to_fit(
  1425. mobject.length_over_dim(dim), dim,
  1426. **kwargs
  1427. )
  1428. def match_width(self, mobject: Mobject, **kwargs) -> Self:
  1429. return self.match_dim_size(mobject, 0, **kwargs)
  1430. def match_height(self, mobject: Mobject, **kwargs) -> Self:
  1431. return self.match_dim_size(mobject, 1, **kwargs)
  1432. def match_depth(self, mobject: Mobject, **kwargs) -> Self:
  1433. return self.match_dim_size(mobject, 2, **kwargs)
  1434. def match_coord(
  1435. self,
  1436. mobject_or_point: Mobject | Vect3,
  1437. dim: int,
  1438. direction: Vect3 = ORIGIN
  1439. ) -> Self:
  1440. if isinstance(mobject_or_point, Mobject):
  1441. coord = mobject_or_point.get_coord(dim, direction)
  1442. else:
  1443. coord = mobject_or_point[dim]
  1444. return self.set_coord(coord, dim=dim, direction=direction)
  1445. def match_x(
  1446. self,
  1447. mobject_or_point: Mobject | Vect3,
  1448. direction: Vect3 = ORIGIN
  1449. ) -> Self:
  1450. return self.match_coord(mobject_or_point, 0, direction)
  1451. def match_y(
  1452. self,
  1453. mobject_or_point: Mobject | Vect3,
  1454. direction: Vect3 = ORIGIN
  1455. ) -> Self:
  1456. return self.match_coord(mobject_or_point, 1, direction)
  1457. def match_z(
  1458. self,
  1459. mobject_or_point: Mobject | Vect3,
  1460. direction: Vect3 = ORIGIN
  1461. ) -> Self:
  1462. return self.match_coord(mobject_or_point, 2, direction)
  1463. def align_to(
  1464. self,
  1465. mobject_or_point: Mobject | Vect3,
  1466. direction: Vect3 = ORIGIN
  1467. ) -> Self:
  1468. """
  1469. Examples:
  1470. mob1.align_to(mob2, UP) moves mob1 vertically so that its
  1471. top edge lines ups with mob2's top edge.
  1472. mob1.align_to(mob2, alignment_vect = RIGHT) moves mob1
  1473. horizontally so that it's center is directly above/below
  1474. the center of mob2
  1475. """
  1476. if isinstance(mobject_or_point, Mobject):
  1477. point = mobject_or_point.get_bounding_box_point(direction)
  1478. else:
  1479. point = mobject_or_point
  1480. for dim in range(self.dim):
  1481. if direction[dim] != 0:
  1482. self.set_coord(point[dim], dim, direction)
  1483. return self
  1484. def get_group_class(self):
  1485. return Group
  1486. # Alignment
  1487. def is_aligned_with(self, mobject: Mobject) -> bool:
  1488. if len(self.data) != len(mobject.data):
  1489. return False
  1490. if len(self.submobjects) != len(mobject.submobjects):
  1491. return False
  1492. return all(
  1493. sm1.is_aligned_with(sm2)
  1494. for sm1, sm2 in zip(self.submobjects, mobject.submobjects)
  1495. )
  1496. def align_data_and_family(self, mobject: Mobject) -> Self:
  1497. self.align_family(mobject)
  1498. self.align_data(mobject)
  1499. return self
  1500. def align_data(self, mobject: Mobject) -> Self:
  1501. for mob1, mob2 in zip(self.get_family(), mobject.get_family()):
  1502. mob1.align_points(mob2)
  1503. return self
  1504. def align_points(self, mobject: Mobject) -> Self:
  1505. max_len = max(self.get_num_points(), mobject.get_num_points())
  1506. for mob in (self, mobject):
  1507. mob.resize_points(max_len, resize_func=resize_preserving_order)
  1508. return self
  1509. def align_family(self, mobject: Mobject) -> Self:
  1510. mob1 = self
  1511. mob2 = mobject
  1512. n1 = len(mob1)
  1513. n2 = len(mob2)
  1514. if n1 != n2:
  1515. mob1.add_n_more_submobjects(max(0, n2 - n1))
  1516. mob2.add_n_more_submobjects(max(0, n1 - n2))
  1517. # Recurse
  1518. for sm1, sm2 in zip(mob1.submobjects, mob2.submobjects):
  1519. sm1.align_family(sm2)
  1520. return self
  1521. def push_self_into_submobjects(self) -> Self:
  1522. copy = self.copy()
  1523. copy.set_submobjects([])
  1524. self.resize_points(0)
  1525. self.add(copy)
  1526. return self
  1527. def add_n_more_submobjects(self, n: int) -> Self:
  1528. if n == 0:
  1529. return self
  1530. curr = len(self.submobjects)
  1531. if curr == 0:
  1532. # If empty, simply add n point mobjects
  1533. null_mob = self.copy()
  1534. null_mob.set_points([self.get_center()])
  1535. self.set_submobjects([
  1536. null_mob.copy()
  1537. for k in range(n)
  1538. ])
  1539. return self
  1540. target = curr + n
  1541. repeat_indices = (np.arange(target) * curr) // target
  1542. split_factors = [
  1543. (repeat_indices == i).sum()
  1544. for i in range(curr)
  1545. ]
  1546. new_submobs = []
  1547. for submob, sf in zip(self.submobjects, split_factors):
  1548. new_submobs.append(submob)
  1549. for k in range(1, sf):
  1550. new_submobs.append(submob.invisible_copy())
  1551. self.set_submobjects(new_submobs)
  1552. return self
  1553. def invisible_copy(self) -> Self:
  1554. return self.copy().set_opacity(0)
  1555. # Interpolate
  1556. def interpolate(
  1557. self,
  1558. mobject1: Mobject,
  1559. mobject2: Mobject,
  1560. alpha: float,
  1561. path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path
  1562. ) -> Self:
  1563. keys = [k for k in self.data.dtype.names if k not in self.locked_data_keys]
  1564. if keys:
  1565. self.note_changed_data()
  1566. for key in keys:
  1567. md1 = mobject1.data[key]
  1568. md2 = mobject2.data[key]
  1569. if key in self.const_data_keys:
  1570. md1 = md1[0]
  1571. md2 = md2[0]
  1572. if key in self.pointlike_data_keys:
  1573. self.data[key] = path_func(md1, md2, alpha)
  1574. else:
  1575. self.data[key] = (1 - alpha) * md1 + alpha * md2
  1576. for key in self.uniforms:
  1577. if key in self.locked_uniform_keys:
  1578. continue
  1579. if key not in mobject1.uniforms or key not in mobject2.uniforms:
  1580. continue
  1581. self.uniforms[key] = (1 - alpha) * mobject1.uniforms[key] + alpha * mobject2.uniforms[key]
  1582. self.bounding_box[:] = path_func(mobject1.bounding_box, mobject2.bounding_box, alpha)
  1583. return self
  1584. def pointwise_become_partial(self, mobject, a, b) -> Self:
  1585. """
  1586. Set points in such a way as to become only
  1587. part of mobject.
  1588. Inputs 0 <= a < b <= 1 determine what portion
  1589. of mobject to become.
  1590. """
  1591. # To be implemented in subclass
  1592. return self
  1593. # Locking data
  1594. def lock_data(self, keys: Iterable[str]) -> Self:
  1595. """
  1596. To speed up some animations, particularly transformations,
  1597. it can be handy to acknowledge which pieces of data
  1598. won't change during the animation so that calls to
  1599. interpolate can skip this, and so that it's not
  1600. read into the shader_wrapper objects needlessly
  1601. """
  1602. if self.has_updaters():
  1603. return self
  1604. self.locked_data_keys = set(keys)
  1605. return self
  1606. def lock_uniforms(self, keys: Iterable[str]) -> Self:
  1607. if self.has_updaters():
  1608. return self
  1609. self.locked_uniform_keys = set(keys)
  1610. return self
  1611. def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject) -> Self:
  1612. tuples = zip(
  1613. self.get_family(),
  1614. mobject1.get_family(),
  1615. mobject2.get_family(),
  1616. )
  1617. for sm, sm1, sm2 in tuples:
  1618. if not sm.data.dtype == sm1.data.dtype == sm2.data.dtype:
  1619. continue
  1620. sm.lock_data(
  1621. key for key in sm.data.dtype.names
  1622. if arrays_match(sm1.data[key], sm2.data[key])
  1623. )
  1624. sm.lock_uniforms(
  1625. key for key in self.uniforms
  1626. if all(listify(mobject1.uniforms.get(key, 0) == mobject2.uniforms.get(key, 0)))
  1627. )
  1628. sm.const_data_keys = set(
  1629. key for key in sm.data.dtype.names
  1630. if key not in sm.locked_data_keys
  1631. if all(
  1632. array_is_constant(mob.data[key])
  1633. for mob in (sm, sm1, sm2)
  1634. )
  1635. )
  1636. return self
  1637. def unlock_data(self) -> Self:
  1638. for mob in self.get_family():
  1639. mob.locked_data_keys = set()
  1640. mob.const_data_keys = set()
  1641. mob.locked_uniform_keys = set()
  1642. return self
  1643. # Operations touching shader uniforms
  1644. @staticmethod
  1645. def affects_shader_info_id(func: Callable[..., T]) -> Callable[..., T]:
  1646. @wraps(func)
  1647. def wrapper(self, *args, **kwargs):
  1648. result = func(self, *args, **kwargs)
  1649. self.refresh_shader_wrapper_id()
  1650. return result
  1651. return wrapper
  1652. @affects_shader_info_id
  1653. def set_uniform(self, recurse: bool = True, **new_uniforms) -> Self:
  1654. for mob in self.get_family(recurse):
  1655. mob.uniforms.update(new_uniforms)
  1656. return self
  1657. @affects_shader_info_id
  1658. def fix_in_frame(self, recurse: bool = True) -> Self:
  1659. self.set_uniform(recurse, is_fixed_in_frame=1.0)
  1660. return self
  1661. @affects_shader_info_id
  1662. def unfix_from_frame(self, recurse: bool = True) -> Self:
  1663. self.set_uniform(recurse, is_fixed_in_frame=0.0)
  1664. return self
  1665. def is_fixed_in_frame(self) -> bool:
  1666. return bool(self.uniforms["is_fixed_in_frame"])
  1667. @affects_shader_info_id
  1668. def apply_depth_test(self, recurse: bool = True) -> Self:
  1669. for mob in self.get_family(recurse):
  1670. mob.depth_test = True
  1671. return self
  1672. @affects_shader_info_id
  1673. def deactivate_depth_test(self, recurse: bool = True) -> Self:
  1674. for mob in self.get_family(recurse):
  1675. mob.depth_test = False
  1676. return self
  1677. def set_clip_plane(
  1678. self,
  1679. vect: Vect3 | None = None,
  1680. threshold: float | None = None
  1681. ) -> Self:
  1682. if vect is not None:
  1683. self.uniforms["clip_plane"][:3] = vect
  1684. if threshold is not None:
  1685. self.uniforms["clip_plane"][3] = threshold
  1686. return self
  1687. def deactivate_clip_plane(self) -> Self:
  1688. self.uniforms["clip_plane"][:] = 0
  1689. return self
  1690. # Shader code manipulation
  1691. @affects_data
  1692. def replace_shader_code(self, old: str, new: str) -> Self:
  1693. for mob in self.get_family():
  1694. mob.shader_code_replacements[old] = new
  1695. mob.shader_wrapper = None
  1696. return self
  1697. def set_color_by_code(self, glsl_code: str) -> Self:
  1698. """
  1699. Takes a snippet of code and inserts it into a
  1700. context which has the following variables:
  1701. vec4 color, vec3 point, vec3 unit_normal.
  1702. The code should change the color variable
  1703. """
  1704. self.replace_shader_code(
  1705. "///// INSERT COLOR FUNCTION HERE /////",
  1706. glsl_code
  1707. )
  1708. return self
  1709. def set_color_by_xyz_func(
  1710. self,
  1711. glsl_snippet: str,
  1712. min_value: float = -5.0,
  1713. max_value: float = 5.0,
  1714. colormap: str = "viridis"
  1715. ) -> Self:
  1716. """
  1717. Pass in a glsl expression in terms of x, y and z which returns
  1718. a float.
  1719. """
  1720. # TODO, add a version of this which changes the point data instead
  1721. # of the shader code
  1722. for char in "xyz":
  1723. glsl_snippet = glsl_snippet.replace(char, "point." + char)
  1724. rgb_list = get_colormap_list(colormap)
  1725. self.set_color_by_code(
  1726. "color.rgb = float_to_color({}, {}, {}, {});".format(
  1727. glsl_snippet,
  1728. float(min_value),
  1729. float(max_value),
  1730. get_colormap_code(rgb_list)
  1731. )
  1732. )
  1733. return self
  1734. # For shader data
  1735. def init_shader_wrapper(self, ctx: Context):
  1736. self.shader_wrapper = ShaderWrapper(
  1737. ctx=ctx,
  1738. vert_data=self.data,
  1739. shader_folder=self.shader_folder,
  1740. mobject_uniforms=self.uniforms,
  1741. texture_paths=self.texture_paths,
  1742. depth_test=self.depth_test,
  1743. render_primitive=self.render_primitive,
  1744. code_replacements=self.shader_code_replacements,
  1745. )
  1746. def refresh_shader_wrapper_id(self):
  1747. for submob in self.get_family():
  1748. if submob.shader_wrapper is not None:
  1749. submob.shader_wrapper.depth_test = submob.depth_test
  1750. submob.shader_wrapper.refresh_id()
  1751. for mob in (self, *self.get_ancestors()):
  1752. mob._data_has_changed = True
  1753. return self
  1754. def get_shader_wrapper(self, ctx: Context) -> ShaderWrapper:
  1755. if self.shader_wrapper is None:
  1756. self.init_shader_wrapper(ctx)
  1757. return self.shader_wrapper
  1758. def get_shader_wrapper_list(self, ctx: Context) -> list[ShaderWrapper]:
  1759. family = self.family_members_with_points()
  1760. batches = batch_by_property(family, lambda sm: sm.get_shader_wrapper(ctx).get_id())
  1761. result = []
  1762. for submobs, sid in batches:
  1763. shader_wrapper = submobs[0].shader_wrapper
  1764. data_list = [sm.get_shader_data() for sm in submobs]
  1765. shader_wrapper.read_in(data_list)
  1766. result.append(shader_wrapper)
  1767. return result
  1768. def get_shader_data(self) -> np.ndarray:
  1769. indices = self.get_shader_vert_indices()
  1770. if indices is not None:
  1771. return self.data[indices]
  1772. else:
  1773. return self.data
  1774. def get_uniforms(self):
  1775. return self.uniforms
  1776. def get_shader_vert_indices(self) -> Optional[np.ndarray]:
  1777. return None
  1778. def render(self, ctx: Context, camera_uniforms: dict):
  1779. if self._data_has_changed:
  1780. self.shader_wrappers = self.get_shader_wrapper_list(ctx)
  1781. self._data_has_changed = False
  1782. for shader_wrapper in self.shader_wrappers:
  1783. shader_wrapper.update_program_uniforms(camera_uniforms)
  1784. shader_wrapper.pre_render()
  1785. shader_wrapper.render()
  1786. # Event Handlers
  1787. """
  1788. Event handling follows the Event Bubbling model of DOM in javascript.
  1789. Return false to stop the event bubbling.
  1790. To learn more visit https://www.quirksmode.org/js/events_order.html
  1791. Event Callback Argument is a callable function taking two arguments:
  1792. 1. Mobject
  1793. 2. EventData
  1794. """
  1795. def init_event_listners(self):
  1796. self.event_listners: list[EventListener] = []
  1797. def add_event_listner(
  1798. self,
  1799. event_type: EventType,
  1800. event_callback: Callable[[Mobject, dict[str]]]
  1801. ):
  1802. event_listner = EventListener(self, event_type, event_callback)
  1803. self.event_listners.append(event_listner)
  1804. EVENT_DISPATCHER.add_listner(event_listner)
  1805. return self
  1806. def remove_event_listner(
  1807. self,
  1808. event_type: EventType,
  1809. event_callback: Callable[[Mobject, dict[str]]]
  1810. ):
  1811. event_listner = EventListener(self, event_type, event_callback)
  1812. while event_listner in self.event_listners:
  1813. self.event_listners.remove(event_listner)
  1814. EVENT_DISPATCHER.remove_listner(event_listner)
  1815. return self
  1816. def clear_event_listners(self, recurse: bool = True):
  1817. self.event_listners = []
  1818. if recurse:
  1819. for submob in self.submobjects:
  1820. submob.clear_event_listners(recurse=recurse)
  1821. return self
  1822. def get_event_listners(self):
  1823. return self.event_listners
  1824. def get_family_event_listners(self):
  1825. return list(it.chain(*[sm.get_event_listners() for sm in self.get_family()]))
  1826. def get_has_event_listner(self):
  1827. return any(
  1828. mob.get_event_listners()
  1829. for mob in self.get_family()
  1830. )
  1831. def add_mouse_motion_listner(self, callback):
  1832. self.add_event_listner(EventType.MouseMotionEvent, callback)
  1833. def remove_mouse_motion_listner(self, callback):
  1834. self.remove_event_listner(EventType.MouseMotionEvent, callback)
  1835. def add_mouse_press_listner(self, callback):
  1836. self.add_event_listner(EventType.MousePressEvent, callback)
  1837. def remove_mouse_press_listner(self, callback):
  1838. self.remove_event_listner(EventType.MousePressEvent, callback)
  1839. def add_mouse_release_listner(self, callback):
  1840. self.add_event_listner(EventType.MouseReleaseEvent, callback)
  1841. def remove_mouse_release_listner(self, callback):
  1842. self.remove_event_listner(EventType.MouseReleaseEvent, callback)
  1843. def add_mouse_drag_listner(self, callback):
  1844. self.add_event_listner(EventType.MouseDragEvent, callback)
  1845. def remove_mouse_drag_listner(self, callback):
  1846. self.remove_event_listner(EventType.MouseDragEvent, callback)
  1847. def add_mouse_scroll_listner(self, callback):
  1848. self.add_event_listner(EventType.MouseScrollEvent, callback)
  1849. def remove_mouse_scroll_listner(self, callback):
  1850. self.remove_event_listner(EventType.MouseScrollEvent, callback)
  1851. def add_key_press_listner(self, callback):
  1852. self.add_event_listner(EventType.KeyPressEvent, callback)
  1853. def remove_key_press_listner(self, callback):
  1854. self.remove_event_listner(EventType.KeyPressEvent, callback)
  1855. def add_key_release_listner(self, callback):
  1856. self.add_event_listner(EventType.KeyReleaseEvent, callback)
  1857. def remove_key_release_listner(self, callback):
  1858. self.remove_event_listner(EventType.KeyReleaseEvent, callback)
  1859. # Errors
  1860. def throw_error_if_no_points(self):
  1861. if not self.has_points():
  1862. message = "Cannot call Mobject.{} " +\
  1863. "for a Mobject with no points"
  1864. caller_name = sys._getframe(1).f_code.co_name
  1865. raise Exception(message.format(caller_name))
  1866. class Group(Mobject, Generic[SubmobjectType]):
  1867. def __init__(self, *mobjects: SubmobjectType | Iterable[SubmobjectType], **kwargs):
  1868. super().__init__(**kwargs)
  1869. self._ingest_args(*mobjects)
  1870. def _ingest_args(self, *args: Mobject | Iterable[Mobject]):
  1871. if len(args) == 0:
  1872. return
  1873. if all(isinstance(mob, Mobject) for mob in args):
  1874. self.add(*args)
  1875. elif isinstance(args[0], Iterable):
  1876. self.add(*args[0])
  1877. else:
  1878. raise Exception(f"Invalid argument to Group of type {type(args[0])}")
  1879. def __add__(self, other: Mobject | Group) -> Self:
  1880. assert isinstance(other, Mobject)
  1881. return self.add(other)
  1882. # This is just here to make linters happy with references to things like Group(...)[0]
  1883. def __getitem__(self, index) -> SubmobjectType:
  1884. return super().__getitem__(index)
  1885. class Point(Mobject):
  1886. def __init__(
  1887. self,
  1888. location: Vect3 = ORIGIN,
  1889. artificial_width: float = 1e-6,
  1890. artificial_height: float = 1e-6,
  1891. **kwargs
  1892. ):
  1893. self.artificial_width = artificial_width
  1894. self.artificial_height = artificial_height
  1895. super().__init__(**kwargs)
  1896. self.set_location(location)
  1897. def get_width(self) -> float:
  1898. return self.artificial_width
  1899. def get_height(self) -> float:
  1900. return self.artificial_height
  1901. def get_location(self) -> Vect3:
  1902. return self.get_points()[0].copy()
  1903. def get_bounding_box_point(self, *args, **kwargs) -> Vect3:
  1904. return self.get_location()
  1905. def set_location(self, new_loc: npt.ArrayLike) -> Self:
  1906. self.set_points(np.array(new_loc, ndmin=2, dtype=float))
  1907. return self
  1908. class _AnimationBuilder:
  1909. def __init__(self, mobject: Mobject):
  1910. self.mobject = mobject
  1911. self.overridden_animation = None
  1912. self.mobject.generate_target()
  1913. self.is_chaining = False
  1914. self.methods: list[Callable] = []
  1915. self.anim_args = {}
  1916. self.can_pass_args = True
  1917. def __getattr__(self, method_name: str):
  1918. method = getattr(self.mobject.target, method_name)
  1919. self.methods.append(method)
  1920. has_overridden_animation = hasattr(method, "_override_animate")
  1921. if (self.is_chaining and has_overridden_animation) or self.overridden_animation:
  1922. raise NotImplementedError(
  1923. "Method chaining is currently not supported for " + \
  1924. "overridden animations"
  1925. )
  1926. def update_target(*method_args, **method_kwargs):
  1927. if has_overridden_animation:
  1928. self.overridden_animation = method._override_animate(
  1929. self.mobject, *method_args, **method_kwargs
  1930. )
  1931. else:
  1932. method(*method_args, **method_kwargs)
  1933. return self
  1934. self.is_chaining = True
  1935. return update_target
  1936. def __call__(self, **kwargs):
  1937. return self.set_anim_args(**kwargs)
  1938. def set_anim_args(self, **kwargs):
  1939. '''
  1940. You can change the args of :class:`~manimlib.animation.transform.Transform`, such as
  1941. - ``run_time``
  1942. - ``time_span``
  1943. - ``rate_func``
  1944. - ``lag_ratio``
  1945. - ``path_arc``
  1946. - ``path_func``
  1947. and so on.
  1948. '''
  1949. if not self.can_pass_args:
  1950. raise ValueError(
  1951. "Animation arguments can only be passed by calling ``animate`` " + \
  1952. "or ``set_anim_args`` and can only be passed once",
  1953. )
  1954. self.anim_args = kwargs
  1955. self.can_pass_args = False
  1956. return self
  1957. def build(self):
  1958. from manimlib.animation.transform import _MethodAnimation
  1959. if self.overridden_animation:
  1960. return self.overridden_animation
  1961. return _MethodAnimation(self.mobject, self.methods, **self.anim_args)
  1962. def override_animate(method):
  1963. def decorator(animation_method):
  1964. method._override_animate = animation_method
  1965. return animation_method
  1966. return decorator
  1967. class _UpdaterBuilder:
  1968. def __init__(self, mobject: Mobject):
  1969. self.mobject = mobject
  1970. def __getattr__(self, method_name: str):
  1971. def add_updater(*method_args, **method_kwargs):
  1972. self.mobject.add_updater(
  1973. lambda m: getattr(m, method_name)(*method_args, **method_kwargs)
  1974. )
  1975. return self
  1976. return add_updater
  1977. class _FunctionalUpdaterBuilder:
  1978. def __init__(self, mobject: Mobject):
  1979. self.mobject = mobject
  1980. def __getattr__(self, method_name: str):
  1981. def add_updater(*method_args, **method_kwargs):
  1982. self.mobject.add_updater(
  1983. lambda m: getattr(m, method_name)(
  1984. *(arg() for arg in method_args),
  1985. **{
  1986. key: value()
  1987. for key, value in method_kwargs.items()
  1988. }
  1989. )
  1990. )
  1991. return self
  1992. return add_updater