numbers.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. from __future__ import annotations
  2. from functools import lru_cache
  3. import numpy as np
  4. from manimlib.constants import DOWN, LEFT, RIGHT, UP
  5. from manimlib.constants import WHITE
  6. from manimlib.mobject.svg.tex_mobject import Tex
  7. from manimlib.mobject.svg.text_mobject import Text
  8. from manimlib.mobject.types.vectorized_mobject import VMobject
  9. from manimlib.utils.paths import straight_path
  10. from manimlib.utils.bezier import interpolate
  11. from typing import TYPE_CHECKING
  12. if TYPE_CHECKING:
  13. from typing import TypeVar
  14. from manimlib.typing import ManimColor, Vect3, Self
  15. T = TypeVar("T", bound=VMobject)
  16. @lru_cache()
  17. def char_to_cahced_mob(char: str, **text_config):
  18. return Text(char, **text_config)
  19. class DecimalNumber(VMobject):
  20. def __init__(
  21. self,
  22. number: float | complex = 0,
  23. color: ManimColor = WHITE,
  24. stroke_width: float = 0,
  25. fill_opacity: float = 1.0,
  26. fill_border_width: float = 0.5,
  27. num_decimal_places: int = 2,
  28. include_sign: bool = False,
  29. group_with_commas: bool = True,
  30. digit_buff_per_font_unit: float = 0.001,
  31. show_ellipsis: bool = False,
  32. unit: str | None = None, # Aligned to bottom unless it starts with "^"
  33. include_background_rectangle: bool = False,
  34. edge_to_fix: Vect3 = LEFT,
  35. font_size: float = 48,
  36. text_config: dict = dict(), # Do not pass in font_size here
  37. **kwargs
  38. ):
  39. self.num_decimal_places = num_decimal_places
  40. self.include_sign = include_sign
  41. self.group_with_commas = group_with_commas
  42. self.digit_buff_per_font_unit = digit_buff_per_font_unit
  43. self.show_ellipsis = show_ellipsis
  44. self.unit = unit
  45. self.include_background_rectangle = include_background_rectangle
  46. self.edge_to_fix = edge_to_fix
  47. self.font_size = font_size
  48. self.text_config = dict(text_config)
  49. super().__init__(
  50. color=color,
  51. stroke_width=stroke_width,
  52. fill_opacity=fill_opacity,
  53. fill_border_width=fill_border_width,
  54. **kwargs
  55. )
  56. self.set_submobjects_from_number(number)
  57. self.init_colors()
  58. def set_submobjects_from_number(self, number: float | complex) -> None:
  59. # Create the submobject list
  60. self.number = number
  61. self.num_string = self.get_num_string(number)
  62. # Submob_templates will be a list of cached Tex and Text mobjects,
  63. # with the intent of calling .copy or .become on them
  64. submob_templates = list(map(self.char_to_mob, self.num_string))
  65. if self.show_ellipsis:
  66. dots = self.char_to_mob("...")
  67. dots.arrange(RIGHT, buff=2 * dots[0].get_width())
  68. submob_templates.append(dots)
  69. if self.unit is not None:
  70. submob_templates.append(self.char_to_mob(self.unit))
  71. # Set internals
  72. font_size = self.get_font_size()
  73. if len(submob_templates) == len(self.submobjects):
  74. for sm, smt in zip(self.submobjects, submob_templates):
  75. sm.become(smt)
  76. sm.scale(font_size / smt.font_size)
  77. else:
  78. self.set_submobjects([
  79. smt.copy().scale(font_size / smt.font_size)
  80. for smt in submob_templates
  81. ])
  82. digit_buff = self.digit_buff_per_font_unit * font_size
  83. self.arrange(RIGHT, buff=digit_buff, aligned_edge=DOWN)
  84. # Handle alignment of special characters
  85. for i, c in enumerate(self.num_string):
  86. if c == "–" and len(self.num_string) > i + 1:
  87. self[i].align_to(self[i + 1], UP)
  88. self[i].shift(self[i + 1].get_height() * DOWN / 2)
  89. elif c == ",":
  90. self[i].shift(self[i].get_height() * DOWN / 2)
  91. if self.unit and self.unit.startswith("^"):
  92. self[-1].align_to(self, UP)
  93. if self.include_background_rectangle:
  94. self.add_background_rectangle()
  95. def get_num_string(self, number: float | complex) -> str:
  96. if isinstance(number, complex):
  97. formatter = self.get_complex_formatter()
  98. else:
  99. formatter = self.get_formatter()
  100. if self.num_decimal_places == 0 and isinstance(number, float):
  101. number = int(number)
  102. num_string = formatter.format(number)
  103. rounded_num = np.round(number, self.num_decimal_places)
  104. if num_string.startswith("-") and rounded_num == 0:
  105. if self.include_sign:
  106. num_string = "+" + num_string[1:]
  107. else:
  108. num_string = num_string[1:]
  109. num_string = num_string.replace("-", "–")
  110. return num_string
  111. def char_to_mob(self, char: str) -> Text:
  112. return char_to_cahced_mob(char, **self.text_config)
  113. def interpolate(
  114. self,
  115. mobject1: Mobject,
  116. mobject2: Mobject,
  117. alpha: float,
  118. path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path
  119. ) -> Self:
  120. super().interpolate(mobject1, mobject2, alpha, path_func)
  121. if hasattr(mobject1, "font_size") and hasattr(mobject2, "font_size"):
  122. self.font_size = interpolate(mobject1.font_size, mobject2.font_size, alpha)
  123. def get_font_size(self) -> float:
  124. return self.font_size
  125. def get_formatter(self, **kwargs) -> str:
  126. """
  127. Configuration is based first off instance attributes,
  128. but overwritten by any kew word argument. Relevant
  129. key words:
  130. - include_sign
  131. - group_with_commas
  132. - num_decimal_places
  133. - field_name (e.g. 0 or 0.real)
  134. """
  135. config = dict([
  136. (attr, getattr(self, attr))
  137. for attr in [
  138. "include_sign",
  139. "group_with_commas",
  140. "num_decimal_places",
  141. ]
  142. ])
  143. config.update(kwargs)
  144. ndp = config["num_decimal_places"]
  145. return "".join([
  146. "{",
  147. config.get("field_name", ""),
  148. ":",
  149. "+" if config["include_sign"] else "",
  150. "," if config["group_with_commas"] else "",
  151. f".{ndp}f" if ndp > 0 else "d",
  152. "}",
  153. ])
  154. def get_complex_formatter(self, **kwargs) -> str:
  155. return "".join([
  156. self.get_formatter(field_name="0.real"),
  157. self.get_formatter(field_name="0.imag", include_sign=True),
  158. "i"
  159. ])
  160. def get_tex(self):
  161. return self.num_string
  162. def set_value(self, number: float | complex) -> Self:
  163. move_to_point = self.get_edge_center(self.edge_to_fix)
  164. style = self.family_members_with_points()[0].get_style()
  165. self.set_submobjects_from_number(number)
  166. self.move_to(move_to_point, self.edge_to_fix)
  167. self.set_style(**style)
  168. for submob in self.get_family():
  169. submob.uniforms.update(self.uniforms)
  170. return self
  171. def _handle_scale_side_effects(self, scale_factor: float) -> Self:
  172. self.font_size *= scale_factor
  173. return self
  174. def get_value(self) -> float | complex:
  175. return self.number
  176. def increment_value(self, delta_t: float | complex = 1) -> Self:
  177. self.set_value(self.get_value() + delta_t)
  178. return self
  179. class Integer(DecimalNumber):
  180. def __init__(
  181. self,
  182. number: int = 0,
  183. num_decimal_places: int = 0,
  184. **kwargs,
  185. ):
  186. super().__init__(number, num_decimal_places=num_decimal_places, **kwargs)
  187. def get_value(self) -> int:
  188. return int(np.round(super().get_value()))