사용자 도구

사이트 도구


ps:problems:boj:25974

거듭제곱의 합 1

ps
링크acmicpc.net/…
출처BOJ
문제 번호25974
문제명거듭제곱의 합 1
레벨플래티넘 2
분류

수학

시간복잡도O(plogp)
인풋사이즈p<=1000
사용한 언어Python 3.11
제출기록31256KB / 44ms
최고기록44ms
해결날짜2023/02/09

풀이

* 거듭제곱의 합을 계산하는 문제.

  • 풀이 방법은 거듭제곱의 합을 참고. 라그랑주 보간법을 써서 풀면 시간복잡도는 O(nlogn).
  • n이 최대 1000 밖에 되지 않기 때문에, O(n^2) DP를 이용해도 충분히 풀리긴 한다 (270ms 근처)
    • DP를 이용하는 코드는 참고.

코드

"""Solution code for "BOJ 25974. 거듭제곱의 합 1".

- Problem link: https://www.acmicpc.net/problem/25974
- Solution link: http://www.teferi.net/ps/problems/boj/25974

Tags: [Lagrangian interpolation]
"""


MOD = 10**9 + 7


def multiple_mod_inv(nums, mod):
    a = list(nums)
    b = [v := 1] + [v := v * x % mod for x in reversed(a)]
    b_inv = pow(b.pop(), -1, mod)
    return [b_inv * b.pop() % mod] + [
        (b_inv := b_inv * a_ % mod) * b_ % mod for a_, b_ in zip(a, reversed(b))
    ]


def lagrangian_interpolation(y, n, prime_mod):
    """Finds k-th order func f(x) from y=[f(0), ..., f(k)], and returns f(n)."""
    l = len(y)
    if n < l:
        return y[n]
    invs = multiple_mod_inv(range(n, n - l, -1), prime_mod)
    factorials = [v := 1] + [v := v * i % prime_mod for i in range(1, l)]
    finv = multiple_mod_inv(factorials, prime_mod)
    answer = 0
    sign = 1 if l % 2 else -1
    for inv_i, finv_i, finv_j, y_i in zip(invs, finv, reversed(finv), y):
        answer += sign * inv_i * finv_i * finv_j * y_i
        sign = -sign
    for i in range(n - l + 1, n + 1):
        answer = answer * i % prime_mod
    return answer


def sum_of_powers(n, k, prime_mod):
    """Returns (1^k + 2^k + ... + n^k) % prime_mod."""
    y = [v := 0] + [v := v + pow(i, k, prime_mod) for i in range(1, k + 2)]
    return lagrangian_interpolation(y, n, prime_mod)


def main():
    n, p = [int(x) for x in input().split()]
    print(sum_of_powers(n, p, MOD))


if __name__ == '__main__':
    main()

토론

댓글을 입력하세요:
J L C J N
 
ps/problems/boj/25974.txt · 마지막으로 수정됨: 2023/02/09 16:31 저자 teferi