number_theory__fixed 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import logging
  4. import math
  5. import random
  6. from functools import cache
  7. logger = logging.getLogger("global_logger")
  8. try:
  9. import gmpy2 as gmpy
  10. gmpy_version = 2
  11. mpz = gmpy.mpz
  12. logger.info("[+] Using gmpy version 2 for math.")
  13. except ImportError:
  14. try:
  15. import gmpy
  16. gmpy_version = 1
  17. mpz = gmpy.mpz
  18. logger.info("[+] Using gmpy version 1 for math.")
  19. except ImportError:
  20. gmpy_version = 0
  21. mpz = int
  22. gmpy = None
  23. logger.warning(
  24. "[!] Using native python functions for math, which is slow. install gmpy2 with: 'python3 -m pip install <module>'."
  25. )
  26. @cache
  27. def list_prod(list_):
  28. list_ = tuple(list_)
  29. if (l := len(list_)) == 0:
  30. return 1
  31. return list_prod(list_[: l - 1]) * list_[-1]
  32. digit_sum = lambda n: sum(int(d) for d in str(n))
  33. A007814 = lambda n: (~n & n - 1).bit_length()
  34. A135481 = lambda n: (~n & n - 1)
  35. A000265 = lambda n: n // (A135481(n) + 1)
  36. @cache
  37. def mulmod(a, b, m):
  38. if b == 0:
  39. return 0
  40. if b == 1:
  41. return a % m
  42. if b & 1 == 0:
  43. return mulmod((a << 1) % m, b >> 1, m)
  44. else:
  45. return (a + mulmod(a, b - 1, m)) % m
  46. def getpubkeysz(n):
  47. if (size := n.bit_length()) & 1 != 0:
  48. size += 1
  49. return size
  50. is_pow2 = lambda n: n & (n - 1) == 0
  51. def _gcdext(a, b):
  52. if a == 0:
  53. return [b, 0, 1]
  54. d, r = divmod(b, a)
  55. g, y, x = _gcdext(r, a)
  56. return [g, x - d * y, y]
  57. def _isqrt(n):
  58. if n == 0:
  59. return 0
  60. x, y = n, (n + 1) >> 1
  61. while y < x:
  62. x, y = y, (y + n // y) >> 1
  63. return x
  64. def _isqrt_rem(n):
  65. i2 = _isqrt(n)
  66. return i2, n - (i2 * i2)
  67. def _isqrt_rem_gmpy(n):
  68. i2 = _isqrt_gmpy(n)
  69. return i2, n - (i2 * i2)
  70. def _gcd(a, b):
  71. while b:
  72. a, b = b, a % b
  73. return abs(a)
  74. def _remove(n, p):
  75. r = n
  76. c = 0
  77. while r % p == 0:
  78. r //= p
  79. c += 1
  80. return r, c
  81. def _introot(n, r=2):
  82. if n < 0:
  83. return None if r & 1 == 0 else -_introot(-n, r)
  84. if n < 2:
  85. return n
  86. if r == 2:
  87. return _isqrt(n)
  88. lower, upper = 0, n
  89. while lower != upper - 1:
  90. mid = lower + ((upper - lower) >> 1)
  91. m = pow(mid, r)
  92. if m == n:
  93. return mid
  94. lower = mid * (m < n) + lower * (m >= n)
  95. upper = mid * (m > n) + upper * (m <= n)
  96. return lower
  97. def _iroot(n, p):
  98. b = introot(n, p)
  99. return b, b**p == n
  100. def _introot_gmpy(n, r=2):
  101. if n < 0:
  102. return None if r & 1 == 0 else -_introot_gmpy(-n, r)
  103. return gmpy.root(n, r)[0]
  104. def _introot_gmpy2(n, r=2):
  105. if n < 0:
  106. return None if r & 1 == 0 else -_introot_gmpy2(-n, r)
  107. return gmpy.iroot(n, r)[0]
  108. def _invmod(a, m):
  109. a, x, u = a % m, 0, 1
  110. while a:
  111. x, u, m, a = u, x - (m // a) * u, a, m % a
  112. return x
  113. def _is_square(n):
  114. if (h := n & 0xF) > 9 or h in [2, 3, 5, 6, 7, 8]:
  115. return False
  116. t = _isqrt(n)
  117. return t * t == n
  118. def _powmod_base_list(base_lst, exp, mod):
  119. return list(powmod(i, exp, mod) for i in base_lst)
  120. def _powmod_exp_list(base, exp_lst, mod):
  121. return list(powmod(base, i, mod) for i in exp_lst)
  122. def miller_rabin(n, k=40):
  123. """ "
  124. Taken from https://gist.github.com/Ayrx/5884790
  125. Implementation uses the Miller-Rabin Primality Test
  126. The optimal number of rounds for this test is 40
  127. See http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic-safe-primes
  128. for justification
  129. """
  130. if n == 2:
  131. return True
  132. if (n & 1 == 0) or (digit_sum(n) % 9 in [0, 3, 6]):
  133. return False
  134. r, s = 0, n - 1
  135. while s & 1 == 0:
  136. r += 1
  137. s >>= 1
  138. i = 0
  139. for _ in range(0, k):
  140. a = random.randrange(2, n - 1)
  141. if (x := pow(a, s, n)) in [1, n - 1]:
  142. continue
  143. j = 0
  144. while j <= r - 1:
  145. if (x := pow(x, 2, n)) == (n - 1):
  146. break
  147. j += 1
  148. else:
  149. return False
  150. return True
  151. def _fermat_prime_criterion(n, b=2):
  152. """Fermat's prime criterion
  153. Returns False if n is definitely composite, True if possible prime."""
  154. return pow(b, n - 1, n) == 1
  155. def _is_prime(n):
  156. """
  157. If fermats prime criterion is false by short circuit we dont need to keep testing bases, so we return false for a guaranteed composite.
  158. Otherwise we keep trying with primes 3 and 5 as base. The sweet spot is primes 2,3,5, it doesn't improvee the runing time adding more primes to test as base.
  159. If all the previous tests pass then we try with rabin miller.
  160. All the tests are probabilistic.
  161. """
  162. if all(
  163. (
  164. _fermat_prime_criterion(n),
  165. _fermat_prime_criterion(n, b=3),
  166. _fermat_prime_criterion(n, b=5),
  167. )
  168. ):
  169. return miller_rabin(n)
  170. else:
  171. return False
  172. def _next_prime(n):
  173. while True:
  174. if _is_prime(n):
  175. return n
  176. n += 1
  177. def erathostenes_sieve(n):
  178. """
  179. Returns a list of primes < n
  180. """
  181. sieve = [True] * n
  182. for i in range(3, isqrt(n) + 1, 2):
  183. if sieve[i]:
  184. sieve[pow(i, 2) :: (i << 1)] = [False] * ((n - pow(i, 2) - 1) // (i << 1) + 1)
  185. return [2] + [i for i in range(3, n, 2) if sieve[i]]
  186. _primes = erathostenes_sieve
  187. def _primes_yield(n):
  188. p = i = 1
  189. while i <= n:
  190. p = next_prime(p)
  191. yield p
  192. i += 1
  193. def _primes_yield_gmpy(n):
  194. p = i = 1
  195. while i <= n:
  196. p = gmpy.next_prime(p)
  197. yield p
  198. i += 1
  199. def _fib(n):
  200. a, b = 0, 1
  201. i = 0
  202. while i <= n:
  203. a, b = b, a + b
  204. i += 1
  205. return a
  206. def ilogb(x, b):
  207. """
  208. greatest integer l such that b**l < = x.
  209. """
  210. l = 0
  211. while x >= b:
  212. x /= b
  213. l += 1
  214. return l
  215. _primes_gmpy = lambda n: list(_primes_yield_gmpy(n))
  216. _isqrt_gmpy = lambda n: int(gmpy.sqrt(n))
  217. _invert = lambda a, b: pow(a, b - 2, b)
  218. _lcm = lambda x, y: (x * y) // _gcd(x, y)
  219. _ilog2_gmpy = lambda n: int(gmpy.log2(n))
  220. _ilog_gmpy = lambda n: int(gmpy.log(n))
  221. _ilog2_math = lambda n: int(math.log2(n))
  222. _ilog_math = lambda n: int(math.log(n))
  223. _ilog10_math = lambda n: int(math.log10(n))
  224. _ilog10_gmpy = lambda n: int(gmpy.log10(n))
  225. _mod = lambda a, b: a % b
  226. _mul = lambda a, b: a * b
  227. _is_divisible = lambda n, p: n % p == 0
  228. _is_congruent = lambda a, b, m: (a - b) % m == 0
  229. def _powmod(b, e, m):
  230. r = 1
  231. b %= m
  232. while e > 0:
  233. r = ((r * b) % m) * (e & 1) + r * ((e + 1) & 1)
  234. e >>= 1
  235. b = (b * b) % m
  236. return r
  237. def _fac(n):
  238. """
  239. Factorial
  240. """
  241. tmp = 1
  242. for m in range(n, 1, -1):
  243. tmp *= m
  244. return tmp
  245. @cache
  246. def _lucas(n):
  247. if n == 0:
  248. return 2
  249. if n == 1:
  250. return 1
  251. return _lucas(n - 1) + _lucas(n - 2)
  252. if gmpy_version > 0:
  253. gcd = gmpy.gcd
  254. gcdext = gmpy.gcdext
  255. is_square = gmpy.is_square
  256. next_prime = gmpy.next_prime
  257. is_prime = gmpy.is_prime
  258. fib = gmpy.fib
  259. primes = _primes_gmpy
  260. lcm = gmpy.lcm
  261. invert = gmpy.invert
  262. invmod = gmpy.invert
  263. remove = gmpy.remove
  264. fac = gmpy.fac
  265. if gmpy_version == 2:
  266. iroot = gmpy.iroot
  267. ilog = _ilog_gmpy
  268. ilog2 = _ilog2_gmpy
  269. ilog10 = _ilog10_gmpy
  270. log = gmpy.log
  271. log2 = gmpy.log2
  272. log10 = gmpy.log10
  273. mod = gmpy.f_mod
  274. mul = gmpy.mul
  275. powmod = gmpy.powmod
  276. isqrt_rem = gmpy.isqrt_rem
  277. introot = _introot_gmpy2
  278. is_divisible = gmpy.is_divisible
  279. is_congruent = gmpy.is_congruent
  280. fdivmod = gmpy.f_divmod
  281. lucas = gmpy.lucas
  282. powmod_base_list = gmpy.powmod_base_list
  283. powmod_exp_list = gmpy.powmod_exp_list
  284. else:
  285. iroot = gmpy.root
  286. ilog = _ilog_math
  287. ilog2 = _ilog2_math
  288. ilog10 = _ilog10_math
  289. log = math.log
  290. log2 = math.log2
  291. log10 = math.log10
  292. mul = _mul
  293. mod = _mod
  294. powmod = pow
  295. isqrt_rem = gmpy.sqrtrem
  296. introot = _introot_gmpy
  297. is_divisible = _is_divisible
  298. is_congruent = _is_congruent
  299. fdivmod = gmpy.fdivmod
  300. lucas = _lucas
  301. powmod_base_list = _powmod_base_list
  302. powmod_exp_list = _powmod_exp_list
  303. isqrt = gmpy.isqrt
  304. else:
  305. remove = _remove
  306. iroot = _iroot
  307. gcd = _gcd
  308. isqrt = _isqrt
  309. isqrt_rem = _isqrt_rem
  310. introot = _introot
  311. invmod = _invmod
  312. gcdext = _gcdext
  313. is_square = _is_square
  314. next_prime = _next_prime
  315. fib = _fib
  316. primes = erathostenes_sieve
  317. is_prime = _is_prime
  318. fib = _fib
  319. primes = _primes
  320. lcm = _lcm
  321. invert = _invmod
  322. powmod = _powmod
  323. ilog = _ilog_math
  324. ilog2 = _ilog2_math
  325. ilog10 = _ilog10_math
  326. log = math.log
  327. log2 = math.log2
  328. log10 = math.log10
  329. mod = _mod
  330. mul = _mul
  331. is_divisible = _is_divisible
  332. is_congruent = _is_congruent
  333. fac = _fac
  334. fdivmod = divmod
  335. lucas = _lucas
  336. powmod_base_list = _powmod_base_list
  337. powmod_exp_list = _powmod_exp_list
  338. legendre = lambda a, p: powmod(a, (p - 1) >> 1, p)
  339. cuberoot = lambda n: introot(n, 3)
  340. def factor_ned_probabilistic(n, e, d):
  341. """
  342. 800-56B R1 Recommendation for Pair-Wise Key Establishment Schemes Using Integer Factorization Cryptography in Appendix C.
  343. """
  344. n1, k = n - 1, d * e - 1
  345. if k & 1 == 1:
  346. return None
  347. t, r = 0, k
  348. while r & 1 == 0:
  349. r >>= 1
  350. t += 1
  351. for _ in range(1, 101):
  352. g = random.randint(0, n1)
  353. if (y := pow(g, r, n)) == 1 or y == n1:
  354. continue
  355. for _ in range(1, t):
  356. if (x := pow(y, 2, n)) == 1:
  357. p = gcd(y - 1, n)
  358. return p, n // p
  359. if x == n1:
  360. continue
  361. y = x
  362. if (x := pow(y, 2, n)) == 1:
  363. p = gcd(x - 1, n)
  364. return p, n // p
  365. def trivial_factorization_with_n_b(n, b):
  366. if (b2n4 := (b * b) - (n << 2)) > 0:
  367. i = isqrt(b2n4)
  368. p, q = int((b - i) >> 1), int((b + i) >> 1)
  369. if p * q == n:
  370. return p, q
  371. def factor_ned_deterministic(n, e, d):
  372. """
  373. 800-56B R2 Recommendation for Pair-Wise Key Establishment Schemes Using Integer Factorization Cryptography in Appendix C.2.
  374. """
  375. k = d * e - 1
  376. m, r = divmod(k * gcd(n - 1, k), n)
  377. return trivial_factorization_with_n_b(n, ((n - r) // (m + 1)) + 1)
  378. factor_ned = factor_ned_deterministic
  379. trivial_factorization_with_n_phi = lambda n, phi: trivial_factorization_with_n_b(n, n - phi + 1)
  380. def neg_pow(a, b, n):
  381. """
  382. Calculates a^{b} mod n when b is negative
  383. """
  384. assert b < 0
  385. assert gcd(a, n) == 1
  386. res = int(invert(a, n))
  387. return powmod(res, b * (-1), n)
  388. def common_modulus_related_message(e1, e2, n, c1, c2):
  389. """
  390. e1 --> Public Key exponent used to encrypt message m and get ciphertext c1
  391. e2 --> Public Key exponent used to encrypt message m and get ciphertext c2
  392. n --> Modulus
  393. The following attack works only when m^{GCD(e1, e2)} < n
  394. """
  395. g, a, b = gcdext(e1, e2)
  396. if g == 1:
  397. return None
  398. c1 = neg_pow(c1, a, n) if a < 0 else powmod(c1, a, n)
  399. c2 = neg_pow(c2, b, n) if a < 0 else powmod(c2, b, n)
  400. ct = c1 * c2 % n
  401. return int(introot(ct, g))
  402. def phi(n, factors):
  403. """
  404. Euler totient function
  405. """
  406. if is_prime(n):
  407. return n - 1
  408. elif is_square(n):
  409. i2 = isqrt(n)
  410. return phi(i2, factors) * i2
  411. else:
  412. y = n
  413. for p in factors:
  414. if n % p == 0:
  415. y //= p
  416. y *= p - 1
  417. n, _ = remove(n, p)
  418. if n > 1:
  419. y //= n
  420. y *= n - 1
  421. return y
  422. def chinese_remainder(m, a):
  423. S = 0
  424. N = list_prod(tuple(m))
  425. for i in range(0, len(m)):
  426. Ni = N // m[i]
  427. S += Ni * invert(Ni, m[i]) * a[i]
  428. return S % N
  429. def tonelli(n, p):
  430. """
  431. tonelli-shanks modular squareroot algorithm
  432. """
  433. assert legendre(n, p) == 1, "not a square (mod p)"
  434. q = p - 1
  435. q >>= (s := A007814(q))
  436. if s == 1:
  437. return powmod(n, (p + 1) >> 2, p)
  438. for z in range(2, p):
  439. if p - 1 == legendre(z, p):
  440. break
  441. c, r, t, m = powmod(z, q, p), powmod(n, (q + 1) >> 1, p), powmod(n, q, p), s
  442. while (t - 1) % p != 0:
  443. t2 = powmod(t, 2, p)
  444. for i in range(1, m):
  445. if (t2 - 1) % p == 0:
  446. break
  447. t2 = powmod(t2, 2, p)
  448. b = powmod(c, 1 << (m - i - 1), p)
  449. # r = (r * b) % p
  450. r = mulmod(r, b, p)
  451. c = powmod(b, 2, p)
  452. # t = (t * c) % p
  453. t = mulmod(t, c, p)
  454. m = i
  455. return r
  456. def is_cube(n):
  457. b = False
  458. if (n % 9) in [0, 1, 8]:
  459. a, b = iroot(n, 3)
  460. return b
  461. def dlp_bruteforce(g, h, p):
  462. """
  463. Try to solve the discrete logarithm problem:
  464. x for g^x == h (mod p) with brute force.
  465. """
  466. for x in range(1, p):
  467. if h == powmod(g, x, p):
  468. return x
  469. def rational_to_contfrac(x, y):
  470. """Rational_to_contfrac implementation"""
  471. a = x // y
  472. if a * y == x:
  473. return [a]
  474. pquotients = rational_to_contfrac(y, x - a * y)
  475. pquotients.insert(0, a)
  476. return pquotients
  477. def contfrac_to_rational(frac):
  478. """Contfrac_to_rational implementation"""
  479. if len(frac) == 0:
  480. return (0, 1)
  481. elif len(frac) == 1:
  482. return (frac[0], 1)
  483. else:
  484. remainder = frac[1:]
  485. (num, denom) = contfrac_to_rational(remainder)
  486. return (frac[0] * num + denom, num)
  487. def convergents_from_contfrac(frac, progress=False):
  488. """Convergents_from_contfrac implementation"""
  489. return [contfrac_to_rational(frac[:i]) for i in range(0, len(frac))]
  490. def inv_mod_pow_of_2(factor, bit_count):
  491. """
  492. its orders of magnitude faster than invert(a, 2^k)
  493. code borrowed from: https://algassert.com/post/1709
  494. """
  495. rest = factor & -2
  496. acc = 1
  497. for i in range(bit_count):
  498. acc -= (acc & (1 << i)) * (rest << i)
  499. mask = (1 << bit_count) - 1
  500. return acc & mask
  501. def mlucas(v, a, n):
  502. """Helper function for williams_pp1(). Multiplies along a Lucas sequence modulo n."""
  503. v1, v2 = v, (v * v - 2) % n
  504. while a > 0:
  505. v1, v2 = ((v1 * v1 - 2) % n, (v1 * v2 - v) % n) if a & 1 == 0 else ((v1 * v2 - v) % n, (v2 * v2 - 2) % n)
  506. a >>= 1
  507. return v1
  508. __all__ = [
  509. getpubkeysz,
  510. gcd,
  511. isqrt,
  512. introot,
  513. invmod,
  514. gcdext,
  515. is_square,
  516. is_cube,
  517. next_prime,
  518. is_prime,
  519. fib,
  520. primes,
  521. lcm,
  522. invert,
  523. powmod,
  524. ilog2,
  525. ilog,
  526. ilog10,
  527. mod,
  528. log,
  529. log2,
  530. log10,
  531. trivial_factorization_with_n_phi,
  532. factor_ned,
  533. neg_pow,
  534. common_modulus_related_message,
  535. phi,
  536. list_prod,
  537. chinese_remainder,
  538. ilogb,
  539. mul,
  540. cuberoot,
  541. isqrt_rem,
  542. is_divisible,
  543. is_congruent,
  544. iroot,
  545. dlp_bruteforce,
  546. fac,
  547. rational_to_contfrac,
  548. contfrac_to_rational,
  549. convergents_from_contfrac,
  550. fdivmod,
  551. inv_mod_pow_of_2,
  552. mlucas,
  553. lucas,
  554. mulmod,
  555. A000265,
  556. powmod_base_list,
  557. powmod_exp_list,
  558. is_pow2,
  559. ]