-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnumthy.py
More file actions
6398 lines (5436 loc) · 221 KB
/
numthy.py
File metadata and controls
6398 lines (5436 loc) · 221 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2026 Ini Oguntola
# Permission is granted to use, copy, modify, and redistribute this work,
# provided acknowledgement of the original author is retained.
# Supports Python 3.10 or later.
from __future__ import annotations
import bisect
import cmath
import hashlib
import hmac
import inspect
import itertools
import secrets
import sys
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from fractions import Fraction
from functools import cache, lru_cache, partial, reduce
from heapq import heappop, heappush
from math import ceil, fsum, gcd, inf, isfinite, isqrt, lcm, log, prod, sqrt
from operator import add, mul, xor
from typing import Callable, Collection, Iterator, TypeAlias, TypeVar
########################################################################
########################### Table of Contents ##########################
########################################################################
__version__ = '0.2.0'
__all__ = [
'Number', 'Vector', 'Matrix', 'Monomial', 'Polynomial', 'clear_cache',
# Primes
'is_prime', 'next_prime', 'random_prime', 'primes', 'count_primes', 'sum_primes',
# Factorization
'perfect_power', 'prime_factors', 'prime_factorization', 'divisors',
# Arithmetic Functions
'omega', 'big_omega', 'divisor_count', 'divisor_sum', 'divisor_function',
'partition', 'radical', 'mobius', 'totient', 'carmichael', 'valuation',
'multiplicative_range',
# Modular Arithmetic
'egcd', 'crt', 'coprimes', 'multiplicative_order', 'primitive_root',
'legendre', 'jacobi', 'kronecker', 'dirichlet_character',
# Nonlinear Congruences
'hensel', 'polynomial_roots', 'nth_roots', 'discrete_log',
# Diophantine Equations
'bezout', 'cornacchia', 'pell', 'conic', 'pythagorean_triples', 'pillai',
# Algebraic Systems
'solve_linear_system', 'solve_polynomial_system',
# Lattices
'lll_reduce', 'bkz_reduce', 'closest_vector', 'small_roots',
# Appendix
'integers', 'integer_pairs', 'alternating', 'below', 'lower_bound', 'permutation',
'is_square', 'iroot', 'ilog', 'fibonacci', 'fibonacci_index', 'polygonal',
'polygonal_index', 'periodic_continued_fraction', 'convergents', 'polynomial',
]
_NoSolutionError = type('_NoSolutionError', (Exception,), {})
_PrecisionError = type('_PrecisionError', (Exception,), {})
_T = TypeVar('T', bound='Number')
Number: TypeAlias = int | float | complex | Fraction
Real: TypeAlias = int | float | Fraction
Vector: TypeAlias = list[_T]
Matrix: TypeAlias = list[list[_T]]
Monomial: TypeAlias = tuple[int, ...]
Polynomial: TypeAlias = dict[Monomial, _T]
singleton = lru_cache(maxsize=1)
small_cache = lru_cache(maxsize=1024)
large_cache = lru_cache(maxsize=1048576)
def clear_cache():
"""
Clear all caches defined in this module.
"""
module = sys.modules[__name__]
for obj in vars(module).values():
if getattr(obj, '__module__', None) == __name__:
cache_clear = getattr(obj, 'cache_clear', None)
if callable(cache_clear):
cache_clear()
########################################################################
################################ Primes ################################
########################################################################
def is_prime(n: int) -> bool:
"""
Test if a given integer n is prime.
Uses a combination of trial division, the Miller-Rabin primality test
with deterministic bases, or the extra-strong variant of the Baillie-PSW
primality test (this variant has no known pseudoprimes in any range, and
has been computationally verified to have no counterexamples for all n < 2^64).
See: https://www.techneon.com/download/is.prime.32.base.data (MR hash for n < 2^32)
See: https://miller-rabin.appspot.com (other deterministic MR base sets)
See: https://ntheory.org/pseudoprimes.html (BPSW verification up to 2^64)
Parameters
----------
n : int
Integer to test for primality
"""
if (n & 1) == 0 or n < 3: # n is even or n < 3
return n == 2
if n < 256:
return n in _ODD_PRIMES_BELOW_256
if gcd(n, _PRIMORIAL_ODD_PRIMES_BELOW_256) > 1:
return False
if n < 65536: # n < 256^2, and n coprime to all primes < 256 implies n is prime
return True
# Check for Mersenne primes
if n.bit_length() == (k := n.bit_count()): # n = 2^k - 1
return _lucas_lehmer(k)
# Use deterministic set of Miller-Rabin bases for small n
if n < 132239:
return _miller_rabin(n, (814494960528 % n,))
if n < 4294967296:
# Use hash-based Miller-Rabin witness table for n < 2^32
h = (0xAD625B89 * n) >> 24 & 255
return _miller_rabin(n, _MILLER_RABIN_32_BIT_BASES[h:h+1])
if n < 55245642489451:
bases = (2, 141889084524735, 1199124725622454117, 11096072698276303650)
return _miller_rabin(n, (a % n for a in bases))
return _baillie_psw(n) # BPSW has zero known pseudoprimes
def next_prime(n: int) -> int:
"""
Get the smallest prime number greater than n.
Parameters
----------
n : int
Strict lower bound for prime number
"""
if n < 2:
return 2
a = (n + 1) | 1 # next odd number
while not is_prime(a):
a += 2
return a
def random_prime(num_bits: int, *, safe: bool = False) -> int:
"""
Generate a random prime with the given number of bits.
Parameters
----------
num_bits : int
Number of bits in the prime to be generated
safe : bool
Whether or not to generate a safe prime
(i.e. prime q of the form q = 2p + 1, where p is also prime)
"""
# Handle edge cases
if safe and num_bits < 3:
raise ValueError("Safe primes require num_bits >= 3")
if not safe and num_bits < 2:
raise ValueError("Primes require num_bits >= 2")
if not safe and num_bits == 2:
return secrets.randbelow(2) + 2
# Precompute bitmask
k = num_bits - 3 if safe else num_bits - 2 # number of random bits per candidate
batch_size = max(1, int(0.4 * k))
top_bit, mask = 1 << (k + 1), (1 << k) - 1
# Generate batches of random bits and test primality
while True:
batch = secrets.randbits(batch_size * k)
for _ in range(batch_size):
middle = batch & mask # all random bits except first/last
p = top_bit | (middle << 1) | 1 # force first/last bit to 1
if is_prime(p):
if safe:
if is_prime(q := 2*p + 1):
return q
else:
return p
batch >>= k
def primes(
*,
low: int = 2,
high: int | None = None,
count: int | None = None,
) -> Iterator[int]:
"""
Generate at most `count` primes in increasing order within the range `[low, high]`.
Uses the sieve of Eratosthenes, with a segmented approach for large or
unbounded ranges.
Parameters
----------
low : int
Lower bound for prime numbers
high : int
Upper bound for prime numbers (default is infinite)
count : int
Maximum number of primes to generate (default is infinite)
"""
DEFAULT_SIEVE_SIZE, MAX_SIEVE_SIZE = 1000, 100_000_000
low = max(low, 2)
high = inf if high is None else high
count = inf if count is None else count
# Initial list of small primes to use for the segmented sieve
small_odd_primes = [
3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41,
43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,
]
if low == 2 and count <= 25 and high <= 100:
yield from (p for p in [2, *small_odd_primes][:count] if p <= high)
return
# Generate initial prime
if low <= 2 <= high and count > 0:
yield 2
count -= 1
elif low > high or count <= 0:
return
# Set initial sieve size based on the prime number theorem
# When `high` is given, sieve on range [low, high]
# When `count` is given, sieve on range [low, n (log n + log log n)],
# where n is an upper bound on `π(low) + count`
if high == count == inf:
sieve_size = DEFAULT_SIEVE_SIZE
else:
n = count + 1.25506 * low / max(log(low), 1) # Rosser & Schoenfeld bound (1962)
upper_bound = n * (log(n) + log(log(n))) # upper bound on the nth prime
sieve_size = int(min(MAX_SIEVE_SIZE, high - low + 1, upper_bound - low))
# Generate additional primes
while low <= high and count > 0:
# If necessary, extend list of small primes via Bertrand intervals
while (p := small_odd_primes[-1]) < isqrt(low + sieve_size):
small_odd_primes.extend(_segmented_eratosthenes(p + 1, p, small_odd_primes))
# Get new primes with segmented sieve
new_primes = _segmented_eratosthenes(low, sieve_size, small_odd_primes)
if count < inf:
new_primes = tuple(itertools.islice(new_primes, count))
count -= len(new_primes)
# Yield new primes
yield from new_primes
# Update sieve range
low += sieve_size
sieve_size = min(2 * sieve_size, MAX_SIEVE_SIZE, high - low + 1)
def count_primes(x: int) -> int:
"""
Prime counting function π(x). Returns the number of primes p ≤ x.
Uses the Lagarias-Miller-Odlyzko (LMO) extension of the Meissel-Lehmer algorithm.
Parameters
----------
x : int
Upper bound for prime numbers
"""
if x < 10000:
return sum(1 for _ in primes(high=x))
thresholds = [(1000000, (5, 0.015)), (1000000000, (5, 0.008))]
k, c = _threshold_select(x, thresholds, default=(15, 0.003))
return _lmo(x, k=k, c=c)
def sum_primes(
x: int,
f: Callable[[int], Number] | None = None,
f_prefix_sum: Callable[[int], Number] | None = None,
) -> Number:
"""
Compute F(x) as the sum of f(p) over all primes p ≤ x,
where f is a completely multiplicative function (by default, f(n) = n).
Uses a generalized version of the LMO prime counting algorithm.
Ideally `f()` and `f_prefix_sum()` can be calculated efficiently
in O(1) time via closed-form expression.
Parameters
----------
x : int
Upper bound for prime numbers
f : Callable(int) -> Number
Completely multiplicative function f(n),
where f(1) = 1 and f(ab) = f(a) * f(b) for all a, b > 0
f_prefix_sum : Callable(int) -> Number
Function to compute the cumulative sum Σ_{1 ≤ k ≤ n} f(k)
"""
if f is None and f_prefix_sum is None:
if x < 10000:
return sum(primes(high=x))
else:
f, f_prefix_sum = _identity, (lambda n: n * (n + 1) // 2)
elif f is None or f_prefix_sum is None:
raise ValueError("Both f() and f_prefix_sum() must be provided")
if x < 10000:
return sum(f(p) for p in primes(high=x))
thresholds = [(100000, (5, 0.025)), (1000000, (5, 0.015)), (10000000, (5, 0.01))]
k, c = _threshold_select(x, thresholds, default=(15, 0.005))
return _lmo(x, k=k, c=c, f=f, f_prefix_sum=f_prefix_sum)
def _miller_rabin(n: int, bases: Iterable[int] | int = (2,)) -> bool:
"""
Miller-Rabin primality test over the given bases.
See: https://www.sciencedirect.com/science/article/pii/0022314X80900840
Complexity
----------
O(k log³n) for k bases, with worst-case error probability 4⁻ᵏ
"""
# Write n - 1 as 2^s * d with d odd
d = n - 1
s = (d & -d).bit_length() - 1
d >>= s
# Perform a Miller-Rabin test for each base sequentially
return _miller_rabin_worker(n, s, d, bases)
def _miller_rabin_worker(n: int, s: int, d: int, bases: Iterable[int] | int) -> bool:
"""
Miller-Rabin primality test for n over the given bases,
where n - 1 = 2^s * d with d odd.
See: https://www.sciencedirect.com/science/article/pii/0022314X80900840
"""
# Generate random bases, if specific bases have not been given
if isinstance(bases, int):
bases = (secrets.randbelow(n - 3) + 2 for _ in range(bases))
# Run a Miller-Rabin test for each base
for a in bases:
x = pow(a, d, n)
if x == n - 1 or x == 1:
continue # probable prime
for _ in range(s - 1):
x = pow(x, 2, n)
if x == n - 1:
break # probable prime
else:
return False # composite
return True # All bases passed
def _baillie_psw(n: int) -> bool:
"""
Baillie-PSW primality test for n. Uses an extra strong Lucas step.
There are no known counterexamples to this primality test,
and it has been computationally verified for all n < 2^64.
See: https://math.dartmouth.edu/~carlp/PDF/paper25.pdf
See: https://ntheory.org/pseudoprimes.html
Complexity
----------
O(log³n) time
"""
# Perform a Miller-Rabin test with base a = 2
if not _miller_rabin(n, bases=(2,)):
return False
# Reject perfect squares
if is_square(n):
return False
# Find a suitable D for the extra-strong Lucas test (D = P^2 - 4Q with Q = 1)
P = 3
while jacobi(P*P - 4, n) != -1:
P += 1
# Write n + 1 = 2^s * d with d odd
d = n + 1
s = (d & -d).bit_length() - 1
d >>= s
# Generate the Lucas sequence element V_d(P, Q) via binary Lucas chain
P %= n
V, V_next = P, (P*P - 2) % n # these represent V_k, V_{k+1}
for bit in format(d, 'b')[1:]:
if bit != '0':
V, V_next = (V * V_next - P) % n, (V_next * V_next - 2) % n
else:
V, V_next = (V * V - 2) % n, (V * V_next - P) % n
# 1st extra-strong condition: U_d = 0 (mod n) and V_d = ± 2 (mod n)
# Since gcd(D, n) = 1, U_d = 0 (mod n) <=> D * U_d = 2V_{d+1} - PV_d = 0 (mod n)
if V in (2, n - 2) and (2 * V_next - P * V) % n == 0:
return True
# 2nd extra-strong condition: V_{2^r * d} = 0 (mod n) for some 0 <= r < s - 1
for _ in range(s - 2):
if V == 0: return True
V = (V*V - 2) % n
return s > 1 and V == 0
def _lucas_lehmer(p: int) -> bool:
"""
Run the Lucas-Lehmer test for Mersenne primes of the form M_p = 2^p - 1.
Complexity
----------
O(p) multiplications
"""
if p == 2:
return True
# Use trial division to determine if p is prime
for q in primes(high=isqrt(p)):
if p % q == 0:
return False
# Perform Lucas-Lehmer test
s, M = 4, (1 << p) - 1
for _ in range(p - 2):
s = (s*s - 2) % M
return s == 0
def _segmented_eratosthenes(
start: int,
sieve_size: int,
odd_primes: Sequence[int],
) -> Iterable[int]:
"""
Segmented sieve of Eratosthenes.
Returns odd prime numbers in the range [start, start + sieve_size).
Expects sorted odd primes up to √(start + sieve_size).
Complexity
----------
O(n log log n) time and O(n) space for segment of size n
"""
# Initialize sieve segment
# Only odd numbers are stored in the sieve (sieve[i] corresponds to start + 2i)
start, end = start | 1, start + sieve_size
sieve_size = (end - start + 1) >> 1
sieve = bytearray(b'\x01') * sieve_size
zeros = bytearray(b'\x00') * sieve_size
# Handle small primes where p^2 <= start
cutoff = bisect.bisect_right(odd_primes, isqrt(start))
for p in odd_primes[:cutoff]:
# Find next odd multiple of p >= start
next_odd_multiple = start + (p - start) % (p + p)
# Mark multiples of p in the odd sieve
index = (next_odd_multiple - start) >> 1
count = (sieve_size - index + p - 1) // p
sieve[index::p] = zeros[:count]
# Handle large primes where p^2 > start
for p in odd_primes[cutoff:]:
if (p_squared := p * p) >= end:
break
# Mark multiples of p in the odd sieve
index = (p_squared - start) >> 1
count = (sieve_size - index + p - 1) // p
sieve[index::p] = zeros[:count]
return itertools.compress(range(start, start + 2 * sieve_size, 2), sieve)
def _lmo(
x: int,
k: int = 15,
c: float = 0.003,
f: Callable[[int], Number] | None = None,
f_prefix_sum: Callable[[int], Number] | None = None,
) -> Number:
"""
Lagarias-Miller-Odlyzko (LMO) extension of the Meissel-Lehmer algorithm.
Returns the value of the prime counting function π(x), i.e. the number of
primes less than or equal to x.
See: https://www-users.cse.umn.edu/~odlyzko/doc/arch/meissel.lehmer.pdf
See: https://arxiv.org/pdf/2111.15545
Also includes a generalized version that calculates the sum F(x) = Σ f(p)
for all primes p ≤ x, where f is any arbitrary completely multiplicative function.
The generalized LMO sub-expressions become:
P₂ = Σ f(p) * [F(x/p) − F(p − 1)] for y < p ≤ sqrt(x)
φ_f(x, a) = φ_f(x, a - 1) - f(pₐ) * φ_f(x/pₐ, a - 1)
S₁ = Σ μ(n) f(n) φ_f(x/n, k) over ordinary leaves (n, k)
S₂ = Σ μ(n) f(n) φ_f(x/n, b) over special leaves (n, b)
and the generalized Meissel-Lehmer formula becomes:
F(x) = F(y) - 1 - P₂ + φ_f(x, a) = F(y) - 1 - P₂ + S₁ + S₂.
Ideally `f()` and `f_prefix_sum()` can be calculated efficiently in O(1) time
via closed-form expression.
Complexity
----------
O(x²ᐟ³ / log x) time and O(x¹ᐟ³ log²x) space with hyperparameter
y = c * x¹ᐟ³ log² x, assuming f() and f_prefix_sum() are O(1).
"""
if x < 2:
return 0
# Set hyperparameter y = cx^(1/3) log^2(x) such that x^(1/3) <= y <= x^(2/5)
# where y is the upper bound on the small primes that are computed directly
y = int(c * iroot(x, 3) * (log(x) ** 2))
y = min(max(y, iroot(x, 3)), iroot(x * x, 5))
y = max(y, 2) # we need y >= 2 to use an odd-only sieve starting at y + 1
# Count primes up to y
small_primes = tuple(primes(high=y))
a = len(small_primes)
F_y = a if f is None else sum(map(f, small_primes))
# Set number of precomputed stages of special leaf sieving
k = min(max(k, 1), a)
# Evaluate the 2nd-order partial sieve function P2(x, a)
# This is the prefix sum Σ f(n) over all n <= x with exactly 2 prime factors,
# that are both greater than p_a
P2 = _lmo_p2(x, y, F_y, small_primes, f)
# Compute the least prime factor (lpf) and Mobius (μ) functions
# for integers 1 ... y by iterating over the primes in reverse order
lpf, mu = [0] * (y + 1), [1] * (y + 1)
for p in reversed(small_primes):
mu[p*p::p*p] = [0] * (y // (p*p))
mu[p::p] = [-value for value in mu[p::p]]
lpf[p::p] = [p] * (y // p)
# Sum the leaves in the tree created by either
# the standard recurrence φ(x, a) = φ(x, a - 1) - φ(x/p_a, a - 1)
# or the weighted recurrence φ_f(x, a) = φ_f(x, a - 1) - f(p_a) * φ_f(x/p_a, a - 1)
S1 = _lmo_s1(x, k, mu, small_primes, f, f_prefix_sum) # sum over ordinary leaves
S2 = _lmo_s2(x, k, lpf, mu, small_primes, f) # sum over special leaves
return F_y - 1 - P2 + S1 + S2
def _lmo_p2(
x: int,
y: int,
F_y: Number,
small_primes: tuple[int, ...],
f: Callable[[int], Number] | None = None,
block_size: int = 64,
) -> Number:
"""
Compute P2(x, a) from the LMO algorithm.
This is the prefix sum Σ f(n) over all n ≤ x with exactly 2 prime factors,
both greater than the a-th prime.
"""
sqrt_x = isqrt(x)
sieve_limit = x // y
sieve_start = (y + 1) | 1 # round up to odd
sieve_size = y + (y & 1) # round up to even
# Compute a generalized P2(x, a) = sum_{y < p <= sqrt(x)} f(p) * [F(x/p) − F(p − 1)]
# Find the weighted sum f(p) * F(x/p) for all primes in the interval (y, sqrt(x)]
# Or equivalently, the sum over all x/p in the inverse interval [sqrt(x), x/y)
# Also accumulate the sum f(p)^2 for all primes in the interval (y, sqrt(x)]
P2 = 0
sum_f2 = 0
F_sqrt_x = F_prev = F_y
F_segment = [F_y]
for low in range(sieve_start, sieve_limit + 1, sieve_size):
# Sieve the interval [low, high)
# Only odd numbers are stored in the sieve (sieve[i] corresponds to low + 2i)
high = min(low + sieve_size, sieve_limit + 1)
sieve = _lmo_odd_sieve(low, high - low, small_primes[1:], max_prime=isqrt(high))
# Find all primes p ∈ (y, sqrt(x)] such that low <= x/p < high
# by similarly sieving the inverse interval (x/high, x/low]
low_ = (max(x // high, y) + 1) | 1
high_ = min(x // low, sqrt_x)
sieve_ = _lmo_odd_sieve(
low_, high_ - low_ + 1, small_primes[1:], max_prime=isqrt(high_))
segment_primes = itertools.compress(range(low_, high_ + 1, 2), sieve_)
# Get f(t) for t ∈ [low, high)
# Also calculate prime sums F(t) = sum_{p <= t} f(p) for t ∈ [low, high)
if f is not None:
f2_primes = itertools.compress(range(low, min(high, sqrt_x + 1), 2), sieve)
sum_f2 += sum(f(p)**2 for p in f2_primes)
f_segment = [f(low + 2*i) if sieve[i] else 0 for i in range(len(sieve))]
F_segment = list(itertools.accumulate(f_segment, initial=F_prev))[1:]
if low <= sqrt_x < high:
F_sqrt_x = F_segment[(sqrt_x - low) >> 1]
# Accumulate over all x/p in our main interval [low, high)
P2 += sum(f(p) * F_segment[(x // p - low) >> 1] for p in segment_primes)
F_prev = F_segment[-1]
else:
blocks = [sieve[i:i+block_size] for i in range(0, len(sieve), block_size)]
block_sums = (block.count(1) for block in blocks)
block_prefix_sums = list(itertools.accumulate(block_sums, initial=0))
def pi(x):
index = (x - low) >> 1
block_index, offset = divmod(index, block_size)
block = blocks[block_index]
count = F_prev + block_prefix_sums[block_index]
return count + block[:offset+1].count(1)
if low <= sqrt_x < high:
F_sqrt_x = pi(sqrt_x)
# Accumulate over all x/p in our main interval [low, high)
P2 += sum(pi(x // p) for p in segment_primes)
F_prev = F_prev + block_prefix_sums[-1]
if f is None:
sum_f2 = F_sqrt_x - F_y
# Now subtract sum_{y < p <= sqrt(x)} f(p) * F(p − 1)
# We can use the telescoping identity with a_i = f(p_i), A_i = F(p_i)
# which is A_i^2 - A_{i-1}^2 = 2 a_i A_{i-1} + a_i^2
# Over y < p_i <= sqrt(x), the sum Σ f(p) * F(p − 1) = Σ a_i A_{i-1}
# becomes 1/2 [F(sqrt(x))^2 - F(y)^2 - Σ f(p)^2]
is_int = isinstance(sum_f2, int)
double_count_sum = F_sqrt_x*F_sqrt_x - F_y*F_y - sum_f2
double_count_sum = double_count_sum // 2 if is_int else double_count_sum / 2
return P2 - double_count_sum
def _lmo_s1(
x: int,
k: int,
mu: list[int],
small_primes: tuple[int, ...],
f: Callable[[int], Number] | None = None,
f_prefix_sum: Callable[[int], Number] | None = None,
) -> Number:
"""
Calculate the S₁ portion of the LMO algorithm.
This is the sum over "ordinary leaves" (i.e. of the form ± φ(x/n, k) with n <= y)
in the tree created by the standard recurrence φ(x, a) = φ(x, a-1) - φ(x/pₐ, a-1),
or the weighted recurrence φ_f(x, a) = φ_f(x, a-1) - f(pₐ) * φ_f(x/pₐ, a-1).
"""
if f is None:
phi = partial(_phi_prime_count, small_primes=small_primes[:k])
elif f == _identity:
phi = partial(_phi_prime_sum, small_primes=small_primes[:k])
else:
phi = lambda x, a: f_prefix_sum(x) if a == 0 else (
phi(x, a - 1) - f(p := small_primes[a - 1]) * phi(x // p, a - 1))
S1 = phi(x, k)
a, y = len(small_primes), len(mu) - 1
leaves = [(i + 1, small_primes[i]) for i in range(k, a)]
while leaves:
b, n = leaves.pop()
S1 += mu[n] * phi(x // n, k) * (f(n) if f else 1)
for i in range(b, a):
m = n * small_primes[i]
if m > y: break
leaves.append((i + 1, m))
return S1
def _lmo_s2(
x: int,
k: int,
lpf: list[int],
mu: list[int],
small_primes: tuple[int, ...],
f: Callable[[int], Number] | None = None,
) -> Number:
"""
Calculate the S₂ portion of the LMO algorithm.
This is the sum over "special leaves" (i.e. of the form ± φ(x/n, b) with n > y)
in the tree created by the standard recurrence φ(x, a) = φ(x, a-1) - φ(x/pₐ, a-1),
or the weighted recurrence φ_f(x, a) = φ_f(x, a-1) - f(pₐ) * φ_f(x/pₐ, a-1).
"""
S2 = 0
a, y = len(small_primes), len(mu) - 1
if k >= a: return 0
phi = [0] * a
sieve_limit = x // y
sieve_size = isqrt(sieve_limit) - 1
sieve_size = 2**(sieve_size.bit_length()) # round up to next power of 2
tree_size = sieve_size // 2
for low in range(1, sieve_limit, sieve_size):
# Sieve the segment [low, high) with the first k primes
# Only odd numbers are stored in the sieve (sieve[i] corresponds to low + 2i)
# sieve[i] is True if and only if low + 2i is coprime to the first k primes
high = min(low + sieve_size, sieve_limit)
odd_sieve = _lmo_odd_sieve(low, sieve_size, small_primes[1:k])
# Initialize a Binary Indexed Tree
if f is None:
tree = _fenwick_tree_init(odd_sieve)
else:
values = [f(low + 2*i) if s else 0 for i, s in enumerate(odd_sieve)]
tree = _fenwick_tree_init(values)
# Sieve the segment [low, high) with the remaining primes
# Any part of the sieve or tree outside this range is ignored
for b in range(k, a):
p = small_primes[b]
min_m = max(x // (p * high), y // p)
max_m = min(x // (p * low), y)
if p >= max_m: break
# Find special leaves in the tree (i.e. φ(x/n, b) where n > y)
for m in range(max_m, min_m, -1):
if p < lpf[m] and mu[m] != 0:
# Compute φ(x/(pm), b) by adding contributions from remaining
# elements after sieving the first b primes
# μ(pm) * f(pm) * φ_f(x/(pm), b) = -μ(m) * f(p) * f(m) * φ_f(...)
index = (x // (p * m) - low) >> 1
phi_xn = phi[b] + _fenwick_tree_query(tree, index)
S2 -= mu[m] * phi_xn * (f(m) * f(p) if f else 1)
# Store the accumulated sum over unsieved elements
phi[b] += _fenwick_tree_query(tree, tree_size - 1)
# Mark odd prime multiples in the sieve
# Update the tree for each element being marked for the first time
next_odd_prime_multiple = (((low + p - 1) // p) | 1) * p
for index in range((next_odd_prime_multiple - low) >> 1, tree_size, p):
if odd_sieve[index]:
odd_sieve[index] = False
value = values[index] if f else 1
_fenwick_tree_update(tree, index, -value, tree_size)
return S2
def _lmo_odd_sieve(
start: int,
sieve_size: int,
odd_primes: Sequence[int],
max_prime: int | None = None,
) -> bytearray:
"""
Sieve the interval [start, start + sieve_size) using the given primes.
Returns a sieve of odd numbers that are coprime to the given primes.
"""
# Initialize sieve segment
# Only odd numbers are stored in the sieve (sieve[i] corresponds to start + 2i)
start, end = start | 1, start + sieve_size
sieve_size = (end - start + 1) >> 1
sieve = bytearray(b'\x01') * sieve_size
zeros = bytearray(b'\x00') * sieve_size
# Iterate over primes
for p in odd_primes:
if max_prime and p > max_prime: break
# Find next odd multiple of p >= start
next_odd_multiple = start + (p - start) % (p + p)
# Mark multiples of p in the odd sieve
index = (next_odd_multiple - start) >> 1
count = (sieve_size - index + p - 1) // p
sieve[index::p] = zeros[:count]
return sieve
def _fenwick_tree_init(values: Iterable[Number]) -> list[Number]:
"""
Create a Binary Indexed Tree (Fenwick Tree) from the given values.
"""
tree = list(values)
for index, parent_index in _fenwick_tree_edges(len(tree)):
tree[parent_index] += tree[index]
return tree
def _fenwick_tree_query(tree: list[Number], index: int) -> Number:
"""
Query the prefix sum for the tree at the given index.
"""
total = 0
for i in _fenwick_tree_query_path(index):
total += tree[i]
return total
def _fenwick_tree_update(tree: list[Number], index: int, value: Number, tree_size: int):
"""
Update the given index in the tree.
"""
for i in _fenwick_tree_update_path(index, tree_size):
tree[i] += value
@small_cache
def _fenwick_tree_edges(tree_size: int) -> tuple[tuple[int, int], ...]:
"""
Get all (index, parent_index) pairs for a Binary Indexed Tree (Fenwick Tree).
"""
return tuple(
(index, index | (index + 1))
for index in range(tree_size - 1)
if index | (index + 1) < tree_size
)
@large_cache
def _fenwick_tree_query_path(index: int) -> tuple[int, ...]:
"""
Get all indices that need to be queried for a prefix sum.
"""
path, index = [], index + 1
while index > 0:
path.append(index - 1)
index &= index - 1 # clears the lowest set bit
return tuple(path)
@large_cache
def _fenwick_tree_update_path(index: int, tree_size: int) -> tuple[int, ...]:
"""
Get all indices that need to be updated for a value change.
"""
path = []
while index < tree_size:
path.append(index)
index |= index + 1 # sets the lowest unset bit
return tuple(path)
@large_cache
def _phi_prime_count(x: int, a: int, small_primes: tuple[int, ...]) -> int:
"""
Evaluate Legendre's partial sieve function φ(x, a),
which counts the number of positive integers ≤ x coprime to the first a primes.
"""
if a == 0:
return x
elif a < 8:
# Use the direct formula φ(x, a) = (x/P) * φ(P) + φ(x % P, a)
q, r = divmod(x, P := _primorial(a))
totient_P = prod(p - 1 for p in small_primes[:a])
return q * totient_P + _phi_prime_count_offsets(P)[r]
else:
# Use the recursive formula φ(x, a) = φ(x, a - 1) - φ(x/p, a - 1)
p = small_primes[a - 1]
return (
_phi_prime_count(x, a - 1, small_primes)
- _phi_prime_count(x // p, a - 1, small_primes)
)
@small_cache
def _phi_prime_count_offsets(d: int) -> tuple[int, ...]:
"""
Compute values for Legendre's partial sieve function φ(r, a) for r = 0, 1 ... d - 1,
where d is the product of the first a primes.
"""
return tuple(itertools.accumulate(_coprime_range(d)))
@large_cache
def _phi_prime_sum(x: int, a: int, small_primes: tuple[int, ...]) -> int:
"""
Evaluate Legendre's partial sieve function φ_f(x, a) for f(n) = n,
which gives the sum of positive integers ≤ x coprime to the first a primes.
"""
if a == 0:
return x * (x + 1) // 2 # sum of all integers <= x
elif a == 1:
return ((x + 1) // 2)**2 # sum of odd integers <= x
elif a < 8:
# Use direct formula based on periodicity of coprimes mod P
q, r = divmod(x, P := _primorial(a))
count_coprimes, sum_coprimes = _phi_prime_sum_offsets(P)[r]
return P * q * (q * totient(P) // 2 + count_coprimes) + sum_coprimes
else:
# Use the recurrence φ_f(x, a) = φ_f(x, a - 1) - f(p_a) * φ_f(x/p_a, a - 1)
p = small_primes[a - 1]
return (
_phi_prime_sum(x, a - 1, small_primes)
- p * _phi_prime_sum(x // p, a - 1, small_primes)
)
@small_cache
def _phi_prime_sum_offsets(d: int) -> tuple[tuple[int, int], ...]:
"""
Compute cumulative counts/sums for the weighted Legendre partial sieve function
with f(n) = n. Returns offsets[r] = (φ(r, a), φ_f(r, a)) for r = 0, 1 ... d - 1,
where d is the product of the first a primes, φ(r, a) counts and φ_f(r, a) sums
integers ≤ r coprime to the first a primes.
"""
is_coprime = _coprime_range(d)
counts = itertools.accumulate(is_coprime)
sums = itertools.accumulate(map(mul, range(d), is_coprime))
return tuple(zip(counts, sums))
@small_cache
def _primorial(n: int) -> int:
"""
Calculate the product of the first n primes.
"""
return prod(primes(count=n))
########################################################################
############################ Factorization #############################
########################################################################
def perfect_power(n: int) -> tuple[int, int]:
"""
Find integers a, b such that a^b = n.
Returns the solution (a, b) with minimal b > 1 if there are any such solutions,
otherwise returns the trivial solution (n, 1).
Parameters
----------
n : int
Integer target
"""
if n in (0, 1):
return (n, 2)
if n == -1:
return (-1, 3)
# Handle square roots
n = -n if (is_negative := n < 0) else n
if not is_negative and (n & 0xF) in (0, 1, 4, 9) and (r := isqrt(n)) * r == n:
return (r, 2)
# Try to find a small prime factor and its multiplicity
multiplicity = 0
if n & 1 == 0:
multiplicity = (n & -n).bit_length() - 1
elif (g := gcd(n, _PRIMORIAL_ODD_PRIMES_BELOW_256)) > 1:
multiplicity = next(valuation(n, p) for p in _ODD_PRIMES_BELOW_256 if not g % p)
# Calculate maximum possible exponent
max_exponent = n.bit_length() - 1
if multiplicity == 0:
max_exponent = min(max_exponent, ilog(n, 257))
if multiplicity == 1 or max_exponent < 3:
return (-n if is_negative else n, 1)
# If we know multiplicity, only check its odd prime divisors
if multiplicity > 2:
m = multiplicity
m >>= (m & -m).bit_length() - 1 # remove factors of 2
# Trial division to find and check odd prime factors in order
for p in _ODD_PRIMES_BELOW_256:
if m % p == 0:
if p <= max_exponent and pow(r := iroot(n, p), p) == n:
return ((-r if is_negative else r), p)
while m % p == 0:
m //= p
# Find m-th root
if 1 < m <= max_exponent and pow(r := iroot(n, m), m) == n:
return ((-r if is_negative else r), m)
else:
# Check all odd primes
for p in primes(low=3, high=max_exponent):
if pow(r := iroot(n, p), p) == n:
return ((-r if is_negative else r), p)
return (-n if is_negative else n, 1)
def prime_factors(n: int) -> tuple[int, ...]:
"""
Get all prime factors of n in sorted order (with multiplicity).
Uses a combination of trial division, Fermat's factorization method,
Brent's variant of Pollard's rho, Lenstra's elliptic curve method (ECM),
and a self-initializing quadratic sieve (SIQS).
Parameters
----------
n : int
Integer to factor
"""
return tuple(sorted(_gen_prime_factors(n)))
def prime_factorization(n: int) -> dict[int, int]:
"""