949. 左与右 II(Left vs Right II)

Left 和 Right 用若干个单词(每个单词都由 L 与 R 组成)轮流进行游戏,Left 先手。
在 Left 的回合中,对每个单词都可以从左侧删除任意个字母(可以删 0 个),但不能把某个单词删空;并且在该回合中,至少要在某一个单词上删去至少 1 个字母。
Right 的回合规则对称:对每个单词从右侧删除任意个字母(也可为 0),同样不能删空,并且该回合至少要在某一个单词上删去至少 1 个字母。
当所有单词都被缩短到只剩 1 个字母时,游戏结束:若剩余字母里 L 的数量多于 R,则 Left 获胜;若 R 的数量多于 L,则 Right 获胜。
本题只考虑单词数量为奇数的情形,因此不会平局。
定义 \(G(n,k)\) 为:选取 \(k\) 个长度为 \(n\) 的单词(相同集合的不同排列按不同方案计数)时,在 Left 先手条件下,Right 拥有必胜策略的方案总数。
已知 \(G(2,3)=14\),并给出 \(G(4,3)=496\)\(G(8,5)=26359197010\)
\(G(20,7)\),答案对 \(1001001011\) 取模。

分析:直接在 \(2^n\) 个单词上再做 \(k\) 重博弈枚举会指数爆炸。关键在于把“一个单词对应的对子博弈”先压缩为可加的数值(并区分冷/热),然后把 \(k\) 个单词的组合问题转成“和分布”的卷积计数。

一、数学背景

把一个单词 \(w\) 看成一个“截断游戏”:

对每个局面定义上下界:

递推关系为:

$$ u(w)=\max_{\text{proper suffix }s\text{ of }w} d(s),\quad d(w)=\min_{\text{proper prefix }p\text{ of }w} u(p) $$

\(u(w)<d(w)\),该局面是冷局面,可归约为某个最简 dyadic 值;若 \(u(w)\ge d(w)\),为热局面,只能保留区间信息。

当我们把每个长度 \(n\) 单词映射为一个上界值(统一缩放到整数)后,\(k\) 个单词的和可以分两类判定 Right 必胜:

这就是最终计数公式的来源。

二、算法设计

1) 预处理长度 n 全部单词的值

2) 直方图与卷积

3) 判定计数

三、复杂度分析

四、代码实现与说明

"""Project Euler 949 - Left vs Right II.

该实现使用冷热博弈值(upper/lower stops)与分布卷积计数,计算 G(20, 7) mod 1001001011。
"""

from __future__ import annotations

from bisect import bisect_left
from time import perf_counter

MOD = 1_001_001_011
TARGET_N = 20
TARGET_K = 7


def ceil_div_by_pow2(value: int, shift: int) -> int:
    """返回 ceil(value / 2**shift)。"""
    if shift == 0:
        return value
    if value >= 0:
        return (value + (1 << shift) - 1) >> shift
    return -((-value) >> shift)


def pick_simplest_dyadic_between(lower: int, upper: int, scale_exp: int) -> int:
    """在 (lower, upper) 中挑选最简单的 dyadic(同一 2**scale_exp 缩放)。"""
    for stripped_bits in range(scale_exp + 1):
        shift = scale_exp - stripped_bits
        min_p = (lower >> shift) + 1
        max_p = ceil_div_by_pow2(upper, shift) - 1
        if min_p > max_p:
            continue

        if min_p > 0:
            p = min_p
        elif max_p < 0:
            p = max_p
        else:
            p = 0

        if stripped_bits and p and (p & 1) == 0:
            if p + 1 <= max_p and ((p + 1) & 1):
                p += 1
            elif p - 1 >= min_p and ((p - 1) & 1):
                p -= 1
        return p << shift
    return 0


def compute_upper_values_and_hot_flags(n: int) -> tuple[list[int], list[int]]:
    """计算所有长度 n 单词的 upper 值与 hot 标记。"""
    scale_exp = n
    scale = 1 << scale_exp

    total_nodes = (1 << (n + 1)) - 1
    upper = [0] * total_nodes
    lower = [0] * total_nodes

    len1_start = 1
    upper[len1_start] = scale
    lower[len1_start] = scale
    upper[len1_start + 1] = -scale
    lower[len1_start + 1] = -scale

    hot_flags = [0] * (1 << n)

    for length in range(2, n + 1):
        layer_start = (1 << length) - 1
        layer_size = 1 << length
        for bits in range(layer_size):
            best_left = -(1 << 60)
            for suffix_len in range(1, length):
                suffix_bits = bits & ((1 << suffix_len) - 1)
                idx = (1 << suffix_len) - 1 + suffix_bits
                if lower[idx] > best_left:
                    best_left = lower[idx]

            best_right = 1 << 60
            for prefix_len in range(1, length):
                prefix_bits = bits >> (length - prefix_len)
                idx = (1 << prefix_len) - 1 + prefix_bits
                if upper[idx] < best_right:
                    best_right = upper[idx]

            idx_now = layer_start + bits
            if best_left < best_right:
                canonical = pick_simplest_dyadic_between(best_left, best_right, scale_exp)
                upper[idx_now] = canonical
                lower[idx_now] = canonical
                if length == n:
                    hot_flags[bits] = 0
            else:
                upper[idx_now] = best_left
                lower[idx_now] = best_right
                if length == n:
                    hot_flags[bits] = 1

    full_start = (1 << n) - 1
    full_upper = upper[full_start : full_start + (1 << n)]
    return full_upper, hot_flags


