540. 统计本原毕达哥拉斯三元组(Counting Primitive Pythagorean Triples)
一个毕达哥拉斯三元组由满足 \(a^2+b^2=c^2\) 的三个正整数 \(a,b,c\) 组成。
若 \(a,b,c\) 两两互素(整体最大公约数为 1),则称该三元组为本原三元组。
记 \(P(n)\) 为满足 \(a<b<c\le n\) 的本原毕达哥拉斯三元组的个数。
例如 \(P(20)=3\),因为共有三个三元组:\((3,4,5)\)、\((5,12,13)\)、\((8,15,17)\)。
已知 \(P(10^6)=159139\)。求 \(P(3141592653589793)\)。
分析:这题的难点不在公式推导本身,而在“如何把计数复杂度压到可运行规模”。直接按 Euclid 参数化枚举 \((m,n)\) 会触达上亿量级,必须用数论反演把“互素约束”转为可批处理的求和,再配合缓存与分块降复杂度。
一、数学背景
对本原毕达哥拉斯三元组,Euclid 参数化为:
$$
a=m^2-n^2,\quad b=2mn,\quad c=m^2+n^2,\quad m>n,\ \gcd(m,n)=1,\ m-n\text{ 为奇数}.
$$
题目要求 \(a<b<c\le N\),其计数核心可以转写为“满足二次约束的整数点计数 + 互素筛选 + 奇偶筛选”。
令
$$
R(x)=\#\{(u,v)\mid 1\le u<v,\ u^2+v^2\le x\},
$$
它是一个不带互素条件的几何计数函数。
再定义
$$
Q(x)=\sum_{d\le \sqrt{x}} \mu(d)\,R\!\left(\left\lfloor \frac{x}{d^2}\right\rfloor\right),
$$
这里 \(\mu\) 是莫比乌斯函数。该式用莫比乌斯反演完成“互素过滤”。
最后用 2 进制容斥处理奇偶限制,可得:
$$
P(N)=\sum_{k\ge 0}(-1)^k\,Q\!\left(\left\lfloor\frac{N}{2^k}\right\rfloor\right).
$$
二、算法设计
候选方案对比
- 方案 A(直接枚举 Euclid 参数)
- 按 \(m\) 枚举到 \(\sqrt{N}\),再按 \(n\) 检查互素与奇偶。
-
时间复杂度接近 \(O(N)\) 量级(常数虽不大但规模过大),不可行。
-
方案 B(莫比乌斯反演 + 分块求和 + 缓存)
- 用 \(Q(x)\) 与外层容斥计算 \(P(N)\)。
- 对 \(Q(x)\) 里的 \(\left\lfloor x/d^2\right\rfloor\) 进行整除分块:把产生相同商值的一整段 \(d\) 合并,用 Mertens 前缀和一次结算系数。
- 对 \(R(x)\) 做记忆化缓存,避免重复几何计数。
- 这是可在秒级跑完的可行方案。
关键实现点
- 莫比乌斯筛:先计算 \(\mu(1\ldots \lfloor\sqrt{N}\rfloor)\)。
- Mertens 前缀和:\(M(t)=\sum_{i\le t}\mu(i)\),用于区间系数快速求和。
- 几何计数
R(x):双指针统计圆内点对 \((u,v)\),避免每轮重复开方。 - 分块
Q(x):当 \(\left\lfloor x/d^2\right\rfloor\) 不变时,把一段 \(d\) 合并计算。 - 外层容斥:按 \(x=\lfloor N/2^k\rfloor\) 递减求和,符号交替。
- 性能加速:核心循环使用
numbaJIT;若无numba会自动回退纯 Python。
三、复杂度分析
- 莫比乌斯筛与前缀和预处理:约 \(O(\sqrt{N}\log\log \sqrt{N})\)。
- 每次
Q(x)用分块后,循环次数远小于 \(\sqrt{x}\)。 R(x)经缓存后,每个不同参数只计算一次。- 在本机实测(
N=3141592653589793): - 3 次运行耗时约为
4.869s / 3.871s / 3.815s - 平均约
4.185s,最大约4.869s - 满足小于 60 秒的要求。
四、代码实现与说明
"""Project Euler 540 - Counting Primitive Pythagorean Triples.
This solver includes an optional Numba acceleration path.
"""
from __future__ import annotations
from functools import lru_cache
from math import isqrt
from time import perf_counter
TARGET_N = 3_141_592_653_589_793
SAMPLE_N = 10**6
SAMPLE_ANSWER = 159_139
try:
import numpy as np
from numba import njit
HAS_NUMBA = True
except Exception:
HAS_NUMBA = False
if HAS_NUMBA:
@njit(cache=False)
def isqrt_numba(n: int) -> int:
"""Return floor(sqrt(n)) for positive integer n in nopython mode."""
x = int(n**0.5)
while (x + 1) * (x + 1) <= n:
x += 1
while x * x > n:
x -= 1
return x
@njit(cache=False)
def mobius_sieve_numba(limit: int):
"""Return Mu[0..limit] using an Eratosthenes-style Mobius sieve."""
mu = np.ones(limit + 1, dtype=np.int8)
mu[0] = 0
is_composite = np.zeros(limit + 1, dtype=np.uint8)
for p in range(2, limit + 1):
if is_composite[p] != 0:
continue
for multiple in range(p, limit + 1, p):
is_composite[multiple] = 1
mu[multiple] = -mu[multiple]
square = p * p
if square <= limit:
for multiple in range(square, limit + 1, square):
mu[multiple] = 0
return mu
@njit(cache=False)
def build_mertens_prefix_numba(mu):
"""Return prefix sums M(n)=sum_{k<=n} mu(k) for fast grouped summation."""
mertens = np.zeros(mu.size, dtype=np.int32)
running = 0
for i in range(1, mu.size):
running += int(mu[i])
mertens[i] = running
return mertens
@njit(cache=False)
def count_ordered_pairs_numba(n: int) -> int:
"""Count pairs (x,y) with 1<=x<y and x^2+y^2<=n."""
x = 1
x_sq = 1
y = isqrt_numba(n)
y_sq = y * y
total = 0
while x < y:
while x_sq + y_sq > n:
y_sq -= 2 * y - 1
y -= 1
if x >= y:
break
total += y - x
x += 1
x_sq += 2 * x - 1
return total
def mobius_sieve(limit: int):
"""Return Mu[0..limit] using an Eratosthenes-style Mobius sieve."""
if HAS_NUMBA:
return mobius_sieve_numba(limit)
mu = [1] * (limit + 1)
mu[0] = 0
is_composite = bytearray(limit + 1)
for p in range(2, limit + 1):
if is_composite[p]:
continue
for multiple in range(p, limit + 1, p):
is_composite[multiple] = 1
mu[multiple] = -mu[multiple]
square = p * p
if square <= limit:
for multiple in range(square, limit + 1, square):
mu[multiple] = 0
return mu
def build_mertens_prefix(mu):
"""Return prefix sums M(n)=sum_{k<=n} mu(k) for fast grouped summation."""
if HAS_NUMBA:
return build_mertens_prefix_numba(mu)
mertens = [0] * len(mu)
running = 0
for i in range(1, len(mu)):
running += mu[i]
mertens[i] = running
return mertens
def count_ordered_pairs(n: int) -> int:
"""Count pairs (x,y) with 1<=x<y and x^2+y^2<=n."""
if HAS_NUMBA:
return int(count_ordered_pairs_numba(n))
x = 1
x_sq = 1
y = isqrt(n)
y_sq = y * y
total = 0
while x < y:
while x_sq + y_sq > n:
y_sq -= 2 * y - 1
y -= 1
if x >= y:
break
total += y - x
x += 1
x_sq += 2 * x - 1
return total
def primitive_triples_upto(n: int) -> int:
"""Return P(n): primitive Pythagorean triples with a<b<c<=n."""
max_d = isqrt(n)
mu = mobius_sieve(max_d)
mertens = build_mertens_prefix(mu)
@lru_cache(maxsize=None)
def r(value: int) -> int:
"""Memoized wrapper for R(value)."""
return count_ordered_pairs(value)
@lru_cache(maxsize=None)
def q(value: int) -> int:
"""Compute Q(value)=sum mu(d)*R(value//d^2) with grouped d-ranges."""
total = 0
d = 1
limit = isqrt(value)
while d <= limit:
quotient = value // (d * d)
right = isqrt(value // quotient)
coeff = int(mertens[right]) - int(mertens[d - 1])
if coeff:
total += coeff * r(quotient)
d = right + 1
return total
answer = 0
sign = 1
factor = 1
while factor <= n:
answer += sign * q(n // factor)
factor <<= 1
sign = -sign
return answer
def solve_and_measure(n: int) -> tuple[int, float]:
"""Return (P(n), elapsed_seconds)."""
start = perf_counter()
value = primitive_triples_upto(n)
elapsed = perf_counter() - start
return value, elapsed
def main() -> None:
"""Run sample verification and print the final answer."""
assert primitive_triples_upto(20) == 3, "P(20) must be 3"
sample_value, sample_elapsed = solve_and_measure(SAMPLE_N)
assert sample_value == SAMPLE_ANSWER, (
f"P(10^6) mismatch: got {sample_value}, expected {SAMPLE_ANSWER}"
)
print(f"P(10^6) = {sample_value}")
print(f"sample_elapsed_seconds = {sample_elapsed:.6f}")
final_value, final_elapsed = solve_and_measure(TARGET_N)
print(f"P({TARGET_N}) = {final_value}")
print(f"final_elapsed_seconds = {final_elapsed:.6f}")
if __name__ == "__main__":
main()
常量与开关
TARGET_N、SAMPLE_N、SAMPLE_ANSWER分别定义目标规模与样例校验值。HAS_NUMBA用于识别是否可启用 JIT 加速,保证同一份代码兼容两种运行环境。
isqrt_numba
x = int(n**0.5)先给一个浮点近似初值。- 之后用两个
while分别上修与下修,确保结果严格等于 \(\lfloor\sqrt n\rfloor\)。 - 该函数只在
numba路径下使用,避免 nopython 模式里缺失math.isqrt支持。
mobius_sieve_numba 与 mobius_sieve
- 初始化
mu全 1,随后对每个素数因子翻转符号,体现“不同素因子个数奇偶性”。 - 对每个 \(p^2\) 的倍数直接置 0,表示该数含平方因子。
- Python 版本与 Numba 版本逻辑一致,保证结果一致性。
build_mertens_prefix_numba 与 build_mertens_prefix
- 逐项累加得到 Mertens 前缀和数组
M。 - 之后在
Q(x)中可用M[r]-M[l-1]一次取出一段莫比乌斯和,替代逐项循环。
count_ordered_pairs_numba 与 count_ordered_pairs
- 目标是计算 \(1\le x<y,\ x^2+y^2\le n\) 的整数点对数。
- 固定
x并维护可行y的上界;当超出圆时持续减小y。 - 每轮可直接累加
y-x个点,避免内层枚举y。 - 双指针单调移动,整段只线性扫描一次。
primitive_triples_upto
- 先预处理
mu与mertens,为后续分块做准备。 - 内部
r(value)缓存R(value),同一参数不会重复算几何计数。 - 内部
q(value)对 \(\lfloor value/d^2\rfloor\) 做分块: quotient是当前商值。right是该商值对应的最右端d。coeff用前缀和一次性取整段莫比乌斯系数。total += coeff * r(quotient)完成本段贡献。- 外层
while factor <= n按 \(n, \lfloor n/2\rfloor, \lfloor n/4\rfloor,\dots\) 做交替求和,得到最终 \(P(n)\)。
solve_and_measure 与 main
solve_and_measure用于统一计时逻辑。main先做P(20)=3与P(10^6)=159139两级验证,再计算目标规模结果。- 该流程确保“先样例、后目标”,并可直接输出耗时用于性能检查。