152. 平方倒数之和(Sums of Square Reciprocals)

有若干种方法可以把数 \(\dfrac{1}{2}\) 写成互不相同的整数的平方倒数之和。

例如,可以使用这些数 \(\{2,3,4,5,7,12,15,20,28,35\}\)

$$ \begin{align} \dfrac{1}{2} &= \dfrac{1}{2^2} + \dfrac{1}{3^2} + \dfrac{1}{4^2} + \dfrac{1}{5^2} +\\ &\quad \dfrac{1}{7^2} + \dfrac{1}{12^2} + \dfrac{1}{15^2} + \dfrac{1}{20^2} +\\ &\quad \dfrac{1}{28^2} + \dfrac{1}{35^2}. \end{align} $$

事实上,如果只使用 \(2\)\(45\)(含)之间的整数,恰好有三种方法可以做到;另外两种是:\(\{2,3,4,6,7,9,10,20,28,35,36,45\}\)\(\{2,3,4,6,7,9,12,15,28,30,35,36,45\}\)

使用 \(2\)\(80\)(含)之间互不相同的整数的平方倒数之和,有多少种方法可以写出 \(\dfrac{1}{2}\)

一、题意概述

本题要求统计集合 \(\{2,3,\dots,80\}\) 的子集数量,使得所选元素满足:

$$ \sum_{n \in S} \frac{1}{n^2} = \frac{1}{2} $$

所有分母必须互不相同,因此这是一个精确的子集计数问题。直接枚举 \(2^{79}\) 个子集不可行,必须利用数论约束减少候选分母,并用精确整数计算避免浮点误差。

二、数学背景

\(L=\mathrm{lcm}(2,3,\dots,N)\),则每一项都可以乘以 \(L^2\)

$$ \sum_{n \in S} \left(\frac{L}{n}\right)^2 = \frac{L^2}{2} $$

这样原问题变为整数权重的子集和计数。对每个分母 \(n\),权重为:

$$ w_n = \left(\frac{L}{n}\right)^2 $$

目标值为 \(T=L^2/2\)

进一步考虑奇素数 \(p\)。设 \(p^a\) 是不超过 \(N\) 的最高 \(p\) 幂。把等式乘到公共分母后,所有满足 \(v_p(n)<a\) 的项在模 \(p\) 意义下都会消失;只有满足 \(v_p(n)=a\) 的项可能贡献非零值。由于右侧 \(L^2/2\) 对奇素数 \(p\) 仍然被 \(p\) 整除,任何合法解中这些最高 \(p\) 幂分母的贡献必须满足模 \(p\) 的零和条件。

\(n=p^a m\)\(p \nmid m\),则在模 \(p\) 下该项只差一个公共非零因子,实际需要检查的是:

$$ \sum \frac{1}{m^2} \equiv 0 \pmod p $$

因此可以先删除那些不可能出现在任何非空零和子集中的分母。例如,对于很大的素数 \(p\),区间中只有分母 \(p\) 自身带最高 \(p\) 幂,它不可能单独构成零和,因此必然不会出现在解中。

三、算法分析

候选方案比较

方案 A:直接整数 DFS。 把所有项缩放为整数后,从 \(2\)\(80\) 做递归搜索,并用剩余和剪枝。这个方法实现简单,但候选数过多,目标规模会产生过多状态。

方案 B:完全按素数约束分组枚举。 对每个最高奇素数幂分组枚举所有合法零和选择,再组合各组。这种方法剪枝强,但跨组分母和剩余自由分母的处理复杂,容易写出难以审查的代码。

方案 C:模约束过滤 + 分割子集和。 先用最高奇素数幂的零和必要条件删除不可能分母,再把剩余分母按大小分成两半:较大的分母项权重小且数量少,直接枚举所有子集和;较小的分母项权重大,用递归和上下界剪枝搜索,并在递归末端查表匹配另一半。这是本实现采用的方案。

具体步骤

