probability.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from __future__ import annotations
  2. import numpy as np
  3. from manimlib.constants import BLUE, BLUE_E, GREEN_E, GREY_B, GREY_D, MAROON_B, YELLOW
  4. from manimlib.constants import DOWN, LEFT, RIGHT, UP
  5. from manimlib.constants import MED_LARGE_BUFF, MED_SMALL_BUFF, SMALL_BUFF
  6. from manimlib.mobject.geometry import Line
  7. from manimlib.mobject.geometry import Rectangle
  8. from manimlib.mobject.mobject import Mobject
  9. from manimlib.mobject.svg.brace import Brace
  10. from manimlib.mobject.svg.tex_mobject import Tex
  11. from manimlib.mobject.svg.tex_mobject import TexText
  12. from manimlib.mobject.types.vectorized_mobject import VGroup
  13. from manimlib.utils.color import color_gradient
  14. from manimlib.utils.iterables import listify
  15. from typing import TYPE_CHECKING
  16. if TYPE_CHECKING:
  17. from typing import Iterable
  18. from manimlib.typing import ManimColor
  19. EPSILON = 0.0001
  20. class SampleSpace(Rectangle):
  21. def __init__(
  22. self,
  23. width: float = 3,
  24. height: float = 3,
  25. fill_color: ManimColor = GREY_D,
  26. fill_opacity: float = 1,
  27. stroke_width: float = 0.5,
  28. stroke_color: ManimColor = GREY_B,
  29. default_label_scale_val: float = 1,
  30. **kwargs,
  31. ):
  32. super().__init__(
  33. width, height,
  34. fill_color=fill_color,
  35. fill_opacity=fill_opacity,
  36. stroke_width=stroke_width,
  37. stroke_color=stroke_color,
  38. )
  39. self.default_label_scale_val = default_label_scale_val
  40. def add_title(
  41. self,
  42. title: str = "Sample space",
  43. buff: float = MED_SMALL_BUFF
  44. ) -> None:
  45. # TODO, should this really exist in SampleSpaceScene
  46. title_mob = TexText(title)
  47. if title_mob.get_width() > self.get_width():
  48. title_mob.set_width(self.get_width())
  49. title_mob.next_to(self, UP, buff=buff)
  50. self.title = title_mob
  51. self.add(title_mob)
  52. def add_label(self, label: str) -> None:
  53. self.label = label
  54. def complete_p_list(self, p_list: list[float]) -> list[float]:
  55. new_p_list = listify(p_list)
  56. remainder = 1.0 - sum(new_p_list)
  57. if abs(remainder) > EPSILON:
  58. new_p_list.append(remainder)
  59. return new_p_list
  60. def get_division_along_dimension(
  61. self,
  62. p_list: list[float],
  63. dim: int,
  64. colors: Iterable[ManimColor],
  65. vect: np.ndarray
  66. ) -> VGroup:
  67. p_list = self.complete_p_list(p_list)
  68. colors = color_gradient(colors, len(p_list))
  69. last_point = self.get_edge_center(-vect)
  70. parts = VGroup()
  71. for factor, color in zip(p_list, colors):
  72. part = SampleSpace()
  73. part.set_fill(color, 1)
  74. part.replace(self, stretch=True)
  75. part.stretch(factor, dim)
  76. part.move_to(last_point, -vect)
  77. last_point = part.get_edge_center(vect)
  78. parts.add(part)
  79. return parts
  80. def get_horizontal_division(
  81. self,
  82. p_list: list[float],
  83. colors: Iterable[ManimColor] = [GREEN_E, BLUE_E],
  84. vect: np.ndarray = DOWN
  85. ) -> VGroup:
  86. return self.get_division_along_dimension(p_list, 1, colors, vect)
  87. def get_vertical_division(
  88. self,
  89. p_list: list[float],
  90. colors: Iterable[ManimColor] = [MAROON_B, YELLOW],
  91. vect: np.ndarray = RIGHT
  92. ) -> VGroup:
  93. return self.get_division_along_dimension(p_list, 0, colors, vect)
  94. def divide_horizontally(self, *args, **kwargs) -> None:
  95. self.horizontal_parts = self.get_horizontal_division(*args, **kwargs)
  96. self.add(self.horizontal_parts)
  97. def divide_vertically(self, *args, **kwargs) -> None:
  98. self.vertical_parts = self.get_vertical_division(*args, **kwargs)
  99. self.add(self.vertical_parts)
  100. def get_subdivision_braces_and_labels(
  101. self,
  102. parts: VGroup,
  103. labels: str,
  104. direction: np.ndarray,
  105. buff: float = SMALL_BUFF,
  106. ) -> VGroup:
  107. label_mobs = VGroup()
  108. braces = VGroup()
  109. for label, part in zip(labels, parts):
  110. brace = Brace(
  111. part, direction,
  112. buff=buff
  113. )
  114. if isinstance(label, Mobject):
  115. label_mob = label
  116. else:
  117. label_mob = Tex(label)
  118. label_mob.scale(self.default_label_scale_val)
  119. label_mob.next_to(brace, direction, buff)
  120. braces.add(brace)
  121. label_mobs.add(label_mob)
  122. parts.braces = braces
  123. parts.labels = label_mobs
  124. parts.label_kwargs = {
  125. "labels": label_mobs.copy(),
  126. "direction": direction,
  127. "buff": buff,
  128. }
  129. return VGroup(parts.braces, parts.labels)
  130. def get_side_braces_and_labels(
  131. self,
  132. labels: str,
  133. direction: np.ndarray = LEFT,
  134. **kwargs
  135. ) -> VGroup:
  136. assert hasattr(self, "horizontal_parts")
  137. parts = self.horizontal_parts
  138. return self.get_subdivision_braces_and_labels(parts, labels, direction, **kwargs)
  139. def get_top_braces_and_labels(
  140. self,
  141. labels: str,
  142. **kwargs
  143. ) -> VGroup:
  144. assert hasattr(self, "vertical_parts")
  145. parts = self.vertical_parts
  146. return self.get_subdivision_braces_and_labels(parts, labels, UP, **kwargs)
  147. def get_bottom_braces_and_labels(
  148. self,
  149. labels: str,
  150. **kwargs
  151. ) -> VGroup:
  152. assert hasattr(self, "vertical_parts")
  153. parts = self.vertical_parts
  154. return self.get_subdivision_braces_and_labels(parts, labels, DOWN, **kwargs)
  155. def add_braces_and_labels(self) -> None:
  156. for attr in "horizontal_parts", "vertical_parts":
  157. if not hasattr(self, attr):
  158. continue
  159. parts = getattr(self, attr)
  160. for subattr in "braces", "labels":
  161. if hasattr(parts, subattr):
  162. self.add(getattr(parts, subattr))
  163. def __getitem__(self, index: int | slice) -> VGroup:
  164. if hasattr(self, "horizontal_parts"):
  165. return self.horizontal_parts[index]
  166. elif hasattr(self, "vertical_parts"):
  167. return self.vertical_parts[index]
  168. return self.split()[index]
  169. class BarChart(VGroup):
  170. def __init__(
  171. self,
  172. values: Iterable[float],
  173. height: float = 4,
  174. width: float = 6,
  175. n_ticks: int = 4,
  176. include_x_ticks: bool = False,
  177. tick_width: float = 0.2,
  178. tick_height: float = 0.15,
  179. label_y_axis: bool = True,
  180. y_axis_label_height: float = 0.25,
  181. max_value: float = 1,
  182. bar_colors: list[ManimColor] = [BLUE, YELLOW],
  183. bar_fill_opacity: float = 0.8,
  184. bar_stroke_width: float = 3,
  185. bar_names: list[str] = [],
  186. bar_label_scale_val: float = 0.75,
  187. **kwargs
  188. ):
  189. super().__init__(**kwargs)
  190. self.height = height
  191. self.width = width
  192. self.n_ticks = n_ticks
  193. self.include_x_ticks = include_x_ticks
  194. self.tick_width = tick_width
  195. self.tick_height = tick_height
  196. self.label_y_axis = label_y_axis
  197. self.y_axis_label_height = y_axis_label_height
  198. self.max_value = max_value
  199. self.bar_colors = bar_colors
  200. self.bar_fill_opacity = bar_fill_opacity
  201. self.bar_stroke_width = bar_stroke_width
  202. self.bar_names = bar_names
  203. self.bar_label_scale_val = bar_label_scale_val
  204. if self.max_value is None:
  205. self.max_value = max(values)
  206. self.n_ticks_x = len(values)
  207. self.add_axes()
  208. self.add_bars(values)
  209. self.center()
  210. def add_axes(self) -> None:
  211. x_axis = Line(self.tick_width * LEFT / 2, self.width * RIGHT)
  212. y_axis = Line(MED_LARGE_BUFF * DOWN, self.height * UP)
  213. y_ticks = VGroup()
  214. heights = np.linspace(0, self.height, self.n_ticks + 1)
  215. values = np.linspace(0, self.max_value, self.n_ticks + 1)
  216. for y, value in zip(heights, values):
  217. y_tick = Line(LEFT, RIGHT)
  218. y_tick.set_width(self.tick_width)
  219. y_tick.move_to(y * UP)
  220. y_ticks.add(y_tick)
  221. y_axis.add(y_ticks)
  222. if self.include_x_ticks == True:
  223. x_ticks = VGroup()
  224. widths = np.linspace(0, self.width, self.n_ticks_x + 1)
  225. label_values = np.linspace(0, len(self.bar_names), self.n_ticks_x + 1)
  226. for x, value in zip(widths, label_values):
  227. x_tick = Line(UP, DOWN)
  228. x_tick.set_height(self.tick_height)
  229. x_tick.move_to(x * RIGHT)
  230. x_ticks.add(x_tick)
  231. x_axis.add(x_ticks)
  232. self.add(x_axis, y_axis)
  233. self.x_axis, self.y_axis = x_axis, y_axis
  234. if self.label_y_axis:
  235. labels = VGroup()
  236. for y_tick, value in zip(y_ticks, values):
  237. label = Tex(str(np.round(value, 2)))
  238. label.set_height(self.y_axis_label_height)
  239. label.next_to(y_tick, LEFT, SMALL_BUFF)
  240. labels.add(label)
  241. self.y_axis_labels = labels
  242. self.add(labels)
  243. def add_bars(self, values: Iterable[float]) -> None:
  244. buff = float(self.width) / (2 * len(values))
  245. bars = VGroup()
  246. for i, value in enumerate(values):
  247. bar = Rectangle(
  248. height=(value / self.max_value) * self.height,
  249. width=buff,
  250. stroke_width=self.bar_stroke_width,
  251. fill_opacity=self.bar_fill_opacity,
  252. )
  253. bar.move_to((2 * i + 0.5) * buff * RIGHT, DOWN + LEFT * 5)
  254. bars.add(bar)
  255. bars.set_color_by_gradient(*self.bar_colors)
  256. bar_labels = VGroup()
  257. for bar, name in zip(bars, self.bar_names):
  258. label = Tex(str(name))
  259. label.scale(self.bar_label_scale_val)
  260. label.next_to(bar, DOWN, SMALL_BUFF)
  261. bar_labels.add(label)
  262. self.add(bars, bar_labels)
  263. self.bars = bars
  264. self.bar_labels = bar_labels
  265. def change_bar_values(self, values: Iterable[float]) -> None:
  266. for bar, value in zip(self.bars, values):
  267. bar_bottom = bar.get_bottom()
  268. bar.stretch_to_fit_height(
  269. (value / self.max_value) * self.height
  270. )
  271. bar.move_to(bar_bottom, DOWN)