사용자 도구

사이트 도구


ps:problems:boj:20390

완전그래프의 최소 스패닝 트리

ps
링크acmicpc.net/…
출처BOJ
문제 번호20390
문제명완전그래프의 최소 스패닝 트리
레벨골드 2
분류

최소 신장 트리

시간복잡도O(n^2)
인풋사이즈n<=10000
사용한 언어Python
제출기록31024KB / 14260ms
최고기록14260ms
해결날짜2021/10/24

풀이

  • 기본적인 최소 신장 트리 (Minimum Spanning Tree / MST) 문제이지만, 노드가 최대 10000인 완전그래프라는 점에서 알고리즘 선택지가 제한된다. 일반적인 크루스칼 알고리즘으로 풀려면 시간 복잡도도 복잡도지만, n*(n+1)/2 개의 엣지 웨이트를 저장하데에서 메모리 문제도 생긴다. 그래서 프림 알고리즘으로 O(n) 공간복잡도와 O(n^2) 시간복잡도로 풀어야 한다.
  • 하지만 프림 알고리즘을 사용하더라도 python에서는 시간 제한이 상당히 빡빡하다. 심지어는 pypy로도 5*3+2 = 17초 제한이 아슬아슬하다.
  • 시간을 단축시키는 방법은 각각 엣지의 웨이트를 계산할때마다 ((Xi × A + Xj × B) % C) ^ D 이 식을 그대로 계산하는 것이 아니라, Xi × A % C, 와 Xi × B % C 를 미리 계산해 놓고 쓰는것이다. 이것만으로도 pypy기준으로 10초에 가깝게 시간이 단축되는 것을 확인했다.
  • 하지만, 이것을 포함한 정상적인 최적화만으로는 python에서 TLE를 면하는데에 실패했다.. 결국 teflib의 prim_mst 함수를 호출하는 대신, 함수내용을 main 안에 복붙한후에 추가 최적화를 거쳐서 겨우 python으로 AC를 받는데에 성공했다.

코드

코드 1 - 일반 버전. (python에서는 TLE)

"""Solution code for "BOJ 20390. 완전그래프의 최소 스패닝 트리".

- Problem link: https://www.acmicpc.net/problem/20390
- Solution link: http://www.teferi.net/ps/problems/boj/20390

Tags: [MST]

This code should be submitted with PyPy3. It gets TLE with Python3.
"""

from teflib import tgraph


def main():

    def weight_func(u, v):
        t = x_a[u] + x_b[v] if u < v else x_a[v] + x_b[u]
        return (t if t < C else t - C) ^ D

    N = int(input())
    A, B, C, D = [int(x) for x in input().split()]
    X = [int(x) for x in input().split()]

    x_a = [x * A % C for x in X]
    x_b = [x * B % C for x in X]

    print(tgraph.prim_mst(N, weight_func))


if __name__ == '__main__':
    main()

코드 2 - 최적화 버전. (python에서도 AC)

"""Solution code for "BOJ 20390. 완전그래프의 최소 스패닝 트리".

- Problem link: https://www.acmicpc.net/problem/20390
- Solution link: http://www.teferi.net/ps/problems/boj/20390

Tags: [MST]
"""

INF = float('inf')


def main():
    N = int(input())
    A, B, C, D = [int(x) for x in input().split()]
    X = [int(x) for x in input().split()]

    x_a = [x * A % C for x in X]
    x_b = [x * B % C for x in X]
    weights = {x: INF for x in range(1, N)}
    u = 0
    total_weight = 0
    for _ in range(N - 1):
        min_weight_node, min_weight = None, INF
        x_a_u, x_b_u = x_a[u], x_b[u]
        for v, weight_v in weights.items():
            t = x_a_u + x_b[v] if u < v else x_a[v] + x_b_u
            w = (t if t < C else t - C) ^ D
            if w < weight_v:
                weight_v = weights[v] = w
            if weight_v < min_weight:
                min_weight_node, min_weight = v, weight_v
        u = min_weight_node
        total_weight += min_weight
        del weights[u]
    print(total_weight)


if __name__ == '__main__':
    main()

토론

댓글을 입력하세요:
Y E B​ Q​ H
 
ps/problems/boj/20390.txt · 마지막으로 수정됨: 2021/10/24 05:46 저자 teferi