svg_mobject.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. from __future__ import annotations
  2. import os
  3. from xml.etree import ElementTree as ET
  4. import numpy as np
  5. import svgelements as se
  6. import io
  7. from manimlib.constants import RIGHT
  8. from manimlib.logger import log
  9. from manimlib.mobject.geometry import Circle
  10. from manimlib.mobject.geometry import Line
  11. from manimlib.mobject.geometry import Polygon
  12. from manimlib.mobject.geometry import Polyline
  13. from manimlib.mobject.geometry import Rectangle
  14. from manimlib.mobject.geometry import RoundedRectangle
  15. from manimlib.mobject.types.vectorized_mobject import VMobject
  16. from manimlib.utils.directories import get_mobject_data_dir
  17. from manimlib.utils.images import get_full_vector_image_path
  18. from manimlib.utils.iterables import hash_obj
  19. from manimlib.utils.simple_functions import hash_string
  20. from typing import TYPE_CHECKING
  21. if TYPE_CHECKING:
  22. from typing import Tuple
  23. from manimlib.typing import ManimColor, Vect3Array
  24. SVG_HASH_TO_MOB_MAP: dict[int, list[VMobject]] = {}
  25. PATH_TO_POINTS: dict[str, Vect3Array] = {}
  26. def _convert_point_to_3d(x: float, y: float) -> np.ndarray:
  27. return np.array([x, y, 0.0])
  28. class SVGMobject(VMobject):
  29. file_name: str = ""
  30. height: float | None = 2.0
  31. width: float | None = None
  32. def __init__(
  33. self,
  34. file_name: str = "",
  35. should_center: bool = True,
  36. height: float | None = None,
  37. width: float | None = None,
  38. # Style that overrides the original svg
  39. color: ManimColor = None,
  40. fill_color: ManimColor = None,
  41. fill_opacity: float | None = None,
  42. stroke_width: float | None = 0.0,
  43. stroke_color: ManimColor = None,
  44. stroke_opacity: float | None = None,
  45. # Style that fills only when not specified
  46. # If None, regarded as default values from svg standard
  47. svg_default: dict = dict(
  48. color=None,
  49. opacity=None,
  50. fill_color=None,
  51. fill_opacity=None,
  52. stroke_width=None,
  53. stroke_color=None,
  54. stroke_opacity=None,
  55. ),
  56. path_string_config: dict = dict(),
  57. **kwargs
  58. ):
  59. self.file_name = file_name or self.file_name
  60. self.svg_default = dict(svg_default)
  61. self.path_string_config = dict(path_string_config)
  62. super().__init__(**kwargs )
  63. self.init_svg_mobject()
  64. self.ensure_positive_orientation()
  65. # Rather than passing style into super().__init__
  66. # do it after svg has been taken in
  67. self.set_style(
  68. fill_color=color or fill_color,
  69. fill_opacity=fill_opacity,
  70. stroke_color=color or stroke_color,
  71. stroke_width=stroke_width,
  72. stroke_opacity=stroke_opacity,
  73. )
  74. # Initialize position
  75. height = height or self.height
  76. width = width or self.width
  77. if should_center:
  78. self.center()
  79. if height is not None:
  80. self.set_height(height)
  81. if width is not None:
  82. self.set_width(width)
  83. def init_svg_mobject(self) -> None:
  84. hash_val = hash_obj(self.hash_seed)
  85. if hash_val in SVG_HASH_TO_MOB_MAP:
  86. submobs = [sm.copy() for sm in SVG_HASH_TO_MOB_MAP[hash_val]]
  87. else:
  88. submobs = self.mobjects_from_file(self.get_file_path())
  89. SVG_HASH_TO_MOB_MAP[hash_val] = [sm.copy() for sm in submobs]
  90. self.add(*submobs)
  91. self.flip(RIGHT) # Flip y
  92. @property
  93. def hash_seed(self) -> tuple:
  94. # Returns data which can uniquely represent the result of `init_points`.
  95. # The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`.
  96. return (
  97. self.__class__.__name__,
  98. self.svg_default,
  99. self.path_string_config,
  100. self.file_name
  101. )
  102. def mobjects_from_file(self, file_path: str) -> list[VMobject]:
  103. element_tree = ET.parse(file_path)
  104. new_tree = self.modify_xml_tree(element_tree)
  105. # New svg based on tree contents
  106. data_stream = io.BytesIO()
  107. new_tree.write(data_stream)
  108. data_stream.seek(0)
  109. svg = se.SVG.parse(data_stream)
  110. data_stream.close()
  111. return self.mobjects_from_svg(svg)
  112. def get_file_path(self) -> str:
  113. if self.file_name is None:
  114. raise Exception("Must specify file for SVGMobject")
  115. return get_full_vector_image_path(self.file_name)
  116. def modify_xml_tree(self, element_tree: ET.ElementTree) -> ET.ElementTree:
  117. config_style_attrs = self.generate_config_style_dict()
  118. style_keys = (
  119. "fill",
  120. "fill-opacity",
  121. "stroke",
  122. "stroke-opacity",
  123. "stroke-width",
  124. "style"
  125. )
  126. root = element_tree.getroot()
  127. style_attrs = {
  128. k: v
  129. for k, v in root.attrib.items()
  130. if k in style_keys
  131. }
  132. # Ignore other attributes in case that svgelements cannot parse them
  133. SVG_XMLNS = "{http://www.w3.org/2000/svg}"
  134. new_root = ET.Element("svg")
  135. config_style_node = ET.SubElement(new_root, f"{SVG_XMLNS}g", config_style_attrs)
  136. root_style_node = ET.SubElement(config_style_node, f"{SVG_XMLNS}g", style_attrs)
  137. root_style_node.extend(root)
  138. return ET.ElementTree(new_root)
  139. def generate_config_style_dict(self) -> dict[str, str]:
  140. keys_converting_dict = {
  141. "fill": ("color", "fill_color"),
  142. "fill-opacity": ("opacity", "fill_opacity"),
  143. "stroke": ("color", "stroke_color"),
  144. "stroke-opacity": ("opacity", "stroke_opacity"),
  145. "stroke-width": ("stroke_width",)
  146. }
  147. svg_default_dict = self.svg_default
  148. result = {}
  149. for svg_key, style_keys in keys_converting_dict.items():
  150. for style_key in style_keys:
  151. if svg_default_dict[style_key] is None:
  152. continue
  153. result[svg_key] = str(svg_default_dict[style_key])
  154. return result
  155. def mobjects_from_svg(self, svg: se.SVG) -> list[VMobject]:
  156. result = []
  157. for shape in svg.elements():
  158. if isinstance(shape, (se.Group, se.Use)):
  159. continue
  160. elif isinstance(shape, se.Path):
  161. mob = self.path_to_mobject(shape)
  162. elif isinstance(shape, se.SimpleLine):
  163. mob = self.line_to_mobject(shape)
  164. elif isinstance(shape, se.Rect):
  165. mob = self.rect_to_mobject(shape)
  166. elif isinstance(shape, (se.Circle, se.Ellipse)):
  167. mob = self.ellipse_to_mobject(shape)
  168. elif isinstance(shape, se.Polygon):
  169. mob = self.polygon_to_mobject(shape)
  170. elif isinstance(shape, se.Polyline):
  171. mob = self.polyline_to_mobject(shape)
  172. # elif isinstance(shape, se.Text):
  173. # mob = self.text_to_mobject(shape)
  174. elif type(shape) == se.SVGElement:
  175. continue
  176. else:
  177. log.warning("Unsupported element type: %s", type(shape))
  178. continue
  179. if not mob.has_points():
  180. continue
  181. if isinstance(shape, se.GraphicObject):
  182. self.apply_style_to_mobject(mob, shape)
  183. if isinstance(shape, se.Transformable) and shape.apply:
  184. self.handle_transform(mob, shape.transform)
  185. result.append(mob)
  186. return result
  187. @staticmethod
  188. def handle_transform(mob: VMobject, matrix: se.Matrix) -> VMobject:
  189. mat = np.array([
  190. [matrix.a, matrix.c],
  191. [matrix.b, matrix.d]
  192. ])
  193. vec = np.array([matrix.e, matrix.f, 0.0])
  194. mob.apply_matrix(mat)
  195. mob.shift(vec)
  196. return mob
  197. @staticmethod
  198. def apply_style_to_mobject(
  199. mob: VMobject,
  200. shape: se.GraphicObject
  201. ) -> VMobject:
  202. mob.set_style(
  203. stroke_width=shape.stroke_width,
  204. stroke_color=shape.stroke.hexrgb,
  205. stroke_opacity=shape.stroke.opacity,
  206. fill_color=shape.fill.hexrgb,
  207. fill_opacity=shape.fill.opacity
  208. )
  209. return mob
  210. def path_to_mobject(self, path: se.Path) -> VMobjectFromSVGPath:
  211. return VMobjectFromSVGPath(path, **self.path_string_config)
  212. def line_to_mobject(self, line: se.SimpleLine) -> Line:
  213. return Line(
  214. start=_convert_point_to_3d(line.x1, line.y1),
  215. end=_convert_point_to_3d(line.x2, line.y2)
  216. )
  217. def rect_to_mobject(self, rect: se.Rect) -> Rectangle:
  218. if rect.rx == 0 or rect.ry == 0:
  219. mob = Rectangle(
  220. width=rect.width,
  221. height=rect.height,
  222. )
  223. else:
  224. mob = RoundedRectangle(
  225. width=rect.width,
  226. height=rect.height * rect.rx / rect.ry,
  227. corner_radius=rect.rx
  228. )
  229. mob.stretch_to_fit_height(rect.height)
  230. mob.shift(_convert_point_to_3d(
  231. rect.x + rect.width / 2,
  232. rect.y + rect.height / 2
  233. ))
  234. return mob
  235. def ellipse_to_mobject(self, ellipse: se.Circle | se.Ellipse) -> Circle:
  236. mob = Circle(radius=ellipse.rx)
  237. mob.stretch_to_fit_height(2 * ellipse.ry)
  238. mob.shift(_convert_point_to_3d(
  239. ellipse.cx, ellipse.cy
  240. ))
  241. return mob
  242. def polygon_to_mobject(self, polygon: se.Polygon) -> Polygon:
  243. points = [
  244. _convert_point_to_3d(*point)
  245. for point in polygon
  246. ]
  247. return Polygon(*points)
  248. def polyline_to_mobject(self, polyline: se.Polyline) -> Polyline:
  249. points = [
  250. _convert_point_to_3d(*point)
  251. for point in polyline
  252. ]
  253. return Polyline(*points)
  254. def text_to_mobject(self, text: se.Text):
  255. pass
  256. class VMobjectFromSVGPath(VMobject):
  257. def __init__(
  258. self,
  259. path_obj: se.Path,
  260. **kwargs
  261. ):
  262. # Get rid of arcs
  263. path_obj.approximate_arcs_with_quads()
  264. self.path_obj = path_obj
  265. super().__init__(**kwargs)
  266. def init_points(self) -> None:
  267. # After a given svg_path has been converted into points, the result
  268. # will be saved so that future calls for the same pathdon't need to
  269. # retrace the same computation.
  270. path_string = self.path_obj.d()
  271. if path_string not in PATH_TO_POINTS:
  272. self.handle_commands()
  273. # Save for future use
  274. PATH_TO_POINTS[path_string] = self.get_points().copy()
  275. else:
  276. points = PATH_TO_POINTS[path_string]
  277. self.set_points(points)
  278. def handle_commands(self) -> None:
  279. segment_class_to_func_map = {
  280. se.Move: (self.start_new_path, ("end",)),
  281. se.Close: (self.close_path, ()),
  282. se.Line: (lambda p: self.add_line_to(p, allow_null_line=False), ("end",)),
  283. se.QuadraticBezier: (lambda c, e: self.add_quadratic_bezier_curve_to(c, e, allow_null_curve=False), ("control", "end")),
  284. se.CubicBezier: (self.add_cubic_bezier_curve_to, ("control1", "control2", "end"))
  285. }
  286. for segment in self.path_obj:
  287. segment_class = segment.__class__
  288. func, attr_names = segment_class_to_func_map[segment_class]
  289. points = [
  290. _convert_point_to_3d(*segment.__getattribute__(attr_name))
  291. for attr_name in attr_names
  292. ]
  293. func(*points)
  294. # Get rid of the side effect of trailing "Z M" commands.
  295. if self.has_new_path_started():
  296. self.resize_points(self.get_num_points() - 2)