====== 거듭제곱의 합 1 ====== ===== 풀이 ===== * [[ps:여러가지 수학 정리|거듭제곱의 합]]을 계산하는 문제. * 풀이 방법은 [[ps:여러가지 수학 정리|거듭제곱의 합]]을 참고. 라그랑주 보간법을 써서 풀면 시간복잡도는 O(nlogn). * n이 최대 1000 밖에 되지 않기 때문에, O(n^2) DP를 이용해도 충분히 풀리긴 한다 (270ms 근처) * DP를 이용하는 코드는 [[ps:problems:boj:1492]] 참고. ===== 코드 ===== """Solution code for "BOJ 25974. 거듭제곱의 합 1". - Problem link: https://www.acmicpc.net/problem/25974 - Solution link: http://www.teferi.net/ps/problems/boj/25974 Tags: [Lagrangian interpolation] """ MOD = 10**9 + 7 def multiple_mod_inv(nums, mod): a = list(nums) b = [v := 1] + [v := v * x % mod for x in reversed(a)] b_inv = pow(b.pop(), -1, mod) return [b_inv * b.pop() % mod] + [ (b_inv := b_inv * a_ % mod) * b_ % mod for a_, b_ in zip(a, reversed(b)) ] def lagrangian_interpolation(y, n, prime_mod): """Finds k-th order func f(x) from y=[f(0), ..., f(k)], and returns f(n).""" l = len(y) if n < l: return y[n] invs = multiple_mod_inv(range(n, n - l, -1), prime_mod) factorials = [v := 1] + [v := v * i % prime_mod for i in range(1, l)] finv = multiple_mod_inv(factorials, prime_mod) answer = 0 sign = 1 if l % 2 else -1 for inv_i, finv_i, finv_j, y_i in zip(invs, finv, reversed(finv), y): answer += sign * inv_i * finv_i * finv_j * y_i sign = -sign for i in range(n - l + 1, n + 1): answer = answer * i % prime_mod return answer def sum_of_powers(n, k, prime_mod): """Returns (1^k + 2^k + ... + n^k) % prime_mod.""" y = [v := 0] + [v := v + pow(i, k, prime_mod) for i in range(1, k + 2)] return lagrangian_interpolation(y, n, prime_mod) def main(): n, p = [int(x) for x in input().split()] print(sum_of_powers(n, p, MOD)) if __name__ == '__main__': main() {{tag>BOJ ps:problems:boj:플래티넘_2}}