simple_functions.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from __future__ import annotations
  2. from functools import lru_cache
  3. import hashlib
  4. import inspect
  5. import math
  6. import numpy as np
  7. from typing import TYPE_CHECKING
  8. if TYPE_CHECKING:
  9. from typing import Callable, TypeVar, Iterable
  10. from manimlib.typing import FloatArray
  11. Scalable = TypeVar("Scalable", float, FloatArray)
  12. def sigmoid(x: float | FloatArray):
  13. return 1.0 / (1 + np.exp(-x))
  14. @lru_cache(maxsize=10)
  15. def choose(n: int, k: int) -> int:
  16. return math.comb(n, k)
  17. def gen_choose(n: int, r: int) -> int:
  18. return int(np.prod(range(n, n - r, -1)) / math.factorial(r))
  19. def get_num_args(function: Callable) -> int:
  20. return function.__code__.co_argcount
  21. def get_parameters(function: Callable) -> Iterable[str]:
  22. return inspect.signature(function).parameters.keys()
  23. # Just to have a less heavyweight name for this extremely common operation
  24. #
  25. # We may wish to have more fine-grained control over division by zero behavior
  26. # in the future (separate specifiable values for 0/0 and x/0 with x != 0),
  27. # but for now, we just allow the option to handle indeterminate 0/0.
  28. def clip(a: float, min_a: float, max_a: float) -> float:
  29. if a < min_a:
  30. return min_a
  31. elif a > max_a:
  32. return max_a
  33. return a
  34. def arr_clip(arr: np.ndarray, min_a: float, max_a: float) -> np.ndarray:
  35. arr[arr < min_a] = min_a
  36. arr[arr > max_a] = max_a
  37. return arr
  38. def fdiv(a: Scalable, b: Scalable, zero_over_zero_value: Scalable | None = None) -> Scalable:
  39. if zero_over_zero_value is not None:
  40. out = np.full_like(a, zero_over_zero_value)
  41. where = np.logical_or(a != 0, b != 0)
  42. else:
  43. out = None
  44. where = True
  45. return np.true_divide(a, b, out=out, where=where)
  46. def binary_search(function: Callable[[float], float],
  47. target: float,
  48. lower_bound: float,
  49. upper_bound: float,
  50. tolerance:float = 1e-4) -> float | None:
  51. lh = lower_bound
  52. rh = upper_bound
  53. mh = (lh + rh) / 2
  54. while abs(rh - lh) > tolerance:
  55. lx, mx, rx = [function(h) for h in (lh, mh, rh)]
  56. if lx == target:
  57. return lx
  58. if rx == target:
  59. return rx
  60. if lx <= target and rx >= target:
  61. if mx > target:
  62. rh = mh
  63. else:
  64. lh = mh
  65. elif lx > target and rx < target:
  66. lh, rh = rh, lh
  67. else:
  68. return None
  69. mh = (lh + rh) / 2
  70. return mh
  71. def hash_string(string: str) -> str:
  72. # Truncating at 16 bytes for cleanliness
  73. hasher = hashlib.sha256(string.encode())
  74. return hasher.hexdigest()[:16]