surface.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. from __future__ import annotations
  2. import moderngl
  3. import numpy as np
  4. from manimlib.constants import GREY
  5. from manimlib.constants import OUT
  6. from manimlib.mobject.mobject import Mobject
  7. from manimlib.utils.bezier import integer_interpolate
  8. from manimlib.utils.bezier import interpolate
  9. from manimlib.utils.images import get_full_raster_image_path
  10. from manimlib.utils.iterables import listify
  11. from manimlib.utils.iterables import resize_with_interpolation
  12. from manimlib.utils.space_ops import normalize_along_axis
  13. from manimlib.utils.space_ops import cross
  14. from typing import TYPE_CHECKING
  15. if TYPE_CHECKING:
  16. from typing import Callable, Iterable, Sequence, Tuple
  17. from manimlib.camera.camera import Camera
  18. from manimlib.typing import ManimColor, Vect3, Vect3Array, Self
  19. class Surface(Mobject):
  20. render_primitive: int = moderngl.TRIANGLES
  21. shader_folder: str = "surface"
  22. data_dtype: np.dtype = np.dtype([
  23. ('point', np.float32, (3,)),
  24. ('du_point', np.float32, (3,)),
  25. ('dv_point', np.float32, (3,)),
  26. ('rgba', np.float32, (4,)),
  27. ])
  28. pointlike_data_keys = ['point', 'du_point', 'dv_point']
  29. def __init__(
  30. self,
  31. color: ManimColor = GREY,
  32. shading: Tuple[float, float, float] = (0.3, 0.2, 0.4),
  33. depth_test: bool = True,
  34. u_range: Tuple[float, float] = (0.0, 1.0),
  35. v_range: Tuple[float, float] = (0.0, 1.0),
  36. # Resolution counts number of points sampled, which for
  37. # each coordinate is one more than the the number of
  38. # rows/columns of approximating squares
  39. resolution: Tuple[int, int] = (101, 101),
  40. prefered_creation_axis: int = 1,
  41. # For du and dv steps. Much smaller and numerical error
  42. # can crop up in the shaders.
  43. epsilon: float = 1e-4,
  44. **kwargs
  45. ):
  46. self.u_range = u_range
  47. self.v_range = v_range
  48. self.resolution = resolution
  49. self.prefered_creation_axis = prefered_creation_axis
  50. self.epsilon = epsilon
  51. super().__init__(
  52. **kwargs,
  53. color=color,
  54. shading=shading,
  55. depth_test=depth_test,
  56. )
  57. self.compute_triangle_indices()
  58. def uv_func(self, u: float, v: float) -> tuple[float, float, float]:
  59. # To be implemented in subclasses
  60. return (u, v, 0.0)
  61. @Mobject.affects_data
  62. def init_points(self):
  63. dim = self.dim
  64. nu, nv = self.resolution
  65. u_range = np.linspace(*self.u_range, nu)
  66. v_range = np.linspace(*self.v_range, nv)
  67. # Get three lists:
  68. # - Points generated by pure uv values
  69. # - Those generated by values nudged by du
  70. # - Those generated by values nudged by dv
  71. uv_grid = np.array([[[u, v] for v in v_range] for u in u_range])
  72. uv_plus_du = uv_grid.copy()
  73. uv_plus_du[:, :, 0] += self.epsilon
  74. uv_plus_dv = uv_grid.copy()
  75. uv_plus_dv[:, :, 1] += self.epsilon
  76. points, du_points, dv_points = [
  77. np.apply_along_axis(
  78. lambda p: self.uv_func(*p), 2, grid
  79. ).reshape((nu * nv, dim))
  80. for grid in (uv_grid, uv_plus_du, uv_plus_dv)
  81. ]
  82. self.set_points(points)
  83. self.data['du_point'][:] = du_points
  84. self.data['dv_point'][:] = dv_points
  85. def apply_points_function(self, *args, **kwargs) -> Self:
  86. super().apply_points_function(*args, **kwargs)
  87. self.get_unit_normals()
  88. return self
  89. def compute_triangle_indices(self) -> np.ndarray:
  90. # TODO, if there is an event which changes
  91. # the resolution of the surface, make sure
  92. # this is called.
  93. nu, nv = self.resolution
  94. if nu == 0 or nv == 0:
  95. self.triangle_indices = np.zeros(0, dtype=int)
  96. return self.triangle_indices
  97. index_grid = np.arange(nu * nv).reshape((nu, nv))
  98. indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int)
  99. indices[0::6] = index_grid[:-1, :-1].flatten() # Top left
  100. indices[1::6] = index_grid[+1:, :-1].flatten() # Bottom left
  101. indices[2::6] = index_grid[:-1, +1:].flatten() # Top right
  102. indices[3::6] = index_grid[:-1, +1:].flatten() # Top right
  103. indices[4::6] = index_grid[+1:, :-1].flatten() # Bottom left
  104. indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right
  105. self.triangle_indices = indices
  106. return self.triangle_indices
  107. def get_triangle_indices(self) -> np.ndarray:
  108. return self.triangle_indices
  109. def get_unit_normals(self) -> Vect3Array:
  110. points = self.get_points()
  111. crosses = cross(
  112. self.data['du_point'] - points,
  113. self.data['dv_point'] - points,
  114. )
  115. return normalize_along_axis(crosses, 1)
  116. @Mobject.affects_data
  117. def pointwise_become_partial(
  118. self,
  119. smobject: "Surface",
  120. a: float,
  121. b: float,
  122. axis: int | None = None
  123. ) -> Self:
  124. assert isinstance(smobject, Surface)
  125. if axis is None:
  126. axis = self.prefered_creation_axis
  127. if a <= 0 and b >= 1:
  128. self.match_points(smobject)
  129. return self
  130. nu, nv = smobject.resolution
  131. self.data['point'][:] = self.get_partial_points_array(
  132. smobject.data['point'], a, b,
  133. (nu, nv, 3),
  134. axis=axis
  135. )
  136. return self
  137. def get_partial_points_array(
  138. self,
  139. points: Vect3Array,
  140. a: float,
  141. b: float,
  142. resolution: Sequence[int],
  143. axis: int
  144. ) -> Vect3Array:
  145. if len(points) == 0:
  146. return points
  147. nu, nv = resolution[:2]
  148. points = points.reshape(resolution).copy()
  149. max_index = resolution[axis] - 1
  150. lower_index, lower_residue = integer_interpolate(0, max_index, a)
  151. upper_index, upper_residue = integer_interpolate(0, max_index, b)
  152. if axis == 0:
  153. points[:lower_index] = interpolate(
  154. points[lower_index],
  155. points[lower_index + 1],
  156. lower_residue
  157. )
  158. points[upper_index + 1:] = interpolate(
  159. points[upper_index],
  160. points[upper_index + 1],
  161. upper_residue
  162. )
  163. else:
  164. shape = (nu, 1, resolution[2])
  165. points[:, :lower_index] = interpolate(
  166. points[:, lower_index],
  167. points[:, lower_index + 1],
  168. lower_residue
  169. ).reshape(shape)
  170. points[:, upper_index + 1:] = interpolate(
  171. points[:, upper_index],
  172. points[:, upper_index + 1],
  173. upper_residue
  174. ).reshape(shape)
  175. return points.reshape((nu * nv, *resolution[2:]))
  176. @Mobject.affects_data
  177. def sort_faces_back_to_front(self, vect: Vect3 = OUT) -> Self:
  178. tri_is = self.triangle_indices
  179. points = self.get_points()
  180. dots = (points[tri_is[::3]] * vect).sum(1)
  181. indices = np.argsort(dots)
  182. for k in range(3):
  183. tri_is[k::3] = tri_is[k::3][indices]
  184. return self
  185. def always_sort_to_camera(self, camera: Camera) -> Self:
  186. def updater(surface: Surface):
  187. vect = camera.get_location() - surface.get_center()
  188. surface.sort_faces_back_to_front(vect)
  189. self.add_updater(updater)
  190. return self
  191. def get_shader_vert_indices(self) -> np.ndarray:
  192. return self.get_triangle_indices()
  193. class ParametricSurface(Surface):
  194. def __init__(
  195. self,
  196. uv_func: Callable[[float, float], Iterable[float]],
  197. u_range: tuple[float, float] = (0, 1),
  198. v_range: tuple[float, float] = (0, 1),
  199. **kwargs
  200. ):
  201. self.passed_uv_func = uv_func
  202. super().__init__(u_range=u_range, v_range=v_range, **kwargs)
  203. def uv_func(self, u, v):
  204. return self.passed_uv_func(u, v)
  205. class SGroup(Surface):
  206. def __init__(
  207. self,
  208. *parametric_surfaces: Surface,
  209. **kwargs
  210. ):
  211. super().__init__(resolution=(0, 0), **kwargs)
  212. self.add(*parametric_surfaces)
  213. def init_points(self):
  214. pass # Needed?
  215. class TexturedSurface(Surface):
  216. shader_folder: str = "textured_surface"
  217. data_dtype: Sequence[Tuple[str, type, Tuple[int]]] = [
  218. ('point', np.float32, (3,)),
  219. ('du_point', np.float32, (3,)),
  220. ('dv_point', np.float32, (3,)),
  221. ('im_coords', np.float32, (2,)),
  222. ('opacity', np.float32, (1,)),
  223. ]
  224. def __init__(
  225. self,
  226. uv_surface: Surface,
  227. image_file: str,
  228. dark_image_file: str | None = None,
  229. **kwargs
  230. ):
  231. if not isinstance(uv_surface, Surface):
  232. raise Exception("uv_surface must be of type Surface")
  233. # Set texture information
  234. if dark_image_file is None:
  235. dark_image_file = image_file
  236. self.num_textures = 1
  237. else:
  238. self.num_textures = 2
  239. texture_paths = {
  240. "LightTexture": get_full_raster_image_path(image_file),
  241. "DarkTexture": get_full_raster_image_path(dark_image_file),
  242. }
  243. self.uv_surface = uv_surface
  244. self.uv_func = uv_surface.uv_func
  245. self.u_range: Tuple[float, float] = uv_surface.u_range
  246. self.v_range: Tuple[float, float] = uv_surface.v_range
  247. self.resolution: Tuple[int, int] = uv_surface.resolution
  248. super().__init__(
  249. texture_paths=texture_paths,
  250. shading=tuple(uv_surface.shading),
  251. **kwargs
  252. )
  253. @Mobject.affects_data
  254. def init_points(self):
  255. surf = self.uv_surface
  256. nu, nv = surf.resolution
  257. self.resize_points(surf.get_num_points())
  258. self.resolution = surf.resolution
  259. self.data['point'][:] = surf.data['point']
  260. self.data['du_point'][:] = surf.data['du_point']
  261. self.data['dv_point'][:] = surf.data['dv_point']
  262. self.data['opacity'][:, 0] = surf.data["rgba"][:, 3]
  263. self.data["im_coords"] = np.array([
  264. [u, v]
  265. for u in np.linspace(0, 1, nu)
  266. for v in np.linspace(1, 0, nv) # Reverse y-direction
  267. ])
  268. def init_uniforms(self):
  269. super().init_uniforms()
  270. self.uniforms["num_textures"] = self.num_textures
  271. @Mobject.affects_data
  272. def set_opacity(self, opacity: float | Iterable[float]) -> Self:
  273. op_arr = np.array(listify(opacity))
  274. self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data))
  275. return self
  276. def set_color(
  277. self,
  278. color: ManimColor | Iterable[ManimColor] | None,
  279. opacity: float | Iterable[float] | None = None,
  280. recurse: bool = True
  281. ) -> Self:
  282. if opacity is not None:
  283. self.set_opacity(opacity)
  284. return self
  285. def pointwise_become_partial(
  286. self,
  287. tsmobject: "TexturedSurface",
  288. a: float,
  289. b: float,
  290. axis: int = 1
  291. ) -> Self:
  292. super().pointwise_become_partial(tsmobject, a, b, axis)
  293. im_coords = self.data["im_coords"]
  294. im_coords[:] = tsmobject.data["im_coords"]
  295. if a <= 0 and b >= 1:
  296. return self
  297. nu, nv = tsmobject.resolution
  298. im_coords[:] = self.get_partial_points_array(
  299. im_coords, a, b, (nu, nv, 2), axis
  300. )
  301. return self