number_line.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from __future__ import annotations
  2. import numpy as np
  3. from manimlib.constants import DOWN, LEFT, RIGHT, UP
  4. from manimlib.constants import GREY_B
  5. from manimlib.constants import MED_SMALL_BUFF
  6. from manimlib.mobject.geometry import Line
  7. from manimlib.mobject.numbers import DecimalNumber
  8. from manimlib.mobject.types.vectorized_mobject import VGroup
  9. from manimlib.utils.bezier import interpolate
  10. from manimlib.utils.bezier import outer_interpolate
  11. from manimlib.utils.dict_ops import merge_dicts_recursively
  12. from manimlib.utils.simple_functions import fdiv
  13. from typing import TYPE_CHECKING
  14. if TYPE_CHECKING:
  15. from typing import Iterable, Optional
  16. from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier
  17. class NumberLine(Line):
  18. def __init__(
  19. self,
  20. x_range: RangeSpecifier = (-8, 8, 1),
  21. color: ManimColor = GREY_B,
  22. stroke_width: float = 2.0,
  23. # How big is one one unit of this number line in terms of absolute spacial distance
  24. unit_size: float = 1.0,
  25. width: Optional[float] = None,
  26. include_ticks: bool = True,
  27. tick_size: float = 0.1,
  28. longer_tick_multiple: float = 1.5,
  29. tick_offset: float = 0.0,
  30. # Change name
  31. big_tick_spacing: Optional[float] = None,
  32. big_tick_numbers: list[float] = [],
  33. include_numbers: bool = False,
  34. line_to_number_direction: Vect3 = DOWN,
  35. line_to_number_buff: float = MED_SMALL_BUFF,
  36. include_tip: bool = False,
  37. tip_config: dict = dict(
  38. width=0.25,
  39. length=0.25,
  40. ),
  41. decimal_number_config: dict = dict(
  42. num_decimal_places=0,
  43. font_size=36,
  44. ),
  45. numbers_to_exclude: list | None = None,
  46. **kwargs,
  47. ):
  48. self.x_range = x_range
  49. self.tick_size = tick_size
  50. self.longer_tick_multiple = longer_tick_multiple
  51. self.tick_offset = tick_offset
  52. if big_tick_spacing is not None:
  53. self.big_tick_numbers = np.arange(
  54. x_range[0],
  55. x_range[1] + big_tick_spacing,
  56. big_tick_spacing,
  57. )
  58. else:
  59. self.big_tick_numbers = list(big_tick_numbers)
  60. self.line_to_number_direction = line_to_number_direction
  61. self.line_to_number_buff = line_to_number_buff
  62. self.include_tip = include_tip
  63. self.tip_config = dict(tip_config)
  64. self.decimal_number_config = dict(decimal_number_config)
  65. self.numbers_to_exclude = numbers_to_exclude
  66. self.x_min, self.x_max = x_range[:2]
  67. self.x_step = 1 if len(x_range) == 2 else x_range[2]
  68. super().__init__(
  69. self.x_min * RIGHT, self.x_max * RIGHT,
  70. color=color,
  71. stroke_width=stroke_width,
  72. **kwargs
  73. )
  74. if width:
  75. self.set_width(width)
  76. else:
  77. self.scale(unit_size)
  78. self.center()
  79. if include_tip:
  80. self.add_tip()
  81. self.tip.set_stroke(
  82. self.stroke_color,
  83. self.stroke_width,
  84. )
  85. if include_ticks:
  86. self.add_ticks()
  87. if include_numbers:
  88. self.add_numbers(excluding=self.numbers_to_exclude)
  89. def get_tick_range(self) -> np.ndarray:
  90. if self.include_tip:
  91. x_max = self.x_max
  92. else:
  93. x_max = self.x_max + self.x_step
  94. result = np.arange(self.x_min, x_max, self.x_step)
  95. return result[result <= self.x_max]
  96. def add_ticks(self) -> None:
  97. ticks = VGroup()
  98. for x in self.get_tick_range():
  99. size = self.tick_size
  100. if np.isclose(self.big_tick_numbers, x).any():
  101. size *= self.longer_tick_multiple
  102. ticks.add(self.get_tick(x, size))
  103. self.add(ticks)
  104. self.ticks = ticks
  105. def get_tick(self, x: float, size: float | None = None) -> Line:
  106. if size is None:
  107. size = self.tick_size
  108. result = Line(size * DOWN, size * UP)
  109. result.rotate(self.get_angle())
  110. result.move_to(self.number_to_point(x))
  111. result.match_style(self)
  112. return result
  113. def get_tick_marks(self) -> VGroup:
  114. return self.ticks
  115. def number_to_point(self, number: float | VectN) -> Vect3 | Vect3Array:
  116. start = self.get_points()[0]
  117. end = self.get_points()[-1]
  118. alpha = (number - self.x_min) / (self.x_max - self.x_min)
  119. return outer_interpolate(start, end, alpha)
  120. def point_to_number(self, point: Vect3 | Vect3Array) -> float | VectN:
  121. start = self.get_points()[0]
  122. end = self.get_points()[-1]
  123. vect = end - start
  124. proportion = fdiv(
  125. np.dot(point - start, vect),
  126. np.dot(end - start, vect),
  127. )
  128. return interpolate(self.x_min, self.x_max, proportion)
  129. def n2p(self, number: float | VectN) -> Vect3 | Vect3Array:
  130. """Abbreviation for number_to_point"""
  131. return self.number_to_point(number)
  132. def p2n(self, point: Vect3 | Vect3Array) -> float | VectN:
  133. """Abbreviation for point_to_number"""
  134. return self.point_to_number(point)
  135. def get_unit_size(self) -> float:
  136. return self.get_length() / (self.x_max - self.x_min)
  137. def get_number_mobject(
  138. self,
  139. x: float,
  140. direction: Vect3 | None = None,
  141. buff: float | None = None,
  142. unit: float = 1.0,
  143. unit_tex: str = "",
  144. **number_config
  145. ) -> DecimalNumber:
  146. number_config = merge_dicts_recursively(
  147. self.decimal_number_config, number_config,
  148. )
  149. if direction is None:
  150. direction = self.line_to_number_direction
  151. if buff is None:
  152. buff = self.line_to_number_buff
  153. if unit_tex:
  154. number_config["unit"] = unit_tex
  155. num_mob = DecimalNumber(x / unit, **number_config)
  156. num_mob.next_to(
  157. self.number_to_point(x),
  158. direction=direction,
  159. buff=buff
  160. )
  161. if x < 0 and direction[0] == 0:
  162. # Align without the minus sign
  163. num_mob.shift(num_mob[0].get_width() * LEFT / 2)
  164. if x == unit and unit_tex:
  165. center = num_mob.get_center()
  166. num_mob.remove(num_mob[0])
  167. num_mob.move_to(center)
  168. return num_mob
  169. def add_numbers(
  170. self,
  171. x_values: Iterable[float] | None = None,
  172. excluding: Iterable[float] | None = None,
  173. font_size: int = 24,
  174. **kwargs
  175. ) -> VGroup:
  176. if x_values is None:
  177. x_values = self.get_tick_range()
  178. kwargs["font_size"] = font_size
  179. if excluding is None:
  180. excluding = self.numbers_to_exclude
  181. numbers = VGroup()
  182. for x in x_values:
  183. if excluding is not None and x in excluding:
  184. continue
  185. numbers.add(self.get_number_mobject(x, **kwargs))
  186. self.add(numbers)
  187. self.numbers = numbers
  188. return numbers
  189. class UnitInterval(NumberLine):
  190. def __init__(
  191. self,
  192. x_range: RangeSpecifier = (0, 1, 0.1),
  193. unit_size: float = 10,
  194. big_tick_numbers: list[float] = [0, 1],
  195. decimal_number_config: dict = dict(
  196. num_decimal_places=1,
  197. )
  198. ):
  199. super().__init__(
  200. x_range=x_range,
  201. unit_size=unit_size,
  202. big_tick_numbers=big_tick_numbers,
  203. decimal_number_config=decimal_number_config,
  204. )