ps:problems:boj:31415
UFO 침공
ps | |
---|---|
링크 | acmicpc.net/… |
출처 | BOJ |
문제 번호 | 31415 |
문제명 | UFO 침공 |
레벨 | 플래티넘 2 |
분류 |
이모스법 |
시간복잡도 | O(n + l*sqrt(n) + q) |
인풋사이즈 | n<=100,000, q<=100,000, l<=100,000 |
사용한 언어 | Python 3.11 |
제출기록 | 147688KB / 5792ms |
최고기록 | 5792ms |
해결날짜 | 2024/02/23 |
출처 |
풀이
- 우선 간단한 관찰들. UFO의 좌표 계산은 x축과 y축 좌표를 분리해서 계산해줘도 된다.
- 레이저 빔 쿼리들 중 y축에 평행한 쿼리들만 생각해보자. 레이저빔의 x값이 qx일때 시작위치과 이동거리가 x0, dx인 UFO가 격추되려면 x0, x0+dx, x0+2*dx, …, x0+(T-1)*dx 중에 qx와 같은 값이 있어야 한다. 만약 dx가 1이라면, qx가 [x0, x0+T) 안에 포함되면 격추된다. 그러면 이것은 모든 UFO들에 대해서 구간들을 구한다음에, qx를 포함하는 구간이 몇개인지를 세는 문제가 되고, 이것은 구간합 업데이트를 처리한 이후에 포인트 쿼리를 수행하는 전형적인 문제가 된다. 누적합과 이모스법을 이용해서, 전처리에 O(n+l), 쿼리 한개당 O(1)에 처리하는 것이 가능하다 (n=UFO의 수, l=좌표의 범위)
- 하지만 이 문제에서 어려운 점은 dx가 1로 고정되어있는게 아니라는 점이다. 물론 dx=k 라고 해도 누적합을 k단위로 p[x] = p[x-k]+d[k] 와 같은 방식으로 업데이트 해주면 되기는 한다. 그러나 dx가 각각 다른 경우에는, 그 각각 다른 dx마다 누적합을 따로 업데이트 시켜줘야 한다. n개의 UFO가 모두 다른 dx값을 갖는다면 O(l)의 누적합 계산을 n번 해줘야 한다. 전처리에서 O(nl)이 걸리면 이미 시간 초과가 된다.
- 침착하게 시간복잡도를 다시 생각해보자. dx=1 이면 l개의 칸에 대해서 누적합을 모두 계산해줘야 하니까 O(n)이 걸리는것은 맞지만, dx=k라면, 업데이트 해줘야 하는 칸수는 l/k 개이다. 두개의 UFO의 dx가 둘다 k라고 해도 x0%k 값이 다르면 업데이트해야 하는 칸들이 다르므로, dx=k 인 UFO들도 k개의 다른 그룹으로 묶일수 있다는것 까지 생각해서 최악의 상황을 생각해보자. UFO들이 모두 중복되지 않으면서 가장 많은 업데이트 범위를 갖도록 dx와 x0값을 갖는 경우에 각각의 UFO에 대해서 업데이트해야 하는 칸수는, l + l/2 + l/2 + l/3 + l/3 + l/3 + l/4 + …이다. UFO가 n개일때 이 값은 O(l*sqrt(n)) 가 된다. 이정도 시간복잡도라면 시간 안에 풀수 있다.
- 방법은 나왔지만, 구현도 간단하지는 않다. UFO들을 (dx, x0%dx) 를 기준으로 분류해서, 각 그룹마다 따로 누적합 업데이트를 돌려줘야 한다. 이와중에 dx가 음수인 경우는 반대방향으로 이동하는 것으로 바꿔서 dx를 양수로 처리해줘야 하고. 같은 작업을 x축뿐만 아니라 y축를 기준으로 다시 해줘야 한다. 이모스법을 쓸때 구간의 시작점과 끝점에 +1과 -1을 업데이트 하는 것을 그냥 배열에 처리해도 안된다. l/dx개의 칸만 업데이트하므로 O(l/dx)에 처리된다는 것이 전제였는데, 길이 l의 배열을 만들게 되면, 배열을 만드는데에 dx 하나당 O(l)의 시간이 걸리는 셈이므로 시간 초과가 된다. +1을 하는 좌표와 -1을 하는 좌표를 리스트에 저장한뒤 소팅해서 처리해야 했다.
- 총 시간복잡도는 대략 O(n + l*sqrt(n) + q)가 된다.
코드
"""Solution code for "BOJ 31415. UFO 침공".
- Problem link: https://www.acmicpc.net/problem/31415
- Solution link: http://www.teferi.net/ps/problems/boj/31415
Tags: [imos]
"""
import collections
import sys
def update(x, dx, max_x, x_counts, x_diffs, T):
if dx == 0:
if x <= max_x + 1:
x_counts[x] += 1
return
beg = x
if dx < 0:
dx = -dx
beg -= dx * (T - 1)
if beg > max_x + 1:
return
diffs = x_diffs[dx, x % dx]
diffs.append((max(beg, x % dx), 1))
end = beg + dx * T
if end <= max_x + 1:
diffs.append((end, -1))
def calc(x_diffs, x_counts, max_x):
for (dx, x), diff in x_diffs.items():
delta = 0
for p, d in sorted(diff):
while x < p:
x_counts[x] += delta
x += dx
delta += d
while x <= max_x:
x_counts[x] += delta
x += dx
def main():
N, Q, T = [int(x) for x in sys.stdin.readline().split()]
ufos = [[int(x) for x in sys.stdin.readline().split()] for _ in range(N)]
queries = [[int(x) for x in sys.stdin.readline().split()] for _ in range(Q)]
max_y = max((pos for t, pos in queries if t == 1), default=0)
max_x = max((pos for t, pos in queries if t == 2), default=0)
x_diffs = collections.defaultdict(list)
y_diffs = collections.defaultdict(list)
x_counts = [0] * (max_x + 1)
y_counts = [0] * (max_y + 1)
for x, y, dx, dy in ufos:
update(x, dx, max_x, x_counts, x_diffs, T)
update(y, dy, max_y, y_counts, y_diffs, T)
calc(x_diffs, x_counts, max_x)
calc(y_diffs, y_counts, max_y)
for t, pos in queries:
print(y_counts[pos] if t == 1 else x_counts[pos])
if __name__ == '__main__':
main()
ps/problems/boj/31415.txt · 마지막으로 수정됨: 2024/03/05 15:06 저자 teferi
토론