transform_matching_parts.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from __future__ import annotations
  2. import itertools as it
  3. from difflib import SequenceMatcher
  4. from manimlib.animation.composition import AnimationGroup
  5. from manimlib.animation.fading import FadeInFromPoint
  6. from manimlib.animation.fading import FadeOutToPoint
  7. from manimlib.animation.transform import Transform
  8. from manimlib.mobject.mobject import Mobject
  9. from manimlib.mobject.types.vectorized_mobject import VMobject
  10. from manimlib.mobject.svg.string_mobject import StringMobject
  11. from typing import TYPE_CHECKING
  12. if TYPE_CHECKING:
  13. from typing import Iterable
  14. from manimlib.scene.scene import Scene
  15. class TransformMatchingParts(AnimationGroup):
  16. def __init__(
  17. self,
  18. source: Mobject,
  19. target: Mobject,
  20. matched_pairs: Iterable[tuple[Mobject, Mobject]] = [],
  21. match_animation: type = Transform,
  22. mismatch_animation: type = Transform,
  23. run_time: float = 2,
  24. lag_ratio: float = 0,
  25. **kwargs,
  26. ):
  27. self.source = source
  28. self.target = target
  29. self.match_animation = match_animation
  30. self.mismatch_animation = mismatch_animation
  31. self.anim_config = dict(**kwargs)
  32. # We will progressively build up a list of transforms
  33. # from pieces in source to those in target. These
  34. # two lists keep track of which pieces are accounted
  35. # for so far
  36. self.source_pieces = source.family_members_with_points()
  37. self.target_pieces = target.family_members_with_points()
  38. self.anims = []
  39. for pair in matched_pairs:
  40. self.add_transform(*pair)
  41. # Match any pairs with the same shape
  42. for pair in self.find_pairs_with_matching_shapes(self.source_pieces, self.target_pieces):
  43. self.add_transform(*pair)
  44. # Finally, account for mismatches
  45. for source_piece in self.source_pieces:
  46. if any([source_piece in anim.mobject.get_family() for anim in self.anims]):
  47. continue
  48. self.anims.append(FadeOutToPoint(
  49. source_piece, target.get_center(),
  50. **self.anim_config
  51. ))
  52. for target_piece in self.target_pieces:
  53. if any([target_piece in anim.mobject.get_family() for anim in self.anims]):
  54. continue
  55. self.anims.append(FadeInFromPoint(
  56. target_piece, source.get_center(),
  57. **self.anim_config
  58. ))
  59. super().__init__(
  60. *self.anims,
  61. run_time=run_time,
  62. lag_ratio=lag_ratio,
  63. )
  64. def add_transform(
  65. self,
  66. source: Mobject,
  67. target: Mobject,
  68. ):
  69. new_source_pieces = source.family_members_with_points()
  70. new_target_pieces = target.family_members_with_points()
  71. if len(new_source_pieces) == 0 or len(new_target_pieces) == 0:
  72. # Don't animate null sorces or null targets
  73. return
  74. source_is_new = all(char in self.source_pieces for char in new_source_pieces)
  75. target_is_new = all(char in self.target_pieces for char in new_target_pieces)
  76. if not source_is_new or not target_is_new:
  77. return
  78. transform_type = self.mismatch_animation
  79. if source.has_same_shape_as(target):
  80. transform_type = self.match_animation
  81. self.anims.append(transform_type(source, target, **self.anim_config))
  82. for char in new_source_pieces:
  83. self.source_pieces.remove(char)
  84. for char in new_target_pieces:
  85. self.target_pieces.remove(char)
  86. def find_pairs_with_matching_shapes(
  87. self,
  88. chars1: list[Mobject],
  89. chars2: list[Mobject]
  90. ) -> list[tuple[Mobject, Mobject]]:
  91. result = []
  92. for char1, char2 in it.product(chars1, chars2):
  93. if char1.has_same_shape_as(char2):
  94. result.append((char1, char2))
  95. return result
  96. def clean_up_from_scene(self, scene: Scene) -> None:
  97. super().clean_up_from_scene(scene)
  98. scene.remove(self.mobject)
  99. scene.add(self.target)
  100. class TransformMatchingShapes(TransformMatchingParts):
  101. """Alias for TransformMatchingParts"""
  102. pass
  103. class TransformMatchingStrings(TransformMatchingParts):
  104. def __init__(
  105. self,
  106. source: StringMobject,
  107. target: StringMobject,
  108. matched_keys: Iterable[str] = [],
  109. key_map: dict[str, str] = dict(),
  110. matched_pairs: Iterable[tuple[VMobject, VMobject]] = [],
  111. **kwargs,
  112. ):
  113. matched_pairs = [
  114. *matched_pairs,
  115. *self.matching_blocks(source, target, matched_keys, key_map),
  116. ]
  117. super().__init__(
  118. source, target,
  119. matched_pairs=matched_pairs,
  120. **kwargs,
  121. )
  122. def matching_blocks(
  123. self,
  124. source: StringMobject,
  125. target: StringMobject,
  126. matched_keys: Iterable[str],
  127. key_map: dict[str, str]
  128. ) -> list[tuple[VMobject, VMobject]]:
  129. syms1 = source.get_symbol_substrings()
  130. syms2 = target.get_symbol_substrings()
  131. counts1 = list(map(source.substr_to_path_count, syms1))
  132. counts2 = list(map(target.substr_to_path_count, syms2))
  133. # Start with user specified matches
  134. blocks = [(source[key], target[key]) for key in matched_keys]
  135. blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()]
  136. # Nullify any intersections with those matches in the two symbol lists
  137. for sub_source, sub_target in blocks:
  138. for i in range(len(syms1)):
  139. if source[i] in sub_source.family_members_with_points():
  140. syms1[i] = "Null1"
  141. for j in range(len(syms2)):
  142. if target[j] in sub_target.family_members_with_points():
  143. syms2[j] = "Null2"
  144. # Group together longest matching substrings
  145. while True:
  146. matcher = SequenceMatcher(None, syms1, syms2)
  147. match = matcher.find_longest_match(0, len(syms1), 0, len(syms2))
  148. if match.size == 0:
  149. break
  150. i1 = sum(counts1[:match.a])
  151. i2 = sum(counts2[:match.b])
  152. size = sum(counts1[match.a:match.a + match.size])
  153. blocks.append((source[i1:i1 + size], target[i2:i2 + size]))
  154. for i in range(match.size):
  155. syms1[match.a + i] = "Null1"
  156. syms2[match.b + i] = "Null2"
  157. return blocks
  158. class TransformMatchingTex(TransformMatchingStrings):
  159. """Alias for TransformMatchingStrings"""
  160. pass