creation.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. import numpy as np
  4. from manimlib.animation.animation import Animation
  5. from manimlib.constants import WHITE
  6. from manimlib.mobject.svg.string_mobject import StringMobject
  7. from manimlib.mobject.types.vectorized_mobject import VMobject
  8. from manimlib.utils.bezier import integer_interpolate
  9. from manimlib.utils.rate_functions import linear
  10. from manimlib.utils.rate_functions import double_smooth
  11. from manimlib.utils.rate_functions import smooth
  12. from manimlib.utils.simple_functions import clip
  13. from typing import TYPE_CHECKING
  14. if TYPE_CHECKING:
  15. from typing import Callable
  16. from manimlib.mobject.mobject import Mobject
  17. from manimlib.scene.scene import Scene
  18. from manimlib.typing import ManimColor
  19. class ShowPartial(Animation, ABC):
  20. """
  21. Abstract class for ShowCreation and ShowPassingFlash
  22. """
  23. def __init__(self, mobject: Mobject, should_match_start: bool = False, **kwargs):
  24. self.should_match_start = should_match_start
  25. super().__init__(mobject, **kwargs)
  26. def interpolate_submobject(
  27. self,
  28. submob: VMobject,
  29. start_submob: VMobject,
  30. alpha: float
  31. ) -> None:
  32. submob.pointwise_become_partial(
  33. start_submob, *self.get_bounds(alpha)
  34. )
  35. @abstractmethod
  36. def get_bounds(self, alpha: float) -> tuple[float, float]:
  37. raise Exception("Not Implemented")
  38. class ShowCreation(ShowPartial):
  39. def __init__(self, mobject: Mobject, lag_ratio: float = 1.0, **kwargs):
  40. super().__init__(mobject, lag_ratio=lag_ratio, **kwargs)
  41. def get_bounds(self, alpha: float) -> tuple[float, float]:
  42. return (0, alpha)
  43. class Uncreate(ShowCreation):
  44. def __init__(
  45. self,
  46. mobject: Mobject,
  47. rate_func: Callable[[float], float] = lambda t: smooth(1 - t),
  48. remover: bool = True,
  49. should_match_start: bool = True,
  50. **kwargs,
  51. ):
  52. super().__init__(
  53. mobject,
  54. rate_func=rate_func,
  55. remover=remover,
  56. should_match_start=should_match_start,
  57. **kwargs,
  58. )
  59. class DrawBorderThenFill(Animation):
  60. def __init__(
  61. self,
  62. vmobject: VMobject,
  63. run_time: float = 2.0,
  64. rate_func: Callable[[float], float] = double_smooth,
  65. stroke_width: float = 2.0,
  66. stroke_color: ManimColor = None,
  67. draw_border_animation_config: dict = {},
  68. fill_animation_config: dict = {},
  69. **kwargs
  70. ):
  71. assert isinstance(vmobject, VMobject)
  72. self.sm_to_index = {hash(sm): 0 for sm in vmobject.get_family()}
  73. self.stroke_width = stroke_width
  74. self.stroke_color = stroke_color
  75. self.draw_border_animation_config = draw_border_animation_config
  76. self.fill_animation_config = fill_animation_config
  77. super().__init__(
  78. vmobject,
  79. run_time=run_time,
  80. rate_func=rate_func,
  81. **kwargs
  82. )
  83. self.mobject = vmobject
  84. def begin(self) -> None:
  85. self.mobject.set_animating_status(True)
  86. self.outline = self.get_outline()
  87. super().begin()
  88. self.mobject.match_style(self.outline)
  89. def finish(self) -> None:
  90. super().finish()
  91. self.mobject.refresh_joint_angles()
  92. def get_outline(self) -> VMobject:
  93. outline = self.mobject.copy()
  94. outline.set_fill(opacity=0)
  95. for sm in outline.family_members_with_points():
  96. sm.set_stroke(
  97. color=self.stroke_color or sm.get_stroke_color(),
  98. width=self.stroke_width,
  99. behind=self.mobject.stroke_behind,
  100. )
  101. return outline
  102. def get_all_mobjects(self) -> list[Mobject]:
  103. return [*super().get_all_mobjects(), self.outline]
  104. def interpolate_submobject(
  105. self,
  106. submob: VMobject,
  107. start: VMobject,
  108. outline: VMobject,
  109. alpha: float
  110. ) -> None:
  111. index, subalpha = integer_interpolate(0, 2, alpha)
  112. if index == 1 and self.sm_to_index[hash(submob)] == 0:
  113. # First time crossing over
  114. submob.set_data(outline.data)
  115. self.sm_to_index[hash(submob)] = 1
  116. if index == 0:
  117. submob.pointwise_become_partial(outline, 0, subalpha)
  118. else:
  119. submob.interpolate(outline, start, subalpha)
  120. class Write(DrawBorderThenFill):
  121. def __init__(
  122. self,
  123. vmobject: VMobject,
  124. run_time: float = -1, # If negative, this will be reassigned
  125. lag_ratio: float = -1, # If negative, this will be reassigned
  126. rate_func: Callable[[float], float] = linear,
  127. stroke_color: ManimColor = None,
  128. **kwargs
  129. ):
  130. if stroke_color is None:
  131. stroke_color = vmobject.get_color()
  132. family_size = len(vmobject.family_members_with_points())
  133. super().__init__(
  134. vmobject,
  135. run_time=self.compute_run_time(family_size, run_time),
  136. lag_ratio=self.compute_lag_ratio(family_size, lag_ratio),
  137. rate_func=rate_func,
  138. stroke_color=stroke_color,
  139. **kwargs
  140. )
  141. def compute_run_time(self, family_size: int, run_time: float):
  142. if run_time < 0:
  143. return 1 if family_size < 15 else 2
  144. return run_time
  145. def compute_lag_ratio(self, family_size: int, lag_ratio: float):
  146. if lag_ratio < 0:
  147. return min(4.0 / (family_size + 1.0), 0.2)
  148. return lag_ratio
  149. class ShowIncreasingSubsets(Animation):
  150. def __init__(
  151. self,
  152. group: Mobject,
  153. int_func: Callable[[float], float] = np.round,
  154. suspend_mobject_updating: bool = False,
  155. **kwargs
  156. ):
  157. self.all_submobs = list(group.submobjects)
  158. self.int_func = int_func
  159. super().__init__(
  160. group,
  161. suspend_mobject_updating=suspend_mobject_updating,
  162. **kwargs
  163. )
  164. def interpolate_mobject(self, alpha: float) -> None:
  165. n_submobs = len(self.all_submobs)
  166. alpha = self.rate_func(alpha)
  167. index = int(self.int_func(alpha * n_submobs))
  168. self.update_submobject_list(index)
  169. def update_submobject_list(self, index: int) -> None:
  170. self.mobject.set_submobjects(self.all_submobs[:index])
  171. class ShowSubmobjectsOneByOne(ShowIncreasingSubsets):
  172. def __init__(
  173. self,
  174. group: Mobject,
  175. int_func: Callable[[float], float] = np.ceil,
  176. **kwargs
  177. ):
  178. super().__init__(group, int_func=int_func, **kwargs)
  179. def update_submobject_list(self, index: int) -> None:
  180. index = int(clip(index, 0, len(self.all_submobs) - 1))
  181. if index == 0:
  182. self.mobject.set_submobjects([])
  183. else:
  184. self.mobject.set_submobjects([self.all_submobs[index - 1]])
  185. class AddTextWordByWord(ShowIncreasingSubsets):
  186. def __init__(
  187. self,
  188. string_mobject: StringMobject,
  189. time_per_word: float = 0.2,
  190. run_time: float = -1.0, # If negative, it will be recomputed with time_per_word
  191. rate_func: Callable[[float], float] = linear,
  192. **kwargs
  193. ):
  194. assert isinstance(string_mobject, StringMobject)
  195. grouped_mobject = string_mobject.build_groups()
  196. if run_time < 0:
  197. run_time = time_per_word * len(grouped_mobject)
  198. super().__init__(
  199. grouped_mobject,
  200. run_time=run_time,
  201. rate_func=rate_func,
  202. **kwargs
  203. )
  204. self.string_mobject = string_mobject
  205. def clean_up_from_scene(self, scene: Scene) -> None:
  206. scene.remove(self.mobject)
  207. if not self.is_remover():
  208. scene.add(self.string_mobject)