540. 统计本原毕达哥拉斯三元组(Counting Primitive Pythagorean Triples)

一个毕达哥拉斯三元组由满足 \(a^2+b^2=c^2\) 的三个正整数 \(a,b,c\) 组成。
\(a,b,c\) 两两互素(整体最大公约数为 1),则称该三元组为本原三元组。
\(P(n)\) 为满足 \(a<b<c\le n\)本原毕达哥拉斯三元组的个数。
例如 \(P(20)=3\),因为共有三个三元组:\((3,4,5)\)\((5,12,13)\)\((8,15,17)\)
已知 \(P(10^6)=159139\)。求 \(P(3141592653589793)\)

分析:这题的难点不在公式推导本身,而在“如何把计数复杂度压到可运行规模”。直接按 Euclid 参数化枚举 \((m,n)\) 会触达上亿量级,必须用数论反演把“互素约束”转为可批处理的求和,再配合缓存与分块降复杂度。

一、数学背景

对本原毕达哥拉斯三元组,Euclid 参数化为:

$$ a=m^2-n^2,\quad b=2mn,\quad c=m^2+n^2,\quad m>n,\ \gcd(m,n)=1,\ m-n\text{ 为奇数}. $$

题目要求 \(a<b<c\le N\),其计数核心可以转写为“满足二次约束的整数点计数 + 互素筛选 + 奇偶筛选”。

$$ R(x)=\#\{(u,v)\mid 1\le u<v,\ u^2+v^2\le x\}, $$

它是一个不带互素条件的几何计数函数。

再定义

$$ Q(x)=\sum_{d\le \sqrt{x}} \mu(d)\,R\!\left(\left\lfloor \frac{x}{d^2}\right\rfloor\right), $$

这里 \(\mu\) 是莫比乌斯函数。该式用莫比乌斯反演完成“互素过滤”。

最后用 2 进制容斥处理奇偶限制,可得:

$$ P(N)=\sum_{k\ge 0}(-1)^k\,Q\!\left(\left\lfloor\frac{N}{2^k}\right\rfloor\right). $$

二、算法设计

候选方案对比

关键实现点

三、复杂度分析

四、代码实现与说明

"""Project Euler 540 - Counting Primitive Pythagorean Triples.

This solver includes an optional Numba acceleration path.
"""

from __future__ import annotations

from functools import lru_cache
from math import isqrt
from time import perf_counter

TARGET_N = 3_141_592_653_589_793
SAMPLE_N = 10**6
SAMPLE_ANSWER = 159_139

try:
    import numpy as np
    from numba import njit

    HAS_NUMBA = True
except Exception:
    HAS_NUMBA = False


if HAS_NUMBA:

    @njit(cache=False)
    def isqrt_numba(n: int) -> int:
        """Return floor(sqrt(n)) for positive integer n in nopython mode."""
        x = int(n**0.5)
        while (x + 1) * (x + 1) <= n:
            x += 1
        while x * x > n:
            x -= 1
        return x


    @njit(cache=False)
    def mobius_sieve_numba(limit: int):
        """Return Mu[0..limit] using an Eratosthenes-style Mobius sieve."""
        mu = np.ones(limit + 1, dtype=np.int8)
        mu[0] = 0
        is_composite = np.zeros(limit + 1, dtype=np.uint8)

        for p in range(2, limit + 1):
            if is_composite[p] != 0:
                continue
            for multiple in range(p, limit + 1, p):
                is_composite[multiple] = 1
                mu[multiple] = -mu[multiple]
            square = p * p
            if square <= limit:
                for multiple in range(square, limit + 1, square):
                    mu[multiple] = 0
        return mu


    @njit(cache=False)
    def build_mertens_prefix_numba(mu):
        """Return prefix sums M(n)=sum_{k<=n} mu(k) for fast grouped summation."""
        mertens = np.zeros(mu.size, dtype=np.int32)
        running = 0
        for i in range(1, mu.size):
            running += int(mu[i])
            mertens[i] = running
        return mertens


    @njit(cache=False)
    def count_ordered_pairs_numba(n: int) -> int:
        """Count pairs (x,y) with 1<=x<y and x^2+y^2<=n."""
        x = 1
        x_sq = 1
        y = isqrt_numba(n)
        y_sq = y * y
        total = 0

        while x < y:
            while x_sq + y_sq > n:
                y_sq -= 2 * y - 1
                y -= 1
            if x >= y:
                break
            total += y - x
            x += 1
            x_sq += 2 * x - 1
        return total


