123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- from __future__ import annotations
- import itertools as it
- from difflib import SequenceMatcher
- from manimlib.animation.composition import AnimationGroup
- from manimlib.animation.fading import FadeInFromPoint
- from manimlib.animation.fading import FadeOutToPoint
- from manimlib.animation.transform import Transform
- from manimlib.mobject.mobject import Mobject
- from manimlib.mobject.types.vectorized_mobject import VMobject
- from manimlib.mobject.svg.string_mobject import StringMobject
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from typing import Iterable
- from manimlib.scene.scene import Scene
- class TransformMatchingParts(AnimationGroup):
- def __init__(
- self,
- source: Mobject,
- target: Mobject,
- matched_pairs: Iterable[tuple[Mobject, Mobject]] = [],
- match_animation: type = Transform,
- mismatch_animation: type = Transform,
- run_time: float = 2,
- lag_ratio: float = 0,
- **kwargs,
- ):
- self.source = source
- self.target = target
- self.match_animation = match_animation
- self.mismatch_animation = mismatch_animation
- self.anim_config = dict(**kwargs)
- # We will progressively build up a list of transforms
- # from pieces in source to those in target. These
- # two lists keep track of which pieces are accounted
- # for so far
- self.source_pieces = source.family_members_with_points()
- self.target_pieces = target.family_members_with_points()
- self.anims = []
- for pair in matched_pairs:
- self.add_transform(*pair)
- # Match any pairs with the same shape
- for pair in self.find_pairs_with_matching_shapes(self.source_pieces, self.target_pieces):
- self.add_transform(*pair)
- # Finally, account for mismatches
- for source_piece in self.source_pieces:
- if any([source_piece in anim.mobject.get_family() for anim in self.anims]):
- continue
- self.anims.append(FadeOutToPoint(
- source_piece, target.get_center(),
- **self.anim_config
- ))
- for target_piece in self.target_pieces:
- if any([target_piece in anim.mobject.get_family() for anim in self.anims]):
- continue
- self.anims.append(FadeInFromPoint(
- target_piece, source.get_center(),
- **self.anim_config
- ))
- super().__init__(
- *self.anims,
- run_time=run_time,
- lag_ratio=lag_ratio,
- )
- def add_transform(
- self,
- source: Mobject,
- target: Mobject,
- ):
- new_source_pieces = source.family_members_with_points()
- new_target_pieces = target.family_members_with_points()
- if len(new_source_pieces) == 0 or len(new_target_pieces) == 0:
- # Don't animate null sorces or null targets
- return
- source_is_new = all(char in self.source_pieces for char in new_source_pieces)
- target_is_new = all(char in self.target_pieces for char in new_target_pieces)
- if not source_is_new or not target_is_new:
- return
- transform_type = self.mismatch_animation
- if source.has_same_shape_as(target):
- transform_type = self.match_animation
- self.anims.append(transform_type(source, target, **self.anim_config))
- for char in new_source_pieces:
- self.source_pieces.remove(char)
- for char in new_target_pieces:
- self.target_pieces.remove(char)
- def find_pairs_with_matching_shapes(
- self,
- chars1: list[Mobject],
- chars2: list[Mobject]
- ) -> list[tuple[Mobject, Mobject]]:
- result = []
- for char1, char2 in it.product(chars1, chars2):
- if char1.has_same_shape_as(char2):
- result.append((char1, char2))
- return result
- def clean_up_from_scene(self, scene: Scene) -> None:
- super().clean_up_from_scene(scene)
- scene.remove(self.mobject)
- scene.add(self.target)
- class TransformMatchingShapes(TransformMatchingParts):
- """Alias for TransformMatchingParts"""
- pass
- class TransformMatchingStrings(TransformMatchingParts):
- def __init__(
- self,
- source: StringMobject,
- target: StringMobject,
- matched_keys: Iterable[str] = [],
- key_map: dict[str, str] = dict(),
- matched_pairs: Iterable[tuple[VMobject, VMobject]] = [],
- **kwargs,
- ):
- matched_pairs = [
- *matched_pairs,
- *self.matching_blocks(source, target, matched_keys, key_map),
- ]
- super().__init__(
- source, target,
- matched_pairs=matched_pairs,
- **kwargs,
- )
- def matching_blocks(
- self,
- source: StringMobject,
- target: StringMobject,
- matched_keys: Iterable[str],
- key_map: dict[str, str]
- ) -> list[tuple[VMobject, VMobject]]:
- syms1 = source.get_symbol_substrings()
- syms2 = target.get_symbol_substrings()
- counts1 = list(map(source.substr_to_path_count, syms1))
- counts2 = list(map(target.substr_to_path_count, syms2))
- # Start with user specified matches
- blocks = [(source[key], target[key]) for key in matched_keys]
- blocks += [(source[key1], target[key2]) for key1, key2 in key_map.items()]
- # Nullify any intersections with those matches in the two symbol lists
- for sub_source, sub_target in blocks:
- for i in range(len(syms1)):
- if source[i] in sub_source.family_members_with_points():
- syms1[i] = "Null1"
- for j in range(len(syms2)):
- if target[j] in sub_target.family_members_with_points():
- syms2[j] = "Null2"
- # Group together longest matching substrings
- while True:
- matcher = SequenceMatcher(None, syms1, syms2)
- match = matcher.find_longest_match(0, len(syms1), 0, len(syms2))
- if match.size == 0:
- break
- i1 = sum(counts1[:match.a])
- i2 = sum(counts2[:match.b])
- size = sum(counts1[match.a:match.a + match.size])
- blocks.append((source[i1:i1 + size], target[i2:i2 + size]))
- for i in range(match.size):
- syms1[match.a + i] = "Null1"
- syms2[match.b + i] = "Null2"
- return blocks
- class TransformMatchingTex(TransformMatchingStrings):
- """Alias for TransformMatchingStrings"""
- pass
|