先生成所有不超过 \(N\) 的素数。对每个奇素数 \(p\),找到最高指数 \(a\),收集当前候选集中满足 \(v_p(n)=a\) 的分母组。

对该分组枚举所有非空子集,计算 \(\sum m^{-2}\bmod p\)。如果某个分母从未出现在任何零和子集中,它不可能属于任何合法解,可以删除。删除会改变其他素数分组的可选集合,因此这个过程迭代到不再发生删除。

得到候选集后,计算 \(L^2\)、每个分母的整数权重 \(w_n\) 和目标 \(T\)

接着按阈值 \(40\) 分割候选集:

由于所有计算都在整数上进行,不存在舍入误差;由于只统计子集和出现次数,不需要存储具体子集。

四、复杂度分析

设过滤后的候选数量为 \(C\),其中高半区数量为 \(H\),低半区数量为 \(L_c\)

模约束过滤中,每个最高幂分组的规模很小,最多只枚举 \(2^k\) 个组内子集,实际代价可以忽略。

高半区枚举复杂度为 \(O(2^H)\),本题中 \(H=20\)。低半区递归的理论上界为 \(O(2^{L_c})\),但上下界剪枝和缓存会合并大量状态。空间主要来自高半区子集和表与低半区递归缓存。

整体运行时间低于 60 秒,且全程使用 Python 原生大整数进行精确计算。

五、代码实现与说明

脚本位于 scripts/pe152.py。文件开头导入命令行、数学、计数器、缓存和计时工具:

import argparse
import math
from collections import Counter
from functools import lru_cache
from time import perf_counter

prime_numbers 用试除法生成不超过 limit 的素数。由于本题最大只到 \(80\),这个简单实现已经足够:

def prime_numbers(limit: int) -> list[int]:
    """Return all primes not greater than limit."""
    primes: list[int] = []
    for value in range(2, limit + 1):
        is_prime = True
        for prime in primes:
            if prime * prime > value:
                break
            if value % prime == 0:
                is_prime = False
                break
        if is_prime:
            primes.append(value)
    return primes

随后两个辅助函数处理 \(p\) 进阶与最高素数幂。p_adic_order 计算 \(v_p(n)\)max_prime_exponent 找到最大 \(a\) 使得 \(p^a\le N\)

def p_adic_order(value: int, prime: int) -> int:
    """Return the exponent of prime in value."""
    exponent = 0
    while value % prime == 0:
        exponent += 1
        value //= prime
    return exponent


def max_prime_exponent(limit: int, prime: int) -> int:
    """Return the largest exponent a such that prime**a <= limit."""
    exponent = 0
    power = prime
    while power <= limit:
        exponent += 1
        power *= prime
    return exponent

zero_subset_members 是模约束过滤的核心。对于最高 \(p\) 幂分组中的每个 \(n=p^a m\),计算 \(m^{-2}\bmod p\),再枚举组内非空子集。能出现在某个零和子集中的分母标记为可保留:

def zero_subset_members(group: list[int], prime: int, exponent: int) -> list[bool]:
    """Mark group members that can belong to a zero-sum subset modulo prime."""
    prime_power = prime**exponent
    residues = []
    for value in group:
        cofactor = value // prime_power
        residues.append(pow((cofactor * cofactor) % prime, -1, prime))

    possible = [False] * len(group)
    for mask in range(1, 1 << len(group)):
        total = 0
        for index, residue in enumerate(residues):
            if mask & (1 << index):
                total = (total + residue) % prime
        if total == 0:
            for index in range(len(group)):
                if mask & (1 << index):
                    possible[index] = True
    return possible

reduced_candidates 从完整区间开始,对所有奇素数重复应用上面的过滤。只要有分母被删除,就重新扫描,直到候选集稳定:

