coordinate_systems.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. import numbers
  4. import numpy as np
  5. import itertools as it
  6. from manimlib.constants import BLACK, BLUE, BLUE_D, BLUE_E, GREEN, GREY_A, WHITE, RED
  7. from manimlib.constants import DEGREES, PI
  8. from manimlib.constants import DL, UL, DOWN, DR, LEFT, ORIGIN, OUT, RIGHT, UP
  9. from manimlib.constants import FRAME_X_RADIUS, FRAME_Y_RADIUS
  10. from manimlib.constants import MED_SMALL_BUFF, SMALL_BUFF
  11. from manimlib.mobject.functions import ParametricCurve
  12. from manimlib.mobject.geometry import Arrow
  13. from manimlib.mobject.geometry import DashedLine
  14. from manimlib.mobject.geometry import Line
  15. from manimlib.mobject.geometry import Rectangle
  16. from manimlib.mobject.number_line import NumberLine
  17. from manimlib.mobject.svg.tex_mobject import Tex
  18. from manimlib.mobject.types.dot_cloud import DotCloud
  19. from manimlib.mobject.types.surface import ParametricSurface
  20. from manimlib.mobject.types.vectorized_mobject import VGroup
  21. from manimlib.mobject.types.vectorized_mobject import VMobject
  22. from manimlib.utils.bezier import inverse_interpolate
  23. from manimlib.utils.dict_ops import merge_dicts_recursively
  24. from manimlib.utils.simple_functions import binary_search
  25. from manimlib.utils.space_ops import angle_of_vector
  26. from manimlib.utils.space_ops import get_norm
  27. from manimlib.utils.space_ops import rotate_vector
  28. from manimlib.utils.space_ops import normalize
  29. from typing import TYPE_CHECKING
  30. if TYPE_CHECKING:
  31. from typing import Callable, Iterable, Sequence, Type, TypeVar, Optional
  32. from manimlib.mobject.mobject import Mobject
  33. from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier, Self
  34. T = TypeVar("T", bound=Mobject)
  35. EPSILON = 1e-8
  36. DEFAULT_X_RANGE = (-8.0, 8.0, 1.0)
  37. DEFAULT_Y_RANGE = (-4.0, 4.0, 1.0)
  38. class CoordinateSystem(ABC):
  39. """
  40. Abstract class for Axes and NumberPlane
  41. """
  42. dimension: int = 2
  43. def __init__(
  44. self,
  45. x_range: RangeSpecifier = DEFAULT_X_RANGE,
  46. y_range: RangeSpecifier = DEFAULT_Y_RANGE,
  47. num_sampled_graph_points_per_tick: int = 5,
  48. ):
  49. self.x_range = x_range
  50. self.y_range = y_range
  51. self.num_sampled_graph_points_per_tick = num_sampled_graph_points_per_tick
  52. @abstractmethod
  53. def coords_to_point(self, *coords: float | VectN) -> Vect3 | Vect3Array:
  54. raise Exception("Not implemented")
  55. @abstractmethod
  56. def point_to_coords(self, point: Vect3 | Vect3Array) -> tuple[float | VectN, ...]:
  57. raise Exception("Not implemented")
  58. def c2p(self, *coords: float) -> Vect3 | Vect3Array:
  59. """Abbreviation for coords_to_point"""
  60. return self.coords_to_point(*coords)
  61. def p2c(self, point: Vect3) -> tuple[float | VectN, ...]:
  62. """Abbreviation for point_to_coords"""
  63. return self.point_to_coords(point)
  64. def get_origin(self) -> Vect3:
  65. return self.c2p(*[0] * self.dimension)
  66. @abstractmethod
  67. def get_axes(self) -> VGroup:
  68. raise Exception("Not implemented")
  69. @abstractmethod
  70. def get_all_ranges(self) -> list[np.ndarray]:
  71. raise Exception("Not implemented")
  72. def get_axis(self, index: int) -> NumberLine:
  73. return self.get_axes()[index]
  74. def get_x_axis(self) -> NumberLine:
  75. return self.get_axis(0)
  76. def get_y_axis(self) -> NumberLine:
  77. return self.get_axis(1)
  78. def get_z_axis(self) -> NumberLine:
  79. return self.get_axis(2)
  80. def get_x_axis_label(
  81. self,
  82. label_tex: str,
  83. edge: Vect3 = RIGHT,
  84. direction: Vect3 = DL,
  85. **kwargs
  86. ) -> Tex:
  87. return self.get_axis_label(
  88. label_tex, self.get_x_axis(),
  89. edge, direction, **kwargs
  90. )
  91. def get_y_axis_label(
  92. self,
  93. label_tex: str,
  94. edge: Vect3 = UP,
  95. direction: Vect3 = DR,
  96. **kwargs
  97. ) -> Tex:
  98. return self.get_axis_label(
  99. label_tex, self.get_y_axis(),
  100. edge, direction, **kwargs
  101. )
  102. def get_axis_label(
  103. self,
  104. label_tex: str,
  105. axis: Vect3,
  106. edge: Vect3,
  107. direction: Vect3,
  108. buff: float = MED_SMALL_BUFF,
  109. ensure_on_screen: bool = False
  110. ) -> Tex:
  111. label = Tex(label_tex)
  112. label.next_to(
  113. axis.get_edge_center(edge), direction,
  114. buff=buff
  115. )
  116. if ensure_on_screen:
  117. label.shift_onto_screen(buff=MED_SMALL_BUFF)
  118. return label
  119. def get_axis_labels(
  120. self,
  121. x_label_tex: str = "x",
  122. y_label_tex: str = "y"
  123. ) -> VGroup:
  124. self.axis_labels = VGroup(
  125. self.get_x_axis_label(x_label_tex),
  126. self.get_y_axis_label(y_label_tex),
  127. )
  128. return self.axis_labels
  129. def get_line_from_axis_to_point(
  130. self,
  131. index: int,
  132. point: Vect3,
  133. line_func: Type[T] = DashedLine,
  134. color: ManimColor = GREY_A,
  135. stroke_width: float = 2
  136. ) -> T:
  137. axis = self.get_axis(index)
  138. line = line_func(axis.get_projection(point), point)
  139. line.set_stroke(color, stroke_width)
  140. return line
  141. def get_v_line(self, point: Vect3, **kwargs):
  142. return self.get_line_from_axis_to_point(0, point, **kwargs)
  143. def get_h_line(self, point: Vect3, **kwargs):
  144. return self.get_line_from_axis_to_point(1, point, **kwargs)
  145. # Useful for graphing
  146. def get_graph(
  147. self,
  148. function: Callable[[float], float],
  149. x_range: Sequence[float] | None = None,
  150. bind: bool = False,
  151. **kwargs
  152. ) -> ParametricCurve:
  153. x_range = x_range or self.x_range
  154. t_range = np.ones(3)
  155. t_range[:len(x_range)] = x_range
  156. # For axes, the third coordinate of x_range indicates
  157. # tick frequency. But for functions, it indicates a
  158. # sample frequency
  159. t_range[2] /= self.num_sampled_graph_points_per_tick
  160. def parametric_function(t: float) -> Vect3:
  161. return self.c2p(t, function(t))
  162. graph = ParametricCurve(
  163. parametric_function,
  164. t_range=tuple(t_range),
  165. **kwargs
  166. )
  167. graph.underlying_function = function
  168. graph.x_range = x_range
  169. if bind:
  170. self.bind_graph_to_func(graph, function)
  171. return graph
  172. def get_parametric_curve(
  173. self,
  174. function: Callable[[float], Vect3],
  175. **kwargs
  176. ) -> ParametricCurve:
  177. dim = self.dimension
  178. graph = ParametricCurve(
  179. lambda t: self.coords_to_point(*function(t)[:dim]),
  180. **kwargs
  181. )
  182. graph.underlying_function = function
  183. return graph
  184. def input_to_graph_point(
  185. self,
  186. x: float,
  187. graph: ParametricCurve
  188. ) -> Vect3 | None:
  189. if hasattr(graph, "underlying_function"):
  190. return self.coords_to_point(x, graph.underlying_function(x))
  191. else:
  192. alpha = binary_search(
  193. function=lambda a: self.point_to_coords(
  194. graph.quick_point_from_proportion(a)
  195. )[0],
  196. target=x,
  197. lower_bound=self.x_range[0],
  198. upper_bound=self.x_range[1],
  199. )
  200. if alpha is not None:
  201. return graph.quick_point_from_proportion(alpha)
  202. else:
  203. return None
  204. def i2gp(self, x: float, graph: ParametricCurve) -> Vect3 | None:
  205. """
  206. Alias for input_to_graph_point
  207. """
  208. return self.input_to_graph_point(x, graph)
  209. def bind_graph_to_func(
  210. self,
  211. graph: VMobject,
  212. func: Callable[[VectN], VectN],
  213. jagged: bool = False,
  214. get_discontinuities: Optional[Callable[[], Vect3]] = None
  215. ) -> VMobject:
  216. """
  217. Use for graphing functions which might change over time, or change with
  218. conditions
  219. """
  220. x_values = np.array([self.x_axis.p2n(p) for p in graph.get_points()])
  221. def get_graph_points():
  222. xs = x_values
  223. if get_discontinuities:
  224. ds = get_discontinuities()
  225. ep = 1e-6
  226. added_xs = it.chain(*((d - ep, d + ep) for d in ds))
  227. xs[:] = sorted([*x_values, *added_xs])[:len(x_values)]
  228. return self.c2p(xs, func(xs))
  229. graph.add_updater(
  230. lambda g: g.set_points_as_corners(get_graph_points())
  231. )
  232. if not jagged:
  233. graph.add_updater(lambda g: g.make_smooth(approx=True))
  234. return graph
  235. def get_graph_label(
  236. self,
  237. graph: ParametricCurve,
  238. label: str | Mobject = "f(x)",
  239. x: float | None = None,
  240. direction: Vect3 = RIGHT,
  241. buff: float = MED_SMALL_BUFF,
  242. color: ManimColor | None = None
  243. ) -> Tex | Mobject:
  244. if isinstance(label, str):
  245. label = Tex(label)
  246. if color is None:
  247. label.match_color(graph)
  248. if x is None:
  249. # Searching from the right, find a point
  250. # whose y value is in bounds
  251. max_y = FRAME_Y_RADIUS - label.get_height()
  252. max_x = FRAME_X_RADIUS - label.get_width()
  253. for x0 in np.arange(*self.x_range)[::-1]:
  254. pt = self.i2gp(x0, graph)
  255. if abs(pt[0]) < max_x and abs(pt[1]) < max_y:
  256. x = x0
  257. break
  258. if x is None:
  259. x = self.x_range[1]
  260. point = self.input_to_graph_point(x, graph)
  261. angle = self.angle_of_tangent(x, graph)
  262. normal = rotate_vector(RIGHT, angle + 90 * DEGREES)
  263. if normal[1] < 0:
  264. normal *= -1
  265. label.next_to(point, normal, buff=buff)
  266. label.shift_onto_screen()
  267. return label
  268. def get_v_line_to_graph(self, x: float, graph: ParametricCurve, **kwargs):
  269. return self.get_v_line(self.i2gp(x, graph), **kwargs)
  270. def get_h_line_to_graph(self, x: float, graph: ParametricCurve, **kwargs):
  271. return self.get_h_line(self.i2gp(x, graph), **kwargs)
  272. def get_scatterplot(self,
  273. x_values: Vect3Array,
  274. y_values: Vect3Array,
  275. **dot_config):
  276. return DotCloud(self.c2p(x_values, y_values), **dot_config)
  277. # For calculus
  278. def angle_of_tangent(
  279. self,
  280. x: float,
  281. graph: ParametricCurve,
  282. dx: float = EPSILON
  283. ) -> float:
  284. p0 = self.input_to_graph_point(x, graph)
  285. p1 = self.input_to_graph_point(x + dx, graph)
  286. return angle_of_vector(p1 - p0)
  287. def slope_of_tangent(
  288. self,
  289. x: float,
  290. graph: ParametricCurve,
  291. **kwargs
  292. ) -> float:
  293. return np.tan(self.angle_of_tangent(x, graph, **kwargs))
  294. def get_tangent_line(
  295. self,
  296. x: float,
  297. graph: ParametricCurve,
  298. length: float = 5,
  299. line_func: Type[T] = Line
  300. ) -> T:
  301. line = line_func(LEFT, RIGHT)
  302. line.set_width(length)
  303. line.rotate(self.angle_of_tangent(x, graph))
  304. line.move_to(self.input_to_graph_point(x, graph))
  305. return line
  306. def get_riemann_rectangles(
  307. self,
  308. graph: ParametricCurve,
  309. x_range: Sequence[float] = None,
  310. dx: float | None = None,
  311. input_sample_type: str = "left",
  312. stroke_width: float = 1,
  313. stroke_color: ManimColor = BLACK,
  314. fill_opacity: float = 1,
  315. colors: Iterable[ManimColor] = (BLUE, GREEN),
  316. negative_color: ManimColor = RED,
  317. stroke_background: bool = True,
  318. show_signed_area: bool = True
  319. ) -> VGroup:
  320. if x_range is None:
  321. x_range = self.x_range[:2]
  322. if dx is None:
  323. dx = self.x_range[2]
  324. if len(x_range) < 3:
  325. x_range = [*x_range, dx]
  326. rects = []
  327. x_range[1] = x_range[1] + dx
  328. xs = np.arange(*x_range)
  329. for x0, x1 in zip(xs, xs[1:]):
  330. if input_sample_type == "left":
  331. sample = x0
  332. elif input_sample_type == "right":
  333. sample = x1
  334. elif input_sample_type == "center":
  335. sample = 0.5 * x0 + 0.5 * x1
  336. else:
  337. raise Exception("Invalid input sample type")
  338. height_vect = self.i2gp(sample, graph) - self.c2p(sample, 0)
  339. rect = Rectangle(
  340. width=self.x_axis.n2p(x1)[0] - self.x_axis.n2p(x0)[0],
  341. height=get_norm(height_vect),
  342. )
  343. rect.positive = height_vect[1] > 0
  344. rect.move_to(self.c2p(x0, 0), DL if rect.positive else UL)
  345. rects.append(rect)
  346. result = VGroup(*rects)
  347. result.set_submobject_colors_by_gradient(*colors)
  348. result.set_style(
  349. stroke_width=stroke_width,
  350. stroke_color=stroke_color,
  351. fill_opacity=fill_opacity,
  352. stroke_background=stroke_background
  353. )
  354. for rect in result:
  355. if not rect.positive:
  356. rect.set_fill(negative_color)
  357. return result
  358. def get_area_under_graph(self, graph, x_range, fill_color=BLUE, fill_opacity=0.5):
  359. if not hasattr(graph, "x_range"):
  360. raise Exception("Argument `graph` must have attribute `x_range`")
  361. alpha_bounds = [
  362. inverse_interpolate(*graph.x_range, x)
  363. for x in x_range
  364. ]
  365. sub_graph = graph.copy()
  366. sub_graph.pointwise_become_partial(graph, *alpha_bounds)
  367. sub_graph.add_line_to(self.c2p(x_range[1], 0))
  368. sub_graph.add_line_to(self.c2p(x_range[0], 0))
  369. sub_graph.add_line_to(sub_graph.get_start())
  370. sub_graph.set_stroke(width=0)
  371. sub_graph.set_fill(fill_color, fill_opacity)
  372. return sub_graph
  373. class Axes(VGroup, CoordinateSystem):
  374. default_axis_config: dict = dict()
  375. default_x_axis_config: dict = dict()
  376. default_y_axis_config: dict = dict(line_to_number_direction=LEFT)
  377. def __init__(
  378. self,
  379. x_range: RangeSpecifier = DEFAULT_X_RANGE,
  380. y_range: RangeSpecifier = DEFAULT_Y_RANGE,
  381. axis_config: dict = dict(),
  382. x_axis_config: dict = dict(),
  383. y_axis_config: dict = dict(),
  384. height: float | None = None,
  385. width: float | None = None,
  386. unit_size: float = 1.0,
  387. **kwargs
  388. ):
  389. CoordinateSystem.__init__(self, x_range, y_range, **kwargs)
  390. kwargs.pop("num_sampled_graph_points_per_tick", None)
  391. VGroup.__init__(self, **kwargs)
  392. axis_config = dict(**axis_config, unit_size=unit_size)
  393. self.x_axis = self.create_axis(
  394. self.x_range,
  395. axis_config=merge_dicts_recursively(
  396. self.default_axis_config,
  397. self.default_x_axis_config,
  398. axis_config,
  399. x_axis_config
  400. ),
  401. length=width,
  402. )
  403. self.y_axis = self.create_axis(
  404. self.y_range,
  405. axis_config=merge_dicts_recursively(
  406. self.default_axis_config,
  407. self.default_y_axis_config,
  408. axis_config,
  409. y_axis_config
  410. ),
  411. length=height,
  412. )
  413. self.y_axis.rotate(90 * DEGREES, about_point=ORIGIN)
  414. # Add as a separate group in case various other
  415. # mobjects are added to self, as for example in
  416. # NumberPlane below
  417. self.axes = VGroup(self.x_axis, self.y_axis)
  418. self.add(*self.axes)
  419. self.center()
  420. def create_axis(
  421. self,
  422. range_terms: RangeSpecifier,
  423. axis_config: dict,
  424. length: float | None
  425. ) -> NumberLine:
  426. axis = NumberLine(range_terms, width=length, **axis_config)
  427. axis.shift(-axis.n2p(0))
  428. return axis
  429. def coords_to_point(self, *coords: float | VectN) -> Vect3 | Vect3Array:
  430. origin = self.x_axis.number_to_point(0)
  431. return origin + sum(
  432. axis.number_to_point(coord) - origin
  433. for axis, coord in zip(self.get_axes(), coords)
  434. )
  435. def point_to_coords(self, point: Vect3 | Vect3Array) -> tuple[float | VectN, ...]:
  436. return tuple([
  437. axis.point_to_number(point)
  438. for axis in self.get_axes()
  439. ])
  440. def get_axes(self) -> VGroup:
  441. return self.axes
  442. def get_all_ranges(self) -> list[Sequence[float]]:
  443. return [self.x_range, self.y_range]
  444. def add_coordinate_labels(
  445. self,
  446. x_values: Iterable[float] | None = None,
  447. y_values: Iterable[float] | None = None,
  448. excluding: Iterable[float] = [0],
  449. **kwargs
  450. ) -> VGroup:
  451. axes = self.get_axes()
  452. self.coordinate_labels = VGroup()
  453. for axis, values in zip(axes, [x_values, y_values]):
  454. labels = axis.add_numbers(values, excluding=excluding, **kwargs)
  455. self.coordinate_labels.add(labels)
  456. return self.coordinate_labels
  457. class ThreeDAxes(Axes):
  458. dimension: int = 3
  459. default_z_axis_config: dict = dict()
  460. def __init__(
  461. self,
  462. x_range: RangeSpecifier = (-6.0, 6.0, 1.0),
  463. y_range: RangeSpecifier = (-5.0, 5.0, 1.0),
  464. z_range: RangeSpecifier = (-4.0, 4.0, 1.0),
  465. z_axis_config: dict = dict(),
  466. z_normal: Vect3 = DOWN,
  467. depth: float | None = None,
  468. **kwargs
  469. ):
  470. Axes.__init__(self, x_range, y_range, **kwargs)
  471. self.z_range = z_range
  472. self.z_axis = self.create_axis(
  473. self.z_range,
  474. axis_config=merge_dicts_recursively(
  475. self.default_axis_config,
  476. self.default_z_axis_config,
  477. kwargs.get("axis_config", {}),
  478. z_axis_config
  479. ),
  480. length=depth,
  481. )
  482. self.z_axis.rotate(-PI / 2, UP, about_point=ORIGIN)
  483. self.z_axis.rotate(
  484. angle_of_vector(z_normal), OUT,
  485. about_point=ORIGIN
  486. )
  487. self.z_axis.shift(self.x_axis.n2p(0))
  488. self.axes.add(self.z_axis)
  489. self.add(self.z_axis)
  490. def get_all_ranges(self) -> list[Sequence[float]]:
  491. return [self.x_range, self.y_range, self.z_range]
  492. def add_axis_labels(self, x_tex="x", y_tex="y", z_tex="z", font_size=24, buff=0.2):
  493. x_label, y_label, z_label = labels = VGroup(*(
  494. Tex(tex, font_size=font_size)
  495. for tex in [x_tex, y_tex, z_tex]
  496. ))
  497. z_label.rotate(PI / 2, RIGHT)
  498. for label, axis in zip(labels, self):
  499. label.next_to(axis, normalize(np.round(axis.get_vector()), 2), buff=buff)
  500. axis.add(label)
  501. self.axis_labels = labels
  502. def get_graph(
  503. self,
  504. func,
  505. color=BLUE_E,
  506. opacity=0.9,
  507. u_range=None,
  508. v_range=None,
  509. **kwargs
  510. ) -> ParametricSurface:
  511. xu = self.x_axis.get_unit_size()
  512. yu = self.y_axis.get_unit_size()
  513. zu = self.z_axis.get_unit_size()
  514. x0, y0, z0 = self.get_origin()
  515. u_range = u_range or self.x_range[:2]
  516. v_range = v_range or self.y_range[:2]
  517. return ParametricSurface(
  518. lambda u, v: [xu * u + x0, yu * v + y0, zu * func(u, v) + z0],
  519. u_range=u_range,
  520. v_range=v_range,
  521. color=color,
  522. opacity=opacity,
  523. **kwargs
  524. )
  525. def get_parametric_surface(
  526. self,
  527. func,
  528. color=BLUE_E,
  529. opacity=0.9,
  530. **kwargs
  531. ) -> ParametricSurface:
  532. surface = ParametricSurface(func, color=color, opacity=opacity, **kwargs)
  533. axes = [self.x_axis, self.y_axis, self.z_axis]
  534. for dim, axis in zip(range(3), axes):
  535. surface.stretch(axis.get_unit_size(), dim, about_point=ORIGIN)
  536. surface.shift(self.get_origin())
  537. return surface
  538. class NumberPlane(Axes):
  539. default_axis_config: dict = dict(
  540. stroke_color=WHITE,
  541. stroke_width=2,
  542. include_ticks=False,
  543. include_tip=False,
  544. line_to_number_buff=SMALL_BUFF,
  545. line_to_number_direction=DL,
  546. )
  547. default_y_axis_config: dict = dict(
  548. line_to_number_direction=DL,
  549. )
  550. def __init__(
  551. self,
  552. x_range: RangeSpecifier = (-8.0, 8.0, 1.0),
  553. y_range: RangeSpecifier = (-4.0, 4.0, 1.0),
  554. background_line_style: dict = dict(
  555. stroke_color=BLUE_D,
  556. stroke_width=2,
  557. stroke_opacity=1,
  558. ),
  559. # Defaults to a faded version of line_config
  560. faded_line_style: dict = dict(),
  561. faded_line_ratio: int = 4,
  562. make_smooth_after_applying_functions: bool = True,
  563. **kwargs
  564. ):
  565. super().__init__(x_range, y_range, **kwargs)
  566. self.background_line_style = dict(background_line_style)
  567. self.faded_line_style = dict(faded_line_style)
  568. self.faded_line_ratio = faded_line_ratio
  569. self.make_smooth_after_applying_functions = make_smooth_after_applying_functions
  570. self.init_background_lines()
  571. def init_background_lines(self) -> None:
  572. if not self.faded_line_style:
  573. style = dict(self.background_line_style)
  574. # For anything numerical, like stroke_width
  575. # and stroke_opacity, chop it in half
  576. for key in style:
  577. if isinstance(style[key], numbers.Number):
  578. style[key] *= 0.5
  579. self.faded_line_style = style
  580. self.background_lines, self.faded_lines = self.get_lines()
  581. self.background_lines.set_style(**self.background_line_style)
  582. self.faded_lines.set_style(**self.faded_line_style)
  583. self.add_to_back(
  584. self.faded_lines,
  585. self.background_lines,
  586. )
  587. def get_lines(self) -> tuple[VGroup, VGroup]:
  588. x_axis = self.get_x_axis()
  589. y_axis = self.get_y_axis()
  590. x_lines1, x_lines2 = self.get_lines_parallel_to_axis(x_axis, y_axis)
  591. y_lines1, y_lines2 = self.get_lines_parallel_to_axis(y_axis, x_axis)
  592. lines1 = VGroup(*x_lines1, *y_lines1)
  593. lines2 = VGroup(*x_lines2, *y_lines2)
  594. return lines1, lines2
  595. def get_lines_parallel_to_axis(
  596. self,
  597. axis1: NumberLine,
  598. axis2: NumberLine
  599. ) -> tuple[VGroup, VGroup]:
  600. freq = axis2.x_step
  601. ratio = self.faded_line_ratio
  602. line = Line(axis1.get_start(), axis1.get_end())
  603. dense_freq = (1 + ratio)
  604. step = (1 / dense_freq) * freq
  605. lines1 = VGroup()
  606. lines2 = VGroup()
  607. inputs = np.arange(axis2.x_min, axis2.x_max + step, step)
  608. for i, x in enumerate(inputs):
  609. if abs(x) < 1e-8:
  610. continue
  611. new_line = line.copy()
  612. new_line.shift(axis2.n2p(x) - axis2.n2p(0))
  613. if i % (1 + ratio) == 0:
  614. lines1.add(new_line)
  615. else:
  616. lines2.add(new_line)
  617. return lines1, lines2
  618. def get_x_unit_size(self) -> float:
  619. return self.get_x_axis().get_unit_size()
  620. def get_y_unit_size(self) -> list:
  621. return self.get_x_axis().get_unit_size()
  622. def get_axes(self) -> VGroup:
  623. return self.axes
  624. def get_vector(self, coords: Iterable[float], **kwargs) -> Arrow:
  625. kwargs["buff"] = 0
  626. return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs)
  627. def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50) -> Self:
  628. for mob in self.family_members_with_points():
  629. num_curves = mob.get_num_curves()
  630. if num_inserted_curves > num_curves:
  631. mob.insert_n_curves(num_inserted_curves - num_curves)
  632. mob.make_smooth_after_applying_functions = True
  633. return self
  634. class ComplexPlane(NumberPlane):
  635. def number_to_point(self, number: complex | float) -> Vect3:
  636. number = complex(number)
  637. return self.coords_to_point(number.real, number.imag)
  638. def n2p(self, number: complex | float) -> Vect3:
  639. return self.number_to_point(number)
  640. def point_to_number(self, point: Vect3) -> complex:
  641. x, y = self.point_to_coords(point)
  642. return complex(x, y)
  643. def p2n(self, point: Vect3) -> complex:
  644. return self.point_to_number(point)
  645. def get_default_coordinate_values(
  646. self,
  647. skip_first: bool = True
  648. ) -> list[complex]:
  649. x_numbers = self.get_x_axis().get_tick_range()[1:]
  650. y_numbers = self.get_y_axis().get_tick_range()[1:]
  651. y_numbers = [complex(0, y) for y in y_numbers if y != 0]
  652. return [*x_numbers, *y_numbers]
  653. def add_coordinate_labels(
  654. self,
  655. numbers: list[complex] | None = None,
  656. skip_first: bool = True,
  657. font_size: int = 36,
  658. **kwargs
  659. ) -> Self:
  660. if numbers is None:
  661. numbers = self.get_default_coordinate_values(skip_first)
  662. self.coordinate_labels = VGroup()
  663. for number in numbers:
  664. z = complex(number)
  665. if abs(z.imag) > abs(z.real):
  666. axis = self.get_y_axis()
  667. value = z.imag
  668. kwargs["unit_tex"] = "i"
  669. else:
  670. axis = self.get_x_axis()
  671. value = z.real
  672. number_mob = axis.get_number_mobject(value, font_size=font_size, **kwargs)
  673. # For -i, remove the "1"
  674. if z.imag == -1:
  675. number_mob.remove(number_mob[1])
  676. number_mob[0].next_to(
  677. number_mob[1], LEFT,
  678. buff=number_mob[0].get_width() / 4
  679. )
  680. self.coordinate_labels.add(number_mob)
  681. self.add(self.coordinate_labels)
  682. return self