def build_histogram(values: list[int], mod: int) -> dict[int, int]:
    """把数值列表压成频次直方图(模 mod)。"""
    hist: dict[int, int] = {}
    for value in values:
        hist[value] = (hist.get(value, 0) + 1) % mod
    return {key: cnt for key, cnt in hist.items() if cnt}


def convolve_histograms(
    left_hist: dict[int, int],
    right_hist: dict[int, int],
    mod: int,
) -> dict[int, int]:
    """离散和分布卷积。"""
    if not left_hist or not right_hist:
        return {}
    if len(left_hist) > len(right_hist):
        left_hist, right_hist = right_hist, left_hist
    out: dict[int, int] = {}
    for x, cx in left_hist.items():
        for y, cy in right_hist.items():
            key = x + y
            out[key] = (out.get(key, 0) + (cx * cy) % mod) % mod
    return {key: cnt for key, cnt in out.items() if cnt}


def histogram_power(hist: dict[int, int], times: int, mod: int) -> dict[int, int]:
    """重复卷积 times 次(times 很小,直接迭代即可)。"""
    if times == 0:
        return {0: 1}
    result = dict(hist)
    for _ in range(1, times):
        result = convolve_histograms(result, hist, mod)
    return result


def count_pairs_with_negative_sum(
    left_dist: dict[int, int],
    right_dist: dict[int, int],
    mod: int,
) -> int:
    """统计 x+y<0 的配对数(模 mod)。"""
    sorted_items = sorted(right_dist.items())
    sums = [value for value, _ in sorted_items]
    prefix = [0]
    running = 0
    for _, cnt in sorted_items:
        running = (running + cnt) % mod
        prefix.append(running)

    ans = 0
    for x, cx in left_dist.items():
        idx = bisect_left(sums, -x)
        ans = (ans + cx * prefix[idx]) % mod
    return ans


def count_pairs_with_zero_sum(
    left_dist: dict[int, int],
    right_dist: dict[int, int],
    mod: int,
) -> int:
    """统计 x+y=0 的配对数(模 mod)。"""
    if len(left_dist) > len(right_dist):
        left_dist, right_dist = right_dist, left_dist
    ans = 0
    for x, cx in left_dist.items():
        ans = (ans + cx * right_dist.get(-x, 0)) % mod
    return ans


def compute_g(n: int, k: int, mod: int = MOD) -> int:
    """计算题目定义的 G(n, k)(k 需为奇数)。"""
    if n <= 0:
        raise ValueError("n must be positive")
    if k % 2 == 0:
        raise ValueError("k must be odd")

    upper_values, hot_flags = compute_upper_values_and_hot_flags(n)
    all_hist = build_histogram(upper_values, mod)
    cold_values = [v for v, hot in zip(upper_values, hot_flags) if hot == 0]
    cold_hist = build_histogram(cold_values, mod)

    left_count = k // 2
    right_count = k - left_count

    all_left_dist = histogram_power(all_hist, left_count, mod)
    all_right_dist = histogram_power(all_hist, right_count, mod)
    strictly_negative = count_pairs_with_negative_sum(all_left_dist, all_right_dist, mod)

    cold_left_dist = histogram_power(cold_hist, left_count, mod)
    cold_right_dist = histogram_power(cold_hist, right_count, mod)
    zero_and_all_cold = count_pairs_with_zero_sum(cold_left_dist, cold_right_dist, mod)

    return (strictly_negative + zero_and_all_cold) % mod


def main() -> None:
    """运行样例与目标规模。"""
    t0 = perf_counter()
    sample_1 = compute_g(2, 3, 1 << 63)
    sample_2 = compute_g(4, 3, 1 << 63)
    sample_3 = compute_g(8, 5, 1 << 63)
    prep_elapsed = perf_counter() - t0

    print(f"G(2, 3) = {sample_1}")
    print(f"G(4, 3) = {sample_2}")
    print(f"G(8, 5) = {sample_3}")
    assert sample_1 == 14, f"sample mismatch: {sample_1}"
    assert sample_2 == 496, f"sample mismatch: {sample_2}"
    assert sample_3 == 26359197010, f"sample mismatch: {sample_3}"

    t1 = perf_counter()
    answer = compute_g(TARGET_N, TARGET_K, MOD)
    solve_elapsed = perf_counter() - t1

    print(f"G({TARGET_N}, {TARGET_K}) mod {MOD} = {answer}")
    print(f"sample_elapsed_seconds = {prep_elapsed:.3f}")
    print(f"solve_elapsed_seconds = {solve_elapsed:.3f}")


if __name__ == "__main__":
    main()

代码说明(按代码顺序):

模块导入与常量

ceil_div_by_pow2

pick_simplest_dyadic_between

compute_upper_values_and_hot_flags

build_histogram

convolve_histograms

histogram_power

count_pairs_with_negative_sum

count_pairs_with_zero_sum

compute_g

main