def reduced_candidates(limit: int) -> list[int]:
    """Remove denominators forbidden by highest odd-prime-power constraints."""
    primes = prime_numbers(limit)
    candidates = set(range(2, limit + 1))

    changed = True
    while changed:
        changed = False
        for prime in primes:
            if prime == 2:
                continue

            exponent = max_prime_exponent(limit, prime)
            group = sorted(
                value
                for value in candidates
                if p_adic_order(value, prime) == exponent
            )
            if not group:
                continue

            possible = zero_subset_members(group, prime, exponent)
            rejected = [
                value for value, can_appear in zip(group, possible) if not can_appear
            ]
            if rejected:
                candidates.difference_update(rejected)
                changed = True

    return sorted(candidates)

square_lcm_scale 计算 \(L^2\),之后所有平方倒数都会用这个值缩放为整数权重:

def square_lcm_scale(limit: int) -> int:
    """Return LCM(2..limit)^2."""
    lcm_value = 1
    for value in range(2, limit + 1):
        lcm_value = math.lcm(lcm_value, value)
    return lcm_value * lcm_value

subset_sum_counts 枚举高半区所有子集和。Counter 的值表示达到同一个和的子集数量:

def subset_sum_counts(numbers: list[int], weights: dict[int, int]) -> Counter[int]:
    """Return a frequency table for subset sums over numbers."""
    counts: Counter[int] = Counter({0: 1})
    for number in numbers:
        weight = weights[number]
        additions = [(total + weight, count) for total, count in counts.items()]
        for total, count in additions:
            counts[total] += count
    return counts

count_scaled_subsets 完成最终计数。它先建立整数权重,再把候选数按 \(40\) 分为低半区和高半区。高半区查表,低半区递归搜索:

def count_scaled_subsets(candidates: list[int], scale: int, target: int) -> int:
    """Count candidate subsets whose scaled weights sum to target."""
    weights = {number: scale // (number * number) for number in candidates}
    lower = [number for number in candidates if number <= 40]
    upper = [number for number in candidates if number > 40]

    upper_counts = subset_sum_counts(upper, weights)
    upper_total = sum(weights[number] for number in upper)

    lower_weights = [weights[number] for number in lower]
    suffix = [0] * (len(lower_weights) + 1)
    for index in range(len(lower_weights) - 1, -1, -1):
        suffix[index] = suffix[index + 1] + lower_weights[index]

内部的 search 使用 lru_cache 记忆化。若剩余目标已经不可能由“当前低半区剩余项 + 任意高半区项”补齐,就返回 \(0\);低半区结束后,直接查询高半区的剩余和出现次数:

    @lru_cache(maxsize=None)
    def search(index: int, remaining: int) -> int:
        """Count completions using lower[index:] plus any upper subset."""
        if remaining < 0 or remaining > suffix[index] + upper_total:
            return 0
        if index == len(lower_weights):
            return upper_counts.get(remaining, 0)
        return search(index + 1, remaining) + search(
            index + 1, remaining - lower_weights[index]
        )

    return search(0, target)

solve 串联全部步骤:先过滤候选分母,再构造公共缩放因子,最后统计缩放后的精确子集和:

def solve(limit: int = 80) -> int:
    """Count representations for 1/2 using distinct square reciprocals."""
    candidates = reduced_candidates(limit)
    scale = square_lcm_scale(limit)
    return count_scaled_subsets(candidates, scale, scale // 2)

脚本还提供 --self-test 入口,用于检查题面给出的 \(N=45\) 小规模结果;普通运行则输出目标规模结果和运行时间:

def run_self_tests() -> None:
    """Run the published small-scale verification case."""
    assert solve(45) == 3


def main() -> None:
    """Run the solver or its self-tests from the command line."""
    parser = argparse.ArgumentParser(description="Solve Project Euler 152.")
    parser.add_argument("--limit", type=int, default=80)
    parser.add_argument("--self-test", action="store_true")
    args = parser.parse_args()

    if args.self_test:
        run_self_tests()
        print("Self-tests passed.")
        return

    start = perf_counter()
    answer = solve(args.limit)
    elapsed = perf_counter() - start
    print(f"Result: {answer}")
    print(f"Time:   {elapsed:.3f}s")