vector_field.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512
  1. from __future__ import annotations
  2. import itertools as it
  3. import numpy as np
  4. from manimlib.constants import FRAME_HEIGHT, FRAME_WIDTH
  5. from manimlib.constants import BLUE, WHITE
  6. from manimlib.constants import ORIGIN
  7. from manimlib.animation.indication import VShowPassingFlash
  8. from manimlib.mobject.geometry import Arrow
  9. from manimlib.mobject.types.vectorized_mobject import VGroup
  10. from manimlib.mobject.types.vectorized_mobject import VMobject
  11. from manimlib.utils.bezier import interpolate
  12. from manimlib.utils.bezier import inverse_interpolate
  13. from manimlib.utils.color import get_colormap_list
  14. from manimlib.utils.color import rgb_to_color
  15. from manimlib.utils.dict_ops import merge_dicts_recursively
  16. from manimlib.utils.iterables import cartesian_product
  17. from manimlib.utils.rate_functions import linear
  18. from manimlib.utils.simple_functions import sigmoid
  19. from manimlib.utils.space_ops import get_norm
  20. from typing import TYPE_CHECKING
  21. if TYPE_CHECKING:
  22. from typing import Callable, Iterable, Sequence, TypeVar, Tuple
  23. from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array
  24. from manimlib.mobject.coordinate_systems import CoordinateSystem
  25. from manimlib.mobject.mobject import Mobject
  26. T = TypeVar("T")
  27. def get_vectorized_rgb_gradient_function(
  28. min_value: T,
  29. max_value: T,
  30. color_map: str
  31. ) -> Callable[[VectN], Vect3Array]:
  32. rgbs = np.array(get_colormap_list(color_map))
  33. def func(values):
  34. alphas = inverse_interpolate(
  35. min_value, max_value, np.array(values)
  36. )
  37. alphas = np.clip(alphas, 0, 1)
  38. scaled_alphas = alphas * (len(rgbs) - 1)
  39. indices = scaled_alphas.astype(int)
  40. next_indices = np.clip(indices + 1, 0, len(rgbs) - 1)
  41. inter_alphas = scaled_alphas % 1
  42. inter_alphas = inter_alphas.repeat(3).reshape((len(indices), 3))
  43. result = interpolate(rgbs[indices], rgbs[next_indices], inter_alphas)
  44. return result
  45. return func
  46. def get_rgb_gradient_function(
  47. min_value: T,
  48. max_value: T,
  49. color_map: str
  50. ) -> Callable[[float], Vect3]:
  51. vectorized_func = get_vectorized_rgb_gradient_function(min_value, max_value, color_map)
  52. return lambda value: vectorized_func(np.array([value]))[0]
  53. def move_along_vector_field(
  54. mobject: Mobject,
  55. func: Callable[[Vect3], Vect3]
  56. ) -> Mobject:
  57. mobject.add_updater(
  58. lambda m, dt: m.shift(
  59. func(m.get_center()) * dt
  60. )
  61. )
  62. return mobject
  63. def move_submobjects_along_vector_field(
  64. mobject: Mobject,
  65. func: Callable[[Vect3], Vect3]
  66. ) -> Mobject:
  67. def apply_nudge(mob, dt):
  68. for submob in mob:
  69. x, y = submob.get_center()[:2]
  70. if abs(x) < FRAME_WIDTH and abs(y) < FRAME_HEIGHT:
  71. submob.shift(func(submob.get_center()) * dt)
  72. mobject.add_updater(apply_nudge)
  73. return mobject
  74. def move_points_along_vector_field(
  75. mobject: Mobject,
  76. func: Callable[[float, float], Iterable[float]],
  77. coordinate_system: CoordinateSystem
  78. ) -> Mobject:
  79. cs = coordinate_system
  80. origin = cs.get_origin()
  81. def apply_nudge(self, dt):
  82. mobject.apply_function(
  83. lambda p: p + (cs.c2p(*func(*cs.p2c(p))) - origin) * dt
  84. )
  85. mobject.add_updater(apply_nudge)
  86. return mobject
  87. def get_sample_points_from_coordinate_system(
  88. coordinate_system: CoordinateSystem,
  89. step_multiple: float
  90. ) -> it.product[tuple[Vect3, ...]]:
  91. ranges = []
  92. for range_args in coordinate_system.get_all_ranges():
  93. _min, _max, step = range_args
  94. step *= step_multiple
  95. ranges.append(np.arange(_min, _max + step, step))
  96. return it.product(*ranges)
  97. # Mobjects
  98. class VectorField(VMobject):
  99. def __init__(
  100. self,
  101. func,
  102. stroke_color: ManimColor = BLUE,
  103. stroke_opacity: float = 1.0,
  104. center: Vect3 = ORIGIN,
  105. sample_points: Optional[Vect3Array] = None,
  106. x_density: float = 2.0,
  107. y_density: float = 2.0,
  108. z_density: float = 2.0,
  109. width: float = 14.0,
  110. height: float = 8.0,
  111. depth: float = 0.0,
  112. stroke_width: float = 2,
  113. tip_width_ratio: float = 4,
  114. tip_len_to_width: float = 0.01,
  115. max_vect_len: float | None = None,
  116. min_drawn_norm: float = 1e-2,
  117. flat_stroke: bool = False,
  118. norm_to_opacity_func=None,
  119. norm_to_rgb_func=None,
  120. **kwargs
  121. ):
  122. self.func = func
  123. self.stroke_width = stroke_width
  124. self.tip_width_ratio = tip_width_ratio
  125. self.tip_len_to_width = tip_len_to_width
  126. self.min_drawn_norm = min_drawn_norm
  127. self.norm_to_opacity_func = norm_to_opacity_func
  128. self.norm_to_rgb_func = norm_to_rgb_func
  129. if max_vect_len is not None:
  130. self.max_vect_len = max_vect_len
  131. else:
  132. densities = np.array([x_density, y_density, z_density])
  133. dims = np.array([width, height, depth])
  134. self.max_vect_len = 1.0 / densities[dims > 0].mean()
  135. if sample_points is None:
  136. self.sample_points = self.get_sample_points(
  137. center, width, height, depth,
  138. x_density, y_density, z_density
  139. )
  140. else:
  141. self.sample_points = sample_points
  142. self.init_base_stroke_width_array(len(self.sample_points))
  143. super().__init__(
  144. stroke_color=stroke_color,
  145. stroke_opacity=stroke_opacity,
  146. flat_stroke=flat_stroke,
  147. **kwargs
  148. )
  149. n_samples = len(self.sample_points)
  150. self.set_points(np.zeros((8 * n_samples - 1, 3)))
  151. self.set_stroke(width=stroke_width)
  152. self.set_joint_type('no_joint')
  153. self.update_vectors()
  154. def get_sample_points(
  155. self,
  156. center: np.ndarray,
  157. width: float,
  158. height: float,
  159. depth: float,
  160. x_density: float,
  161. y_density: float,
  162. z_density: float
  163. ) -> np.ndarray:
  164. to_corner = np.array([width / 2, height / 2, depth / 2])
  165. spacings = 1.0 / np.array([x_density, y_density, z_density])
  166. to_corner = spacings * (to_corner / spacings).astype(int)
  167. lower_corner = center - to_corner
  168. upper_corner = center + to_corner + spacings
  169. return cartesian_product(*(
  170. np.arange(low, high, space)
  171. for low, high, space in zip(lower_corner, upper_corner, spacings)
  172. ))
  173. def init_base_stroke_width_array(self, n_sample_points):
  174. arr = np.ones(8 * n_sample_points - 1)
  175. arr[4::8] = self.tip_width_ratio
  176. arr[5::8] = self.tip_width_ratio * 0.5
  177. arr[6::8] = 0
  178. arr[7::8] = 0
  179. self.base_stroke_width_array = arr
  180. def set_sample_points(self, sample_points: Vect3Array):
  181. self.sample_points = sample_points
  182. return self
  183. def set_stroke(self, color=None, width=None, opacity=None, behind=None, flat=None, recurse=True):
  184. super().set_stroke(color, None, opacity, behind, flat, recurse)
  185. if width is not None:
  186. self.set_stroke_width(float(width))
  187. return self
  188. def set_stroke_width(self, width: float):
  189. if self.get_num_points() > 0:
  190. self.get_stroke_widths()[:] = width * self.base_stroke_width_array
  191. self.stroke_width = width
  192. return self
  193. def update_vectors(self):
  194. tip_width = self.tip_width_ratio * self.stroke_width
  195. tip_len = self.tip_len_to_width * tip_width
  196. samples = self.sample_points
  197. # Get raw outputs and lengths
  198. outputs = self.func(samples)
  199. norms = np.linalg.norm(outputs, axis=1)[:, np.newaxis]
  200. # How long should the arrows be drawn?
  201. max_len = self.max_vect_len
  202. if max_len < np.inf:
  203. drawn_norms = max_len * np.tanh(norms / max_len)
  204. else:
  205. drawn_norms = norms
  206. # What's the distance from the base of an arrow to
  207. # the base of its head?
  208. dist_to_head_base = np.clip(drawn_norms - tip_len, 0, np.inf)
  209. # Set all points
  210. unit_outputs = np.zeros_like(outputs)
  211. np.true_divide(outputs, norms, out=unit_outputs, where=(norms > self.min_drawn_norm))
  212. points = self.get_points()
  213. points[0::8] = samples
  214. points[2::8] = samples + dist_to_head_base * unit_outputs
  215. points[4::8] = points[2::8]
  216. points[6::8] = samples + drawn_norms * unit_outputs
  217. for i in (1, 3, 5):
  218. points[i::8] = 0.5 * (points[i - 1::8] + points[i + 1::8])
  219. points[7::8] = points[6:-1:8]
  220. # Adjust stroke widths
  221. width_arr = self.stroke_width * self.base_stroke_width_array
  222. width_scalars = np.clip(drawn_norms / tip_len, 0, 1)
  223. width_scalars = np.repeat(width_scalars, 8)[:-1]
  224. self.get_stroke_widths()[:] = width_scalars * width_arr
  225. # Potentially adjust opacity and color
  226. if self.norm_to_opacity_func is not None:
  227. self.get_stroke_opacities()[:] = self.norm_to_opacity_func(
  228. np.repeat(norms, 8)[:-1]
  229. )
  230. if self.norm_to_rgb_func is not None:
  231. self.get_stroke_colors()
  232. self.data['stroke_rgba'][:, :3] = self.norm_to_rgb_func(
  233. np.repeat(norms, 8)[:-1]
  234. )
  235. self.note_changed_data()
  236. return self
  237. class TimeVaryingVectorField(VectorField):
  238. def __init__(
  239. self,
  240. # Takes in an array of points and a float for time
  241. time_func,
  242. **kwargs
  243. ):
  244. self.time = 0
  245. super().__init__(func=lambda p: time_func(p, self.time), **kwargs)
  246. self.add_updater(lambda m, dt: m.increment_time(dt))
  247. always(self.update_vectors)
  248. def increment_time(self, dt):
  249. self.time += dt
  250. class OldVectorField(VGroup):
  251. def __init__(
  252. self,
  253. func: Callable[[float, float], Sequence[float]],
  254. coordinate_system: CoordinateSystem,
  255. step_multiple: float = 0.5,
  256. magnitude_range: Tuple[float, float] = (0, 2),
  257. color_map: str = "3b1b_colormap",
  258. # Takes in actual norm, spits out displayed norm
  259. length_func: Callable[[float], float] = lambda norm: 0.45 * sigmoid(norm),
  260. opacity: float = 1.0,
  261. vector_config: dict = dict(),
  262. **kwargs
  263. ):
  264. super().__init__(**kwargs)
  265. self.func = func
  266. self.coordinate_system = coordinate_system
  267. self.step_multiple = step_multiple
  268. self.magnitude_range = magnitude_range
  269. self.color_map = color_map
  270. self.length_func = length_func
  271. self.opacity = opacity
  272. self.vector_config = dict(vector_config)
  273. self.value_to_rgb = get_rgb_gradient_function(
  274. *self.magnitude_range, self.color_map,
  275. )
  276. samples = get_sample_points_from_coordinate_system(
  277. coordinate_system, self.step_multiple
  278. )
  279. self.add(*(
  280. self.get_vector(coords)
  281. for coords in samples
  282. ))
  283. def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow:
  284. vector_config = merge_dicts_recursively(
  285. self.vector_config,
  286. kwargs
  287. )
  288. output = np.array(self.func(*coords))
  289. norm = get_norm(output)
  290. if norm > 0:
  291. output *= self.length_func(norm) / norm
  292. origin = self.coordinate_system.get_origin()
  293. _input = self.coordinate_system.c2p(*coords)
  294. _output = self.coordinate_system.c2p(*output)
  295. vect = Arrow(
  296. origin, _output, buff=0,
  297. **vector_config
  298. )
  299. vect.shift(_input - origin)
  300. vect.set_color(
  301. rgb_to_color(self.value_to_rgb(norm)),
  302. opacity=self.opacity,
  303. )
  304. return vect
  305. class StreamLines(VGroup):
  306. def __init__(
  307. self,
  308. func: Callable[[float, float], Sequence[float]],
  309. coordinate_system: CoordinateSystem,
  310. step_multiple: float = 0.5,
  311. n_repeats: int = 1,
  312. noise_factor: float | None = None,
  313. # Config for drawing lines
  314. dt: float = 0.05,
  315. arc_len: float = 3,
  316. max_time_steps: int = 200,
  317. n_samples_per_line: int = 10,
  318. cutoff_norm: float = 15,
  319. # Style info
  320. stroke_width: float = 1.0,
  321. stroke_color: ManimColor = WHITE,
  322. stroke_opacity: float = 1,
  323. color_by_magnitude: bool = True,
  324. magnitude_range: Tuple[float, float] = (0, 2.0),
  325. taper_stroke_width: bool = False,
  326. color_map: str = "3b1b_colormap",
  327. **kwargs
  328. ):
  329. super().__init__(**kwargs)
  330. self.func = func
  331. self.coordinate_system = coordinate_system
  332. self.step_multiple = step_multiple
  333. self.n_repeats = n_repeats
  334. self.noise_factor = noise_factor
  335. self.dt = dt
  336. self.arc_len = arc_len
  337. self.max_time_steps = max_time_steps
  338. self.n_samples_per_line = n_samples_per_line
  339. self.cutoff_norm = cutoff_norm
  340. self.stroke_width = stroke_width
  341. self.stroke_color = stroke_color
  342. self.stroke_opacity = stroke_opacity
  343. self.color_by_magnitude = color_by_magnitude
  344. self.magnitude_range = magnitude_range
  345. self.taper_stroke_width = taper_stroke_width
  346. self.color_map = color_map
  347. self.draw_lines()
  348. self.init_style()
  349. def point_func(self, point: Vect3) -> Vect3:
  350. in_coords = self.coordinate_system.p2c(point)
  351. out_coords = self.func(*in_coords)
  352. return self.coordinate_system.c2p(*out_coords)
  353. def draw_lines(self) -> None:
  354. lines = []
  355. origin = self.coordinate_system.get_origin()
  356. for point in self.get_start_points():
  357. points = [point]
  358. total_arc_len = 0
  359. time = 0
  360. for x in range(self.max_time_steps):
  361. time += self.dt
  362. last_point = points[-1]
  363. new_point = last_point + self.dt * (self.point_func(last_point) - origin)
  364. points.append(new_point)
  365. total_arc_len += get_norm(new_point - last_point)
  366. if get_norm(last_point) > self.cutoff_norm:
  367. break
  368. if total_arc_len > self.arc_len:
  369. break
  370. line = VMobject()
  371. line.virtual_time = time
  372. step = max(1, int(len(points) / self.n_samples_per_line))
  373. line.set_points_as_corners(points[::step])
  374. line.make_smooth(approx=True)
  375. lines.append(line)
  376. self.set_submobjects(lines)
  377. def get_start_points(self) -> Vect3Array:
  378. cs = self.coordinate_system
  379. sample_coords = get_sample_points_from_coordinate_system(
  380. cs, self.step_multiple,
  381. )
  382. noise_factor = self.noise_factor
  383. if noise_factor is None:
  384. noise_factor = cs.x_range[2] * self.step_multiple * 0.5
  385. return np.array([
  386. cs.c2p(*coords) + noise_factor * np.random.random(3)
  387. for n in range(self.n_repeats)
  388. for coords in sample_coords
  389. ])
  390. def init_style(self) -> None:
  391. if self.color_by_magnitude:
  392. values_to_rgbs = get_vectorized_rgb_gradient_function(
  393. *self.magnitude_range, self.color_map,
  394. )
  395. cs = self.coordinate_system
  396. for line in self.submobjects:
  397. norms = [
  398. get_norm(self.func(*cs.p2c(point)))
  399. for point in line.get_points()
  400. ]
  401. rgbs = values_to_rgbs(norms)
  402. rgbas = np.zeros((len(rgbs), 4))
  403. rgbas[:, :3] = rgbs
  404. rgbas[:, 3] = self.stroke_opacity
  405. line.set_rgba_array(rgbas, "stroke_rgba")
  406. else:
  407. self.set_stroke(self.stroke_color, opacity=self.stroke_opacity)
  408. if self.taper_stroke_width:
  409. width = [0, self.stroke_width, 0]
  410. else:
  411. width = self.stroke_width
  412. self.set_stroke(width=width)
  413. class AnimatedStreamLines(VGroup):
  414. def __init__(
  415. self,
  416. stream_lines: StreamLines,
  417. lag_range: float = 4,
  418. line_anim_config: dict = dict(
  419. rate_func=linear,
  420. time_width=1.0,
  421. ),
  422. **kwargs
  423. ):
  424. super().__init__(**kwargs)
  425. self.stream_lines = stream_lines
  426. for line in stream_lines:
  427. line.anim = VShowPassingFlash(
  428. line,
  429. run_time=line.virtual_time,
  430. **line_anim_config,
  431. )
  432. line.anim.begin()
  433. line.time = -lag_range * np.random.random()
  434. self.add(line.anim.mobject)
  435. self.add_updater(lambda m, dt: m.update(dt))
  436. def update(self, dt: float) -> None:
  437. stream_lines = self.stream_lines
  438. for line in stream_lines:
  439. line.time += dt
  440. adjusted_time = max(line.time, 0) % line.anim.run_time
  441. line.anim.update(adjusted_time / line.anim.run_time)