사용자 도구

사이트 도구


ps:teflib:combinatorics

combinatorics.py

comb

코드

# N comb
# I {"version": "1.1"}
def comb(n: int, k: int, prime_mod: int) -> int:
    if prime_mod <= k:
        raise ValueError("prime_mod must be greater than k.")
    if n < 0 or k < 0:
        raise ValueError("n and k must be a non-negative integer.")
    if k > n:
        return 0
    if n - k < k:
        k = n - k
    numer = denom = 1
    for i in range(k):
        numer = numer * (n - i) % prime_mod
        denom = denom * (i + 1) % prime_mod
    return numer * pow(denom, -1, prime_mod) % prime_mod

설명

  • 놀랍게도, n<5000 정도까지는 그냥 math.comb(n, k) % mod 로 이항계수의 정확한 값을 구해놓고 마지막에 mod만 해주는 것과 비교해서 별로 빠르지 않다..

이 코드를 사용하는 문제

CombTable

코드

# N CombTable
# I {"version": "1.1"}
class CombTable():
    """A class that repeatedly computes C(n, k) % mod in an efficient way."""

    def __init__(self, max_n: int, prime_mod: int):
        if prime_mod <= max_n:
            raise ValueError("prime_mod must be greater than max_n.")
        self._mod = prime_mod
        # fact = [1!, 2!, ..., n!]
        v = 1
        self._fact = [1] + [v := v * i % prime_mod for i in range(1, max_n + 1)]
        # fact_inv = [1, inv(n!), inv((n-1)!), ..., inv(1!)]
        v = pow(v, -1, prime_mod)
        self._fact_inv = ([1, v] +
                          [v := v * i % prime_mod for i in range(max_n, 1, -1)])

    def get(self, n: int, k: int) -> int:
        if n < 0 or k < 0:
            raise ValueError("n and k must be a non-negative integer.")
        if k > n:
            return 0
        return (self._fact[n] * self._fact_inv[-k] *
                self._fact_inv[k - n] % self._mod)

설명

    • 설명된 방법 중, O(n + logP)에 테이블을 만들어 두는 방법을 사용했다.
  • 호출 횟수가 어지간히 크지 않고서는, 그냥 fact배열만 미리 계산해놓고서, 매 쿼리마다 modinv를 계산하는 방법에 비해서 그다지 빠르지 않다..

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ11402이항 계수 4플래티넘 5
BOJ13977이항 계수와 쿼리골드 1
BOJ14854이항 계수 6다이아몬드 5
BOJ1492플래티넘 2
BOJ15718돌아온 떡파이어플래티넘 3
프로그래머스68647짝수 행 세기Level 4

linear_homogeneous_recurrence

코드

# N linear_homogeneous_recurrence
# I {"version": "1.41", "func": ["matrix.matpow"]}
def linear_homogeneous_recurrence(
    coefs: list[int], seeds: list[int], n: int, mod: int = 0
) -> int:
    """[DEPRECATED] Computes (a[n] % mod) from linear homogeneous recurrence.

    *** Deprecated: Use combinatorics.linear_recurrence(), instead. ***

    Computes a[n] from the recurrence relation in following form in
    O(logn * k^3) time.
       a[n] = c[0]*a[n-1] + c[1]*a[n-2] + ... + c[k]*a[n-k+1]

    Args:
        coef: A list of coefficients. [c[0], c[1], ..., c[k]].
        seeds: A list of seed values. [a[0], a[1], ..., a[k]].
        n: n.
        mod: Optional modular.

    Returns:
        N-th value of the recurrence relation.

    Raises:
        ValueError: An error occurred if n is negative.
    """
    if 0 <= n < len(seeds):
        return seeds[n]
    elif n < 0:
        raise ValueError('n should be non-negative.')
    init_mat = [coefs] + [[0] * len(coefs) for _ in range(len(coefs) - 1)]
    for i in range(1, len(coefs)):
        init_mat[i][i - 1] = 1
    gen_coefs = matrix.matpow(init_mat, n - len(coefs) + 1, mod)[0]
    res = sum(c * s for c, s in zip(gen_coefs, reversed(seeds)))
    return res % mod if mod > 0 else res

설명

linear_recurrence

코드

# N linear_recurrence
# I {"version": "0.1", "func": ["_naive_convolution"]}
def linear_recurrence(
    coefs: Sequence[int], seeds: Sequence[int], n: int, mod_param: int = 0
):
    """Computes (a[n] % mod) from linear homogeneous recurrence relation.

    This function computes a[n] from the recurrence relation in following form.
       a[n] = c[0]*a[n-1] + c[1]*a[n-2] + ... + c[k]*a[n-k+1]
       (a[i] and c[i] should be integers)

    The implementation is based on Bostan-Mori algorithm with naive convolution
    function, with O(k^2*logn) time complexity. If k is large (>60), use
    linear_recurrence_large, instead.

    Args:
        coefs: A list of coefficients. [c[0], c[1], ..., c[k]].
        seeds: A list of seed values. [a[0], a[1], ..., a[k]].
        n: n.
        mod_param: Optional modular.

    Returns:
        a[n] % mod_param.

    Raises:
        ValueError: An error occurred if n is negative.
    """

    if 0 <= n < len(seeds):
        return seeds[n]
    elif n < 0:
        raise ValueError('n should be non-negative.')
    mod = mod_param or (1 << 64)
    d = len(coefs)
    if len(seeds) > d:
        seeds, n = seeds[-d:], n - (len(seeds) - d)
    q = [1] + [(-x) % mod for x in coefs]
    p = [x % mod for x in _naive_convolution(seeds, q)[:d]]
    while n:
        r = q[:]
        r[1::2] = [mod - x for x in q[1::2]]
        p = [x % mod for x in _naive_convolution(p, r)[n & 1 :: 2]]
        q = [x % mod for x in _naive_convolution(q, r)[::2]]
        n >>= 1
    return p[0] - mod if mod_param == 0 and p[0] > (mod // 2) else p[0]

설명

  • 보스탄-모리 알고리즘을 이용해서, 인접 k항의 선형 동차 점화식이 주어졌을때, n번째 항을 O(k^2logn)에 계산한다.
  • k가 60이상으로 커질수 있다면, linear_recurrence_large 를 사용하는 편이 효율적이다.

이 코드를 사용하는 문제

fibonacci

코드

# N fibonacci
# I {"version": "1.2"}
def fibonacci(n: int, mod: int = 0) -> int:
    """Returns n-th Fibonacci number. f(1)=1, f(2)=1, f(3)=2, f(4)=3, ..."""

    c, d = 0, 1
    for bit in bin(n)[2:]:
        c, d = c * (d + d - c), c * c + d * d
        if bit == '1':
            c, d = d, c + d
        if mod:
            c, d = c % mod, d % mod
    return c

설명

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ11440피보나치 수의 제곱의 합플래티넘 5
BOJ11444피보나치 수 6골드 3
BOJ117262×n 타일링실버 3
BOJ11778피보나치 수와 최대공약수골드 1
BOJ2193이친수실버 3
BOJ2749피보나치 수 3골드 2
프로그래머스129002 x n 타일링Level 3

토론

댓글을 입력하세요:
E C B X D
 
ps/teflib/combinatorics.txt · 마지막으로 수정됨: 2023/08/20 16:11 저자 teferi