사용자 도구

사이트 도구


ps:teflib:combinatorics

combinatorics.py

comb

코드

# N comb
# I {"version": "1.1"}
def comb(n: int, k: int, prime_mod: int) -> int:
    if prime_mod <= k:
        raise ValueError("prime_mod must be greater than k.")
    if n < 0 or k < 0:
        raise ValueError("n and k must be a non-negative integer.")
    if k > n:
        return 0
    if n - k < k:
        k = n - k
    numer = denom = 1
    for i in range(k):
        numer = numer * (n - i) % prime_mod
        denom = denom * (i + 1) % prime_mod
    return numer * pow(denom, -1, prime_mod) % prime_mod

설명

  • 놀랍게도, n<5000 정도까지는 그냥 math.comb(n, k) % mod 로 이항계수의 정확한 값을 구해놓고 마지막에 mod만 해주는 것과 비교해서 별로 빠르지 않다..

이 코드를 사용하는 문제

CombTable

코드

# N CombTable
# I {"version": "1.1"}
class CombTable():
    """A class that repeatedly computes C(n, k) % mod in an efficient way."""

    def __init__(self, max_n: int, prime_mod: int):
        if prime_mod <= max_n:
            raise ValueError("prime_mod must be greater than max_n.")
        self._mod = prime_mod
        # fact = [1!, 2!, ..., n!]
        v = 1
        self._fact = [1] + [v := v * i % prime_mod for i in range(1, max_n + 1)]
        # fact_inv = [1, inv(n!), inv((n-1)!), ..., inv(1!)]
        v = pow(v, -1, prime_mod)
        self._fact_inv = ([1, v] +
                          [v := v * i % prime_mod for i in range(max_n, 1, -1)])

    def get(self, n: int, k: int) -> int:
        if n < 0 or k < 0:
            raise ValueError("n and k must be a non-negative integer.")
        if k > n:
            return 0
        return (self._fact[n] * self._fact_inv[-k] *
                self._fact_inv[k - n] % self._mod)

설명

    • 설명된 방법 중, O(n + logP)에 테이블을 만들어 두는 방법을 사용했다.
  • 호출 횟수가 어지간히 크지 않고서는, 그냥 fact배열만 미리 계산해놓고서, 매 쿼리마다 modinv를 계산하는 방법에 비해서 그다지 빠르지 않다..

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ11402이항 계수 4플래티넘 5
BOJ13977이항 계수와 쿼리골드 1
BOJ14854이항 계수 6다이아몬드 5
BOJ1492플래티넘 2
BOJ15718돌아온 떡파이어플래티넘 3
프로그래머스68647짝수 행 세기Level 4

linear_homogeneous_recurrence

코드

# N linear_homogeneous_recurrence
# I {"version": "1.0", "typing": ["List"]}
class _SqMat(object):
    """A very simple implementation for n x n matrix."""
    def __init__(self, mat):
        self._mat = mat

    def __getitem__(self, row):
        return self._mat[row]

    def __matmul__(self, other):
        ret = []
        for row in self:
            ret_row = []
            for c in range(len(row)):
                col_vec = (row[c] for row in other._mat)
                ret_row.append(sum(r * c for r, c in zip(row, col_vec)))
            ret.append(ret_row)
        return _SqMat(ret)

    def __mod__(self, mod):
        return _SqMat([[item % mod for item in row] for row in self._mat])

    def __pow__(self, n, mod=None):
        res = _SqMat([[0] * len(self._mat) for _ in self._mat])
        for i, row in enumerate(res):
            row[i] = 1
        m = _SqMat([row[:] for row in self._mat])
        while n:
            if n % 2 == 1:
                res @= m
            m @= m
            if mod:
                m %= mod
            n //= 2
        return res % mod if mod else res


def linear_homogeneous_recurrence(coef: List[int],
                                  seeds: List[int],
                                  n: int,
                                  mod: int = None) -> int:
    """Computes (a[n] % mod) from linear homogeneous recurrence relation.

    Computes a[n] from the recurrence relation in following form in
    O(logn * k^3) time.
       a[n] = c[0]*a[n-1] + c[1]*a[n-2] + ... + c[k]*a[n-k+1]

    Args:
        coef: A list of coefficients. [c[0], c[1], ..., c[k]].
        seeds: A list of seed values. [a[0], a[1], ..., a[k]].
        n: n.
        mod: Optional modular.

    Returns:
        N-th value of the recurrence relation.

    Raises:
        ValueError: An error occurred if n is negative.
    """
    if 0 <= n < len(seeds):
        return seeds[n]
    elif n < 0:
        raise ValueError('n should be non-negative.')
    init_mat = [coef] + [[0] * len(coef) for _ in range(len(coef) - 1)]
    for i in range(1, len(coef)):
        init_mat[i][i - 1] = 1
    gen_coefs = pow(_SqMat(init_mat), n - 1, mod)[0]
    res = sum(c * s for c, s in zip(gen_coefs, reversed(seeds)))
    return res % mod if mod else res

설명

이 코드를 사용하는 문제

출처문제 번호Page레벨
BOJ117272×n 타일링 2실버 3
BOJ12727Numbers (Small)골드 3
BOJ12728n제곱 계산플래티넘 1
BOJ12925Numbers플래티넘 1
BOJ13976타일 채우기 2골드 1
BOJ2133타일 채우기실버 2
프로그래머스129023 x n 타일링Level 4

토론

댓글을 입력하세요:
U R Z I C
 
ps/teflib/combinatorics.txt · 마지막으로 수정됨: 2021/07/31 16:12 저자 teferi