ps:teflib:segmenttree
목차
segmenttree.py
imports and globals
from typing import Callable, Iterable, TypeVar
T = TypeVar('T')
SegmentTree
코드
# N SegmentTree
# I {"version": "1.41", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]}
class SegmentTree:
"""Bottom-up segment tree supporting point update and range query."""
def __init__(self,
values: Iterable[T],
merge: Callable[[T, T], T] = min):
l = list(values)
self._size = len(l)
self._tree = l + l
self._merge = merge
for i in range(self._size - 1, 0, -1):
self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])
def set(self, pos: int, value: T):
i = pos + self._size
while i:
self._tree[i] = value
value = (self._merge(self._tree[i - 1], value) if i % 2
else self._merge(value, self._tree[i + 1]))
i >>= 1
def get(self, pos: int) -> T:
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> T:
if end == beg + 1:
return self._tree[beg + self._size]
l, r = beg + self._size + 1, end + self._size - 2
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2:
ret_l = self._merge(ret_l, self._tree[l])
if not r % 2:
ret_r = self._merge(self._tree[r], ret_r)
l, r = (l + 1) >> 1, (r - 1) >> 1
return self._merge(ret_l, ret_r)
설명
- 세그먼트 트리 참고
- merge인자를 안 주면 default로 min이 들어가서 구간 최솟값을 계산하게 된다.
- 구간합 트리가 필요한 경우에는 merge에 operator.sum을 넘겨서 만드는 것 보다, teflib.fenwicktree.FenwickTree를 사용하는 것이 효율적이다.
- v1.2 → v1.3의 변화
- merge에서 처리하는 함수가 교환법칙이 성립하지 않을 때에도 (merge(a,b) != merge(b,a) 일때) 정확히 처리하도록 수정했다. 그러나 이 수정으로 인해 속도는 v1.2보다 조금 느려졌다.
- update 함수를 제거했다. 기본적으로 set으로 커버가 되기도 하고, 쓸일이 별로 없기도 하다. update가 주로 쓰이는 것은 그나마 구간합 쿼리의 경우인데, 이 경우는 어차피 펜윅트리를 쓸것이므로.
- v1.3 → v1.4 변화
- get 함수 추가
이 코드를 사용하는 문제
출처 | 문제 번호 | Page | 레벨 |
---|---|---|---|
BOJ | 12986 | 화려한 마을2 | 플래티넘 2 |
BOJ | 15561 | 구간 합 최대? 2 | 플래티넘 2 |
BOJ | 16933 | 연속합과 쿼리 | 플래티넘 2 |
BOJ | 17407 | 괄호 문자열과 쿼리 | 플래티넘 2 |
BOJ | 13557 | 수열과 쿼리 10 | 플래티넘 1 |
BOJ | 6519 | Frequent values | 플래티넘 1 |
BOJ | 11503 | 가장 긴 증가하는 부분 수열 | 실버 2 |
BOJ | 11055 | 가장 큰 증가 부분 수열 | 실버 2 |
BOJ | 10167 | 금광 | 다이아몬드 5 |
BOJ | 17975 | Strike Zone | 다이아몬드 5 |
BOJ | 19651 | 수열과 쿼리 39 | 다이아몬드 5 |
BOJ | 14002 | 가장 긴 증가하는 부분 수열 4 | 골드 4 |
BOJ | 15560 | 구간 합 최대? 1 | 골드 2 |
BOJ | 11505 | 구간 곱 구하기 | 골드 1 |
BOJ | 14427 | 수열과 쿼리 15 | 골드 1 |
BOJ | 14428 | 수열과 쿼리 16 | 골드 1 |
BOJ | 14438 | 수열과 쿼리 17 | 골드 1 |
MinSegmentTree
코드
# N MinSegmentTree
# I {"version": "1.01", "typing": ["Iterable"]}
class MinSegmentTree:
"""Bottom-up segment tree supporting point update and range min query."""
__slots__ = ('_size', '_tree')
def __init__(self,
nums_or_size: Iterable[float] | int,
*,
default: float = 0):
if isinstance(nums_or_size, int):
self._size = nums_or_size
self._tree = [default] * (nums_or_size + nums_or_size)
else:
l = list(nums_or_size)
self._size = len(l)
self._tree = l + l
it = reversed(self._tree)
for i in range(self._size - 1, 0, -1):
self._tree[i] = min(next(it), next(it))
def set(self, pos: int, value: float):
i = pos + self._size
while i:
self._tree[i] = value
adj_value = self._tree[i - 1] if i % 2 else self._tree[i + 1]
if adj_value < value:
value = adj_value
i >>= 1
def get(self, pos: int) -> float:
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> float:
if end == beg + 1:
return self._tree[beg + self._size]
l, r = beg + self._size + 1, end + self._size - 2
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2 and self._tree[l] < ret_l:
ret_l = self._tree[l]
if not r % 2 and self._tree[r] < ret_r:
ret_r = self._tree[r]
l, r = (l + 1) >> 1, (r - 1) >> 1
return min(ret_l, ret_r)
설명
- SegmentTree 를 min 연산에 최적화 시킨 버전.
- 연산자는 고정되어있으므로 인자로 넘길 필요가 없고, 인자로는 초기값들을 넘겨주거나, 크기와 디폴트값을 넘겨주거나 하면 된다,
- merge 함수를 적용하던 부분을 그냥 비교연산을 통해서 작은값으로 업데이트하게 된다. op1과 op2의 순서가 상관 없으므로, 그부분도 간단해졌다. 그러면 이제 self._tree[i - 1] if i % 2 else self._tree[i + 1] 도 self._tree[i ^ 1] 로 간단하게 써도 되긴 하는데, 놀랍게도 속도가 더 느려진다..;
이 코드를 사용하는 문제
출처 | 문제 번호 | Page | 레벨 |
---|---|---|---|
BOJ | 15648 | 추출하는 폴도 바리스타입니다 | 플래티넘 4 |
BOJ | 9345 | 디지털 비디오 디스크(DVDs) | 플래티넘 3 |
LazySegmentTree
코드
# N LazySegmentTree
# I {"version": "1.1", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["ValueType", "ParamType"]}
class LazySegmentTree:
def __init__(self,
values: Iterable[ValueType],
merge: Callable[[ValueType, ValueType], ValueType],
update_value: Callable[[ValueType, ParamType, int], ValueType],
update_param: Callable[[ParamType, ParamType], ParamType],
should_keep_update_order: bool = True):
l = list(values)
self._size = len(l)
self._tree = l + l
self._param = [None] * self._size
self._merge = merge
self._update_value = update_value
self._update_param = update_param
self._should_keep_update_order = should_keep_update_order
for i in range(self._size - 1, 0, -1):
self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])
def _apply(self, pos: int, param: ParamType, size: int):
self._tree[pos] = self._update_value(self._tree[pos], param, size)
if pos < self._size:
cur_param = self._param[pos]
self._param[pos] = (param if cur_param is None
else self._update_param(cur_param, param))
def _push_down(self, pos: int):
h = self._size.bit_length()
size = 1 << (h - 1)
for i in range(h, 0, -1):
parent = pos >> i
param = self._param[parent]
if param is not None:
self._apply(parent * 2, param, size)
self._apply(parent * 2 + 1, param, size)
self._param[parent] = None
size >>= 1
def _build_up(self, pos: int):
s = 1
while pos > 1:
pos >>= 1
s *= 2
t = self._merge(self._tree[pos * 2], self._tree[pos * 2 + 1])
self._tree[pos] = (t if self._param[pos] is None
else self._update_value(t, self._param[pos], s))
def range_update(self, beg: int, end: int, param: ParamType):
l, r = beg + self._size, end + self._size - 1
if self._should_keep_update_order:
self._push_down(l)
self._push_down(r)
l2, r2, size = l, r, 1
while l2 <= r2:
if l2 % 2:
self._apply(l2, param, size)
if not r2 % 2:
self._apply(r2, param, size)
l2, r2 = (l2 + 1) >> 1, (r2 - 1) >> 1
size *= 2
self._build_up(l)
self._build_up(r)
def get(self, pos: int) -> ValueType:
self._push_down(pos + self._size)
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> ValueType:
if end == beg + 1:
return self.get(beg)
l, r = beg + self._size + 1, end + self._size - 2
self._push_down(l - 1)
self._push_down(r + 1)
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2:
ret_l = self._merge(ret_l, self._tree[l])
if not r % 2:
ret_r = self._merge(self._tree[r], ret_r)
l, r = (l + 1) >> 1, (r - 1) >> 1
return self._merge(ret_l, ret_r)
설명
- 세그먼트 트리 참고
이 코드를 사용하는 문제
OrderStatisticTree
코드
# N OrderStatisticTree
# I {"version": "1.0"}
class OrderStatisticTree:
def __init__(self, counts_or_max_num):
if isinstance(counts_or_max_num, int):
self._size = 1 << ((counts_or_max_num + 1).bit_length())
self._tree = [0] * (self._size * 2)
else:
l = list(counts_or_max_num)
self._size = 1 << (len(l) - 1).bit_length()
self._tree = [0] * (self._size) + l + [0] * (self._size - len(l))
for i in range(self._size - 1, 0, -1):
self._tree[i] = self._tree[i + i] + self._tree[i + i + 1]
def size(self) -> int:
return self._tree[1]
def count(self, num: int) -> int:
return self._tree[num + self._size]
def add(self, num: int, count: int = 1):
i = num + self._size
while i:
self._tree[i] += count
i >>= 1
def kth(self, k: int) -> int:
i = 1
while i < self._size:
i += i
t = self._tree[i]
if t < k:
k -= t
i += 1
return i - self._size
def count_less_than(self, num: int) -> int:
ret = 0
i = num + self._size - 1
while i:
if not i % 2:
ret += self._tree[i]
i -= 1
i >>= 1
return ret
설명
- teflib.fenwicktree.OrderStatisticTree도 동일한 메소드들을 갖고 있고 시간 복잡도도 동일하다. 이쪽 구현이 count_less_than()에 대해서는 약간 더 빠르게 동작한다. 따라서 주로 사용할 연산이 count_less_than() 이라면 이쪽 구현체를 사용하자.
이 코드를 사용하는 문제
ps/teflib/segmenttree.txt · 마지막으로 수정됨: 2023/08/31 05:39 저자 teferi
토론