123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675 |
- #!/usr/bin/env python3
- from __future__ import annotations
- import logging
- import math
- import random
- from functools import cache
- logger = logging.getLogger("global_logger")
- try:
- import gmpy2 as gmpy
- gmpy_version = 2
- mpz = gmpy.mpz
- logger.info("[+] Using gmpy version 2 for math.")
- except ImportError:
- try:
- import gmpy
- gmpy_version = 1
- mpz = gmpy.mpz
- logger.info("[+] Using gmpy version 1 for math.")
- except ImportError:
- gmpy_version = 0
- mpz = int
- gmpy = None
- logger.warning(
- "[!] Using native python functions for math, which is slow. install gmpy2 with: 'python3 -m pip install <module>'."
- )
- @cache
- def list_prod(list_):
- list_ = tuple(list_)
- if (l := len(list_)) == 0:
- return 1
- return list_prod(list_[: l - 1]) * list_[-1]
- digit_sum = lambda n: sum(int(d) for d in str(n))
- A007814 = lambda n: (~n & n - 1).bit_length()
- A135481 = lambda n: (~n & n - 1)
- A000265 = lambda n: n // (A135481(n) + 1)
- @cache
- def mulmod(a, b, m):
- if b == 0:
- return 0
- if b == 1:
- return a % m
- if b & 1 == 0:
- return mulmod((a << 1) % m, b >> 1, m)
- else:
- return (a + mulmod(a, b - 1, m)) % m
- def getpubkeysz(n):
- if (size := n.bit_length()) & 1 != 0:
- size += 1
- return size
- is_pow2 = lambda n: n & (n - 1) == 0
- def _gcdext(a, b):
- if a == 0:
- return [b, 0, 1]
- d, r = divmod(b, a)
- g, y, x = _gcdext(r, a)
- return [g, x - d * y, y]
- def _isqrt(n):
- if n == 0:
- return 0
- x, y = n, (n + 1) >> 1
- while y < x:
- x, y = y, (y + n // y) >> 1
- return x
- def _isqrt_rem(n):
- i2 = _isqrt(n)
- return i2, n - (i2 * i2)
- def _isqrt_rem_gmpy(n):
- i2 = _isqrt_gmpy(n)
- return i2, n - (i2 * i2)
- def _gcd(a, b):
- while b:
- a, b = b, a % b
- return abs(a)
- def _remove(n, p):
- r = n
- c = 0
- while r % p == 0:
- r //= p
- c += 1
- return r, c
- def _introot(n, r=2):
- if n < 0:
- return None if r & 1 == 0 else -_introot(-n, r)
- if n < 2:
- return n
- if r == 2:
- return _isqrt(n)
- lower, upper = 0, n
- while lower != upper - 1:
- mid = lower + ((upper - lower) >> 1)
- m = pow(mid, r)
- if m == n:
- return mid
- lower = mid * (m < n) + lower * (m >= n)
- upper = mid * (m > n) + upper * (m <= n)
- return lower
- def _iroot(n, p):
- b = introot(n, p)
- return b, b**p == n
- def _introot_gmpy(n, r=2):
- if n < 0:
- return None if r & 1 == 0 else -_introot_gmpy(-n, r)
- return gmpy.root(n, r)[0]
- def _introot_gmpy2(n, r=2):
- if n < 0:
- return None if r & 1 == 0 else -_introot_gmpy2(-n, r)
- return gmpy.iroot(n, r)[0]
- def _invmod(a, m):
- a, x, u = a % m, 0, 1
- while a:
- x, u, m, a = u, x - (m // a) * u, a, m % a
- return x
- def _is_square(n):
- if (h := n & 0xF) > 9 or h in [2, 3, 5, 6, 7, 8]:
- return False
- t = _isqrt(n)
- return t * t == n
- def _powmod_base_list(base_lst, exp, mod):
- return list(powmod(i, exp, mod) for i in base_lst)
- def _powmod_exp_list(base, exp_lst, mod):
- return list(powmod(base, i, mod) for i in exp_lst)
- def miller_rabin(n, k=40):
- """ "
- Taken from https://gist.github.com/Ayrx/5884790
- Implementation uses the Miller-Rabin Primality Test
- The optimal number of rounds for this test is 40
- See http://stackoverflow.com/questions/6325576/how-many-iterations-of-rabin-miller-should-i-use-for-cryptographic-safe-primes
- for justification
- """
- if n == 2:
- return True
- if (n & 1 == 0) or (digit_sum(n) % 9 in [0, 3, 6]):
- return False
- r, s = 0, n - 1
- while s & 1 == 0:
- r += 1
- s >>= 1
- i = 0
- for _ in range(0, k):
- a = random.randrange(2, n - 1)
- if (x := pow(a, s, n)) in [1, n - 1]:
- continue
- j = 0
- while j <= r - 1:
- if (x := pow(x, 2, n)) == (n - 1):
- break
- j += 1
- else:
- return False
- return True
- def _fermat_prime_criterion(n, b=2):
- """Fermat's prime criterion
- Returns False if n is definitely composite, True if possible prime."""
- return pow(b, n - 1, n) == 1
- def _is_prime(n):
- """
- If fermats prime criterion is false by short circuit we dont need to keep testing bases, so we return false for a guaranteed composite.
- Otherwise we keep trying with primes 3 and 5 as base. The sweet spot is primes 2,3,5, it doesn't improvee the runing time adding more primes to test as base.
- If all the previous tests pass then we try with rabin miller.
- All the tests are probabilistic.
- """
- if all(
- (
- _fermat_prime_criterion(n),
- _fermat_prime_criterion(n, b=3),
- _fermat_prime_criterion(n, b=5),
- )
- ):
- return miller_rabin(n)
- else:
- return False
- def _next_prime(n):
- while True:
- if _is_prime(n):
- return n
- n += 1
- def erathostenes_sieve(n):
- """
- Returns a list of primes < n
- """
- sieve = [True] * n
- for i in range(3, isqrt(n) + 1, 2):
- if sieve[i]:
- sieve[pow(i, 2) :: (i << 1)] = [False] * ((n - pow(i, 2) - 1) // (i << 1) + 1)
- return [2] + [i for i in range(3, n, 2) if sieve[i]]
- _primes = erathostenes_sieve
- def _primes_yield(n):
- p = i = 1
- while i <= n:
- p = next_prime(p)
- yield p
- i += 1
- def _primes_yield_gmpy(n):
- p = i = 1
- while i <= n:
- p = gmpy.next_prime(p)
- yield p
- i += 1
- def _fib(n):
- a, b = 0, 1
- i = 0
- while i <= n:
- a, b = b, a + b
- i += 1
- return a
- def ilogb(x, b):
- """
- greatest integer l such that b**l < = x.
- """
- l = 0
- while x >= b:
- x /= b
- l += 1
- return l
- _primes_gmpy = lambda n: list(_primes_yield_gmpy(n))
- _isqrt_gmpy = lambda n: int(gmpy.sqrt(n))
- _invert = lambda a, b: pow(a, b - 2, b)
- _lcm = lambda x, y: (x * y) // _gcd(x, y)
- _ilog2_gmpy = lambda n: int(gmpy.log2(n))
- _ilog_gmpy = lambda n: int(gmpy.log(n))
- _ilog2_math = lambda n: int(math.log2(n))
- _ilog_math = lambda n: int(math.log(n))
- _ilog10_math = lambda n: int(math.log10(n))
- _ilog10_gmpy = lambda n: int(gmpy.log10(n))
- _mod = lambda a, b: a % b
- _mul = lambda a, b: a * b
- _is_divisible = lambda n, p: n % p == 0
- _is_congruent = lambda a, b, m: (a - b) % m == 0
- def _powmod(b, e, m):
- r = 1
- b %= m
- while e > 0:
- r = ((r * b) % m) * (e & 1) + r * ((e + 1) & 1)
- e >>= 1
- b = (b * b) % m
- return r
- def _fac(n):
- """
- Factorial
- """
- tmp = 1
- for m in range(n, 1, -1):
- tmp *= m
- return tmp
- @cache
- def _lucas(n):
- if n == 0:
- return 2
- if n == 1:
- return 1
- return _lucas(n - 1) + _lucas(n - 2)
- if gmpy_version > 0:
- gcd = gmpy.gcd
- gcdext = gmpy.gcdext
- is_square = gmpy.is_square
- next_prime = gmpy.next_prime
- is_prime = gmpy.is_prime
- fib = gmpy.fib
- primes = _primes_gmpy
- lcm = gmpy.lcm
- invert = gmpy.invert
- invmod = gmpy.invert
- remove = gmpy.remove
- fac = gmpy.fac
- if gmpy_version == 2:
- iroot = gmpy.iroot
- ilog = _ilog_gmpy
- ilog2 = _ilog2_gmpy
- ilog10 = _ilog10_gmpy
- log = gmpy.log
- log2 = gmpy.log2
- log10 = gmpy.log10
- mod = gmpy.f_mod
- mul = gmpy.mul
- powmod = gmpy.powmod
- isqrt_rem = gmpy.isqrt_rem
- introot = _introot_gmpy2
- is_divisible = gmpy.is_divisible
- is_congruent = gmpy.is_congruent
- fdivmod = gmpy.f_divmod
- lucas = gmpy.lucas
- powmod_base_list = gmpy.powmod_base_list
- powmod_exp_list = gmpy.powmod_exp_list
- else:
- iroot = gmpy.root
- ilog = _ilog_math
- ilog2 = _ilog2_math
- ilog10 = _ilog10_math
- log = math.log
- log2 = math.log2
- log10 = math.log10
- mul = _mul
- mod = _mod
- powmod = pow
- isqrt_rem = gmpy.sqrtrem
- introot = _introot_gmpy
- is_divisible = _is_divisible
- is_congruent = _is_congruent
- fdivmod = gmpy.fdivmod
- lucas = _lucas
- powmod_base_list = _powmod_base_list
- powmod_exp_list = _powmod_exp_list
- isqrt = gmpy.isqrt
- else:
- remove = _remove
- iroot = _iroot
- gcd = _gcd
- isqrt = _isqrt
- isqrt_rem = _isqrt_rem
- introot = _introot
- invmod = _invmod
- gcdext = _gcdext
- is_square = _is_square
- next_prime = _next_prime
- fib = _fib
- primes = erathostenes_sieve
- is_prime = _is_prime
- fib = _fib
- primes = _primes
- lcm = _lcm
- invert = _invmod
- powmod = _powmod
- ilog = _ilog_math
- ilog2 = _ilog2_math
- ilog10 = _ilog10_math
- log = math.log
- log2 = math.log2
- log10 = math.log10
- mod = _mod
- mul = _mul
- is_divisible = _is_divisible
- is_congruent = _is_congruent
- fac = _fac
- fdivmod = divmod
- lucas = _lucas
- powmod_base_list = _powmod_base_list
- powmod_exp_list = _powmod_exp_list
- legendre = lambda a, p: powmod(a, (p - 1) >> 1, p)
- cuberoot = lambda n: introot(n, 3)
- def factor_ned_probabilistic(n, e, d):
- """
- 800-56B R1 Recommendation for Pair-Wise Key Establishment Schemes Using Integer Factorization Cryptography in Appendix C.
- """
- n1, k = n - 1, d * e - 1
- if k & 1 == 1:
- return None
- t, r = 0, k
- while r & 1 == 0:
- r >>= 1
- t += 1
- for _ in range(1, 101):
- g = random.randint(0, n1)
- if (y := pow(g, r, n)) == 1 or y == n1:
- continue
- for _ in range(1, t):
- if (x := pow(y, 2, n)) == 1:
- p = gcd(y - 1, n)
- return p, n // p
- if x == n1:
- continue
- y = x
- if (x := pow(y, 2, n)) == 1:
- p = gcd(x - 1, n)
- return p, n // p
- def trivial_factorization_with_n_b(n, b):
- if (b2n4 := (b * b) - (n << 2)) > 0:
- i = isqrt(b2n4)
- p, q = int((b - i) >> 1), int((b + i) >> 1)
- if p * q == n:
- return p, q
- def factor_ned_deterministic(n, e, d):
- """
- 800-56B R2 Recommendation for Pair-Wise Key Establishment Schemes Using Integer Factorization Cryptography in Appendix C.2.
- """
- k = d * e - 1
- m, r = divmod(k * gcd(n - 1, k), n)
- return trivial_factorization_with_n_b(n, ((n - r) // (m + 1)) + 1)
- factor_ned = factor_ned_deterministic
- trivial_factorization_with_n_phi = lambda n, phi: trivial_factorization_with_n_b(n, n - phi + 1)
- def neg_pow(a, b, n):
- """
- Calculates a^{b} mod n when b is negative
- """
- assert b < 0
- assert gcd(a, n) == 1
- res = int(invert(a, n))
- return powmod(res, b * (-1), n)
- def common_modulus_related_message(e1, e2, n, c1, c2):
- """
- e1 --> Public Key exponent used to encrypt message m and get ciphertext c1
- e2 --> Public Key exponent used to encrypt message m and get ciphertext c2
- n --> Modulus
- The following attack works only when m^{GCD(e1, e2)} < n
- """
- g, a, b = gcdext(e1, e2)
- if g == 1:
- return None
- c1 = neg_pow(c1, a, n) if a < 0 else powmod(c1, a, n)
- c2 = neg_pow(c2, b, n) if a < 0 else powmod(c2, b, n)
- ct = c1 * c2 % n
- return int(introot(ct, g))
- def phi(n, factors):
- """
- Euler totient function
- """
- if is_prime(n):
- return n - 1
- elif is_square(n):
- i2 = isqrt(n)
- return phi(i2, factors) * i2
- else:
- y = n
- for p in factors:
- if n % p == 0:
- y //= p
- y *= p - 1
- n, _ = remove(n, p)
- if n > 1:
- y //= n
- y *= n - 1
- return y
- def chinese_remainder(m, a):
- S = 0
- N = list_prod(tuple(m))
- for i in range(0, len(m)):
- Ni = N // m[i]
- S += Ni * invert(Ni, m[i]) * a[i]
- return S % N
- def tonelli(n, p):
- """
- tonelli-shanks modular squareroot algorithm
- """
- assert legendre(n, p) == 1, "not a square (mod p)"
- q = p - 1
- q >>= (s := A007814(q))
- if s == 1:
- return powmod(n, (p + 1) >> 2, p)
- for z in range(2, p):
- if p - 1 == legendre(z, p):
- break
- c, r, t, m = powmod(z, q, p), powmod(n, (q + 1) >> 1, p), powmod(n, q, p), s
- while (t - 1) % p != 0:
- t2 = powmod(t, 2, p)
- for i in range(1, m):
- if (t2 - 1) % p == 0:
- break
- t2 = powmod(t2, 2, p)
- b = powmod(c, 1 << (m - i - 1), p)
- # r = (r * b) % p
- r = mulmod(r, b, p)
- c = powmod(b, 2, p)
- # t = (t * c) % p
- t = mulmod(t, c, p)
- m = i
- return r
- def is_cube(n):
- b = False
- if (n % 9) in [0, 1, 8]:
- a, b = iroot(n, 3)
- return b
- def dlp_bruteforce(g, h, p):
- """
- Try to solve the discrete logarithm problem:
- x for g^x == h (mod p) with brute force.
- """
- for x in range(1, p):
- if h == powmod(g, x, p):
- return x
- def rational_to_contfrac(x, y):
- """Rational_to_contfrac implementation"""
- a = x // y
- if a * y == x:
- return [a]
- pquotients = rational_to_contfrac(y, x - a * y)
- pquotients.insert(0, a)
- return pquotients
- def contfrac_to_rational(frac):
- """Contfrac_to_rational implementation"""
- if len(frac) == 0:
- return (0, 1)
- elif len(frac) == 1:
- return (frac[0], 1)
- else:
- remainder = frac[1:]
- (num, denom) = contfrac_to_rational(remainder)
- return (frac[0] * num + denom, num)
- def convergents_from_contfrac(frac, progress=False):
- """Convergents_from_contfrac implementation"""
- return [contfrac_to_rational(frac[:i]) for i in range(0, len(frac))]
- def inv_mod_pow_of_2(factor, bit_count):
- """
- its orders of magnitude faster than invert(a, 2^k)
- code borrowed from: https://algassert.com/post/1709
- """
- rest = factor & -2
- acc = 1
- for i in range(bit_count):
- acc -= (acc & (1 << i)) * (rest << i)
- mask = (1 << bit_count) - 1
- return acc & mask
- def mlucas(v, a, n):
- """Helper function for williams_pp1(). Multiplies along a Lucas sequence modulo n."""
- v1, v2 = v, (v * v - 2) % n
- while a > 0:
- v1, v2 = ((v1 * v1 - 2) % n, (v1 * v2 - v) % n) if a & 1 == 0 else ((v1 * v2 - v) % n, (v2 * v2 - 2) % n)
- a >>= 1
- return v1
- __all__ = [
- getpubkeysz,
- gcd,
- isqrt,
- introot,
- invmod,
- gcdext,
- is_square,
- is_cube,
- next_prime,
- is_prime,
- fib,
- primes,
- lcm,
- invert,
- powmod,
- ilog2,
- ilog,
- ilog10,
- mod,
- log,
- log2,
- log10,
- trivial_factorization_with_n_phi,
- factor_ned,
- neg_pow,
- common_modulus_related_message,
- phi,
- list_prod,
- chinese_remainder,
- ilogb,
- mul,
- cuberoot,
- isqrt_rem,
- is_divisible,
- is_congruent,
- iroot,
- dlp_bruteforce,
- fac,
- rational_to_contfrac,
- contfrac_to_rational,
- convergents_from_contfrac,
- fdivmod,
- inv_mod_pow_of_2,
- mlucas,
- lucas,
- mulmod,
- A000265,
- powmod_base_list,
- powmod_exp_list,
- is_pow2,
- ]
|