space_ops.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. from __future__ import annotations
  2. from functools import reduce
  3. import math
  4. import operator as op
  5. import platform
  6. from mapbox_earcut import triangulate_float32 as earcut
  7. import numpy as np
  8. from scipy.spatial.transform import Rotation
  9. from tqdm.auto import tqdm as ProgressDisplay
  10. from manimlib.constants import DOWN, OUT, RIGHT, UP
  11. from manimlib.constants import PI, TAU
  12. from manimlib.utils.iterables import adjacent_pairs
  13. from manimlib.utils.simple_functions import clip
  14. from typing import TYPE_CHECKING
  15. if TYPE_CHECKING:
  16. from typing import Callable, Sequence, List, Tuple
  17. from manimlib.typing import Vect2, Vect3, Vect4, VectN, Matrix3x3, Vect3Array, Vect2Array
  18. def cross(
  19. v1: Vect3 | List[float],
  20. v2: Vect3 | List[float],
  21. out: np.ndarray | None = None
  22. ) -> Vect3 | Vect3Array:
  23. is2d = isinstance(v1, np.ndarray) and len(v1.shape) == 2
  24. if is2d:
  25. x1, y1, z1 = v1[:, 0], v1[:, 1], v1[:, 2]
  26. x2, y2, z2 = v2[:, 0], v2[:, 1], v2[:, 2]
  27. else:
  28. x1, y1, z1 = v1
  29. x2, y2, z2 = v2
  30. if out is None:
  31. out = np.empty(np.shape(v1))
  32. out.T[:] = [
  33. y1 * z2 - z1 * y2,
  34. z1 * x2 - x1 * z2,
  35. x1 * y2 - y1 * x2,
  36. ]
  37. return out
  38. def get_norm(vect: VectN | List[float]) -> float:
  39. return sum((x**2 for x in vect))**0.5
  40. def normalize(
  41. vect: VectN | List[float],
  42. fall_back: VectN | List[float] | None = None
  43. ) -> VectN:
  44. norm = get_norm(vect)
  45. if norm > 0:
  46. return np.array(vect) / norm
  47. elif fall_back is not None:
  48. return np.array(fall_back)
  49. else:
  50. return np.zeros(len(vect))
  51. def poly_line_length(points):
  52. """
  53. Return the sum of the lengths between adjacent points
  54. """
  55. diffs = points[1:] - points[:-1]
  56. return np.sqrt((diffs**2).sum(1)).sum()
  57. # Operations related to rotation
  58. def quaternion_mult(*quats: Vect4) -> Vect4:
  59. """
  60. Inputs are treated as quaternions, where the real part is the
  61. last entry, so as to follow the scipy Rotation conventions.
  62. """
  63. if len(quats) == 0:
  64. return np.array([0, 0, 0, 1])
  65. result = np.array(quats[0])
  66. for next_quat in quats[1:]:
  67. x1, y1, z1, w1 = result
  68. x2, y2, z2, w2 = next_quat
  69. result[:] = [
  70. w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
  71. w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2,
  72. w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2,
  73. w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
  74. ]
  75. return result
  76. def quaternion_from_angle_axis(
  77. angle: float,
  78. axis: Vect3,
  79. ) -> Vect4:
  80. return Rotation.from_rotvec(angle * normalize(axis)).as_quat()
  81. def angle_axis_from_quaternion(quat: Vect4) -> Tuple[float, Vect3]:
  82. rot_vec = Rotation.from_quat(quat).as_rotvec()
  83. norm = get_norm(rot_vec)
  84. return norm, rot_vec / norm
  85. def quaternion_conjugate(quaternion: Vect4) -> Vect4:
  86. result = np.array(quaternion)
  87. result[:3] *= -1
  88. return result
  89. def rotate_vector(
  90. vector: Vect3,
  91. angle: float,
  92. axis: Vect3 = OUT
  93. ) -> Vect3:
  94. rot = Rotation.from_rotvec(angle * normalize(axis))
  95. return np.dot(vector, rot.as_matrix().T)
  96. def rotate_vector_2d(vector: Vect2, angle: float) -> Vect2:
  97. # Use complex numbers...because why not
  98. z = complex(*vector) * np.exp(complex(0, angle))
  99. return np.array([z.real, z.imag])
  100. def rotation_matrix_transpose_from_quaternion(quat: Vect4) -> Matrix3x3:
  101. return Rotation.from_quat(quat).as_matrix()
  102. def rotation_matrix_from_quaternion(quat: Vect4) -> Matrix3x3:
  103. return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
  104. def rotation_matrix(angle: float, axis: Vect3) -> Matrix3x3:
  105. """
  106. Rotation in R^3 about a specified axis of rotation.
  107. """
  108. return Rotation.from_rotvec(angle * normalize(axis)).as_matrix()
  109. def rotation_matrix_transpose(angle: float, axis: Vect3) -> Matrix3x3:
  110. return rotation_matrix(angle, axis).T
  111. def rotation_about_z(angle: float) -> Matrix3x3:
  112. cos_a = math.cos(angle)
  113. sin_a = math.sin(angle)
  114. return np.array([
  115. [cos_a, -sin_a, 0],
  116. [sin_a, cos_a, 0],
  117. [0, 0, 1]
  118. ])
  119. def rotation_between_vectors(v1: Vect3, v2: Vect3) -> Matrix3x3:
  120. atol = 1e-8
  121. if get_norm(v1 - v2) < atol:
  122. return np.identity(3)
  123. axis = cross(v1, v2)
  124. if get_norm(axis) < atol:
  125. # v1 and v2 align
  126. axis = cross(v1, RIGHT)
  127. if get_norm(axis) < atol:
  128. # v1 and v2 _and_ RIGHT all align
  129. axis = cross(v1, UP)
  130. return rotation_matrix(
  131. angle=angle_between_vectors(v1, v2),
  132. axis=axis,
  133. )
  134. def z_to_vector(vector: Vect3) -> Matrix3x3:
  135. return rotation_between_vectors(OUT, vector)
  136. def angle_of_vector(vector: Vect2 | Vect3) -> float:
  137. """
  138. Returns polar coordinate theta when vector is project on xy plane
  139. """
  140. return math.atan2(vector[1], vector[0])
  141. def angle_between_vectors(v1: VectN, v2: VectN) -> float:
  142. """
  143. Returns the angle between two 3D vectors.
  144. This angle will always be btw 0 and pi
  145. """
  146. n1 = get_norm(v1)
  147. n2 = get_norm(v2)
  148. if n1 == 0 or n2 == 0:
  149. return 0
  150. cos_angle = np.dot(v1, v2) / np.float64(n1 * n2)
  151. return math.acos(clip(cos_angle, -1, 1))
  152. def project_along_vector(point: Vect3, vector: Vect3) -> Vect3:
  153. matrix = np.identity(3) - np.outer(vector, vector)
  154. return np.dot(point, matrix.T)
  155. def normalize_along_axis(
  156. array: np.ndarray,
  157. axis: int,
  158. ) -> np.ndarray:
  159. norms = np.sqrt((array * array).sum(axis))
  160. norms[norms == 0] = 1
  161. return array / norms[:, np.newaxis]
  162. def get_unit_normal(
  163. v1: Vect3,
  164. v2: Vect3,
  165. tol: float = 1e-6
  166. ) -> Vect3:
  167. v1 = normalize(v1)
  168. v2 = normalize(v2)
  169. cp = cross(v1, v2)
  170. cp_norm = get_norm(cp)
  171. if cp_norm < tol:
  172. # Vectors align, so find a normal to them in the plane shared with the z-axis
  173. new_cp = cross(cross(v1, OUT), v1)
  174. new_cp_norm = get_norm(new_cp)
  175. if new_cp_norm < tol:
  176. return DOWN
  177. return new_cp / new_cp_norm
  178. return cp / cp_norm
  179. ###
  180. def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
  181. row_indices = np.arange(dim).repeat(dim).reshape((dim, dim))
  182. col_indices = np.transpose(row_indices)
  183. return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
  184. def compass_directions(n: int = 4, start_vect: Vect3 = RIGHT) -> Vect3:
  185. angle = TAU / n
  186. return np.array([
  187. rotate_vector(start_vect, k * angle)
  188. for k in range(n)
  189. ])
  190. def complex_to_R3(complex_num: complex) -> Vect3:
  191. return np.array((complex_num.real, complex_num.imag, 0))
  192. def R3_to_complex(point: Vect3) -> complex:
  193. return complex(*point[:2])
  194. def complex_func_to_R3_func(complex_func: Callable[[complex], complex]) -> Callable[[Vect3], Vect3]:
  195. def result(p: Vect3):
  196. return complex_to_R3(complex_func(R3_to_complex(p)))
  197. return result
  198. def center_of_mass(points: Sequence[Vect3]) -> Vect3:
  199. return np.array(points).sum(0) / len(points)
  200. def midpoint(point1: VectN, point2: VectN) -> VectN:
  201. return center_of_mass([point1, point2])
  202. def line_intersection(
  203. line1: Tuple[Vect3, Vect3],
  204. line2: Tuple[Vect3, Vect3]
  205. ) -> Vect3:
  206. """
  207. return intersection point of two lines,
  208. each defined with a pair of vectors determining
  209. the end points
  210. """
  211. x_diff = (line1[0][0] - line1[1][0], line2[0][0] - line2[1][0])
  212. y_diff = (line1[0][1] - line1[1][1], line2[0][1] - line2[1][1])
  213. def det(a, b):
  214. return a[0] * b[1] - a[1] * b[0]
  215. div = det(x_diff, y_diff)
  216. if div == 0:
  217. raise Exception("Lines do not intersect")
  218. d = (det(*line1), det(*line2))
  219. x = det(d, x_diff) / div
  220. y = det(d, y_diff) / div
  221. return np.array([x, y, 0])
  222. def find_intersection(
  223. p0: Vect3 | Vect3Array,
  224. v0: Vect3 | Vect3Array,
  225. p1: Vect3 | Vect3Array,
  226. v1: Vect3 | Vect3Array,
  227. threshold: float = 1e-5,
  228. ) -> Vect3:
  229. """
  230. Return the intersection of a line passing through p0 in direction v0
  231. with one passing through p1 in direction v1. (Or array of intersections
  232. from arrays of such points/directions).
  233. For 3d values, it returns the point on the ray p0 + v0 * t closest to the
  234. ray p1 + v1 * t
  235. """
  236. d = len(p0.shape)
  237. if d == 1:
  238. is_3d = any(arr[2] for arr in (p0, v0, p1, v1))
  239. else:
  240. is_3d = any(z for arr in (p0, v0, p1, v1) for z in arr.T[2])
  241. if not is_3d:
  242. numer = np.array(cross2d(v1, p1 - p0))
  243. denom = np.array(cross2d(v1, v0))
  244. else:
  245. cp1 = cross(v1, p1 - p0)
  246. cp2 = cross(v1, v0)
  247. numer = np.array((cp1 * cp1).sum(d - 1))
  248. denom = np.array((cp1 * cp2).sum(d - 1))
  249. denom[abs(denom) < threshold] = np.inf
  250. ratio = numer / denom
  251. return p0 + (ratio * v0.T).T
  252. def line_intersects_path(
  253. start: Vect2 | Vect3,
  254. end: Vect2 | Vect3,
  255. path: Vect2Array | Vect3Array,
  256. ) -> bool:
  257. """
  258. Tests whether the line (start, end) intersects
  259. a polygonal path defined by its vertices
  260. """
  261. n = len(path) - 1
  262. p1 = np.empty((n, 2))
  263. q1 = np.empty((n, 2))
  264. p1[:] = start[:2]
  265. q1[:] = end[:2]
  266. p2 = path[:-1, :2]
  267. q2 = path[1:, :2]
  268. v1 = q1 - p1
  269. v2 = q2 - p2
  270. mis1 = cross2d(v1, p2 - p1) * cross2d(v1, q2 - p1) < 0
  271. mis2 = cross2d(v2, p1 - p2) * cross2d(v2, q1 - p2) < 0
  272. return bool((mis1 * mis2).any())
  273. def get_closest_point_on_line(a: VectN, b: VectN, p: VectN) -> VectN:
  274. """
  275. It returns point x such that
  276. x is on line ab and xp is perpendicular to ab.
  277. If x lies beyond ab line, then it returns nearest edge(a or b).
  278. """
  279. # x = b + t*(a-b) = t*a + (1-t)*b
  280. t = np.dot(p - b, a - b) / np.dot(a - b, a - b)
  281. if t < 0:
  282. t = 0
  283. if t > 1:
  284. t = 1
  285. return ((t * a) + ((1 - t) * b))
  286. def get_winding_number(points: Sequence[Vect2 | Vect3]) -> float:
  287. total_angle = 0
  288. for p1, p2 in adjacent_pairs(points):
  289. d_angle = angle_of_vector(p2) - angle_of_vector(p1)
  290. d_angle = ((d_angle + PI) % TAU) - PI
  291. total_angle += d_angle
  292. return total_angle / TAU
  293. ##
  294. def cross2d(a: Vect2 | Vect2Array, b: Vect2 | Vect2Array) -> Vect2 | Vect2Array:
  295. if len(a.shape) == 2:
  296. return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
  297. else:
  298. return a[0] * b[1] - b[0] * a[1]
  299. def tri_area(
  300. a: Vect2,
  301. b: Vect2,
  302. c: Vect2
  303. ) -> float:
  304. return 0.5 * abs(
  305. a[0] * (b[1] - c[1]) +
  306. b[0] * (c[1] - a[1]) +
  307. c[0] * (a[1] - b[1])
  308. )
  309. def is_inside_triangle(
  310. p: Vect2,
  311. a: Vect2,
  312. b: Vect2,
  313. c: Vect2
  314. ) -> bool:
  315. """
  316. Test if point p is inside triangle abc
  317. """
  318. crosses = np.array([
  319. cross2d(p - a, b - p),
  320. cross2d(p - b, c - p),
  321. cross2d(p - c, a - p),
  322. ])
  323. return bool(np.all(crosses > 0) or np.all(crosses < 0))
  324. def norm_squared(v: VectN | List[float]) -> float:
  325. return sum(x * x for x in v)
  326. # TODO, fails for polygons drawn over themselves
  327. def earclip_triangulation(verts: Vect3Array | Vect2Array, ring_ends: list[int]) -> list[int]:
  328. """
  329. Returns a list of indices giving a triangulation
  330. of a polygon, potentially with holes
  331. - verts is a numpy array of points
  332. - ring_ends is a list of indices indicating where
  333. the ends of new paths are
  334. """
  335. rings = [
  336. list(range(e0, e1))
  337. for e0, e1 in zip([0, *ring_ends], ring_ends)
  338. ]
  339. epsilon = 1e-6
  340. def is_in(point, ring_id):
  341. return abs(abs(get_winding_number([i - point for i in verts[rings[ring_id]]])) - 1) < epsilon
  342. def ring_area(ring_id):
  343. ring = rings[ring_id]
  344. s = 0
  345. for i, j in zip(ring[1:], ring):
  346. s += cross2d(verts[i], verts[j])
  347. return abs(s) / 2
  348. # Points at the same position may cause problems
  349. for i in rings:
  350. if len(i) < 2:
  351. continue
  352. verts[i[0]] += (verts[i[1]] - verts[i[0]]) * epsilon
  353. verts[i[-1]] += (verts[i[-2]] - verts[i[-1]]) * epsilon
  354. # First, we should know which rings are directly contained in it for each ring
  355. right = [max(verts[rings[i], 0]) for i in range(len(rings))]
  356. left = [min(verts[rings[i], 0]) for i in range(len(rings))]
  357. top = [max(verts[rings[i], 1]) for i in range(len(rings))]
  358. bottom = [min(verts[rings[i], 1]) for i in range(len(rings))]
  359. area = [ring_area(i) for i in range(len(rings))]
  360. # The larger ring must be outside
  361. rings_sorted = list(range(len(rings)))
  362. rings_sorted.sort(key=lambda x: area[x], reverse=True)
  363. def is_in_fast(ring_a, ring_b):
  364. # Whether a is in b
  365. return reduce(op.and_, (
  366. left[ring_b] <= left[ring_a] <= right[ring_a] <= right[ring_b],
  367. bottom[ring_b] <= bottom[ring_a] <= top[ring_a] <= top[ring_b],
  368. is_in(verts[rings[ring_a][0]], ring_b)
  369. ))
  370. chilren = [[] for i in rings]
  371. ringenum = ProgressDisplay(
  372. enumerate(rings_sorted),
  373. total=len(rings),
  374. leave=False,
  375. ascii=True if platform.system() == 'Windows' else None,
  376. dynamic_ncols=True,
  377. desc="SVG Triangulation",
  378. delay=3,
  379. )
  380. for idx, i in ringenum:
  381. for j in rings_sorted[:idx][::-1]:
  382. if is_in_fast(i, j):
  383. chilren[j].append(i)
  384. break
  385. res = []
  386. # Then, we can use earcut for each part
  387. used = [False] * len(rings)
  388. for i in rings_sorted:
  389. if used[i]:
  390. continue
  391. v = rings[i]
  392. ring_ends = [len(v)]
  393. for j in chilren[i]:
  394. used[j] = True
  395. v += rings[j]
  396. ring_ends.append(len(v))
  397. res += [v[i] for i in earcut(verts[v, :2], ring_ends)]
  398. return res