def mobius_sieve(limit: int):
    """Return Mu[0..limit] using an Eratosthenes-style Mobius sieve."""
    if HAS_NUMBA:
        return mobius_sieve_numba(limit)

    mu = [1] * (limit + 1)
    mu[0] = 0
    is_composite = bytearray(limit + 1)

    for p in range(2, limit + 1):
        if is_composite[p]:
            continue
        for multiple in range(p, limit + 1, p):
            is_composite[multiple] = 1
            mu[multiple] = -mu[multiple]
        square = p * p
        if square <= limit:
            for multiple in range(square, limit + 1, square):
                mu[multiple] = 0
    return mu


def build_mertens_prefix(mu):
    """Return prefix sums M(n)=sum_{k<=n} mu(k) for fast grouped summation."""
    if HAS_NUMBA:
        return build_mertens_prefix_numba(mu)

    mertens = [0] * len(mu)
    running = 0
    for i in range(1, len(mu)):
        running += mu[i]
        mertens[i] = running
    return mertens


def count_ordered_pairs(n: int) -> int:
    """Count pairs (x,y) with 1<=x<y and x^2+y^2<=n."""
    if HAS_NUMBA:
        return int(count_ordered_pairs_numba(n))

    x = 1
    x_sq = 1
    y = isqrt(n)
    y_sq = y * y
    total = 0

    while x < y:
        while x_sq + y_sq > n:
            y_sq -= 2 * y - 1
            y -= 1
        if x >= y:
            break
        total += y - x
        x += 1
        x_sq += 2 * x - 1
    return total


def primitive_triples_upto(n: int) -> int:
    """Return P(n): primitive Pythagorean triples with a<b<c<=n."""
    max_d = isqrt(n)
    mu = mobius_sieve(max_d)
    mertens = build_mertens_prefix(mu)

    @lru_cache(maxsize=None)
    def r(value: int) -> int:
        """Memoized wrapper for R(value)."""
        return count_ordered_pairs(value)

    @lru_cache(maxsize=None)
    def q(value: int) -> int:
        """Compute Q(value)=sum mu(d)*R(value//d^2) with grouped d-ranges."""
        total = 0
        d = 1
        limit = isqrt(value)
        while d <= limit:
            quotient = value // (d * d)
            right = isqrt(value // quotient)
            coeff = int(mertens[right]) - int(mertens[d - 1])
            if coeff:
                total += coeff * r(quotient)
            d = right + 1
        return total

    answer = 0
    sign = 1
    factor = 1
    while factor <= n:
        answer += sign * q(n // factor)
        factor <<= 1
        sign = -sign
    return answer


def solve_and_measure(n: int) -> tuple[int, float]:
    """Return (P(n), elapsed_seconds)."""
    start = perf_counter()
    value = primitive_triples_upto(n)
    elapsed = perf_counter() - start
    return value, elapsed


def main() -> None:
    """Run sample verification and print the final answer."""
    assert primitive_triples_upto(20) == 3, "P(20) must be 3"

    sample_value, sample_elapsed = solve_and_measure(SAMPLE_N)
    assert sample_value == SAMPLE_ANSWER, (
        f"P(10^6) mismatch: got {sample_value}, expected {SAMPLE_ANSWER}"
    )
    print(f"P(10^6) = {sample_value}")
    print(f"sample_elapsed_seconds = {sample_elapsed:.6f}")

    final_value, final_elapsed = solve_and_measure(TARGET_N)
    print(f"P({TARGET_N}) = {final_value}")
    print(f"final_elapsed_seconds = {final_elapsed:.6f}")


if __name__ == "__main__":
    main()

常量与开关

isqrt_numba

mobius_sieve_numbamobius_sieve

build_mertens_prefix_numbabuild_mertens_prefix

count_ordered_pairs_numbacount_ordered_pairs

primitive_triples_upto

solve_and_measuremain