====== segmenttree.py ======
* [[ps:teflib:segmenttree (old)|이전 버전 코드]]
===== imports and globals =====
from typing import Callable, Iterable, TypeVar
T = TypeVar('T')
===== SegmentTree =====
==== 코드 ====
# N SegmentTree
# I {"version": "1.41", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]}
class SegmentTree:
"""Bottom-up segment tree supporting point update and range query."""
def __init__(self,
values: Iterable[T],
merge: Callable[[T, T], T] = min):
l = list(values)
self._size = len(l)
self._tree = l + l
self._merge = merge
for i in range(self._size - 1, 0, -1):
self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])
def set(self, pos: int, value: T):
i = pos + self._size
while i:
self._tree[i] = value
value = (self._merge(self._tree[i - 1], value) if i % 2
else self._merge(value, self._tree[i + 1]))
i >>= 1
def get(self, pos: int) -> T:
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> T:
if end == beg + 1:
return self._tree[beg + self._size]
l, r = beg + self._size + 1, end + self._size - 2
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2:
ret_l = self._merge(ret_l, self._tree[l])
if not r % 2:
ret_r = self._merge(self._tree[r], ret_r)
l, r = (l + 1) >> 1, (r - 1) >> 1
return self._merge(ret_l, ret_r)
==== 설명 ====
* [[ps:세그먼트 트리]] 참고
* merge인자를 안 주면 default로 min이 들어가서 구간 최솟값을 계산하게 된다.
* 구간합 트리가 필요한 경우에는 merge에 operator.sum을 넘겨서 만드는 것 보다, [[ps:teflib:fenwicktree#FenwickTree|teflib.fenwicktree.FenwickTree]]를 사용하는 것이 효율적이다.
* v1.2 -> v1.3의 변화
* merge에서 처리하는 함수가 교환법칙이 성립하지 않을 때에도 (merge(a,b) != merge(b,a) 일때) 정확히 처리하도록 수정했다. 그러나 이 수정으로 인해 속도는 v1.2보다 조금 느려졌다.
* update 함수를 제거했다. 기본적으로 set으로 커버가 되기도 하고, 쓸일이 별로 없기도 하다. update가 주로 쓰이는 것은 그나마 구간합 쿼리의 경우인데, 이 경우는 어차피 펜윅트리를 쓸것이므로.
* v1.3 -> v1.4 변화
* get 함수 추가
==== 이 코드를 사용하는 문제 ====
---- struct table ----
schema: ps
cols: site, prob_id, %title%, prob_level
filter: teflib ~ *[SegmentTree]*
csv: 0
----
===== MinSegmentTree =====
==== 코드 ====
# N MinSegmentTree
# I {"version": "1.01", "typing": ["Iterable"]}
class MinSegmentTree:
"""Bottom-up segment tree supporting point update and range min query."""
__slots__ = ('_size', '_tree')
def __init__(self,
nums_or_size: Iterable[float] | int,
*,
default: float = 0):
if isinstance(nums_or_size, int):
self._size = nums_or_size
self._tree = [default] * (nums_or_size + nums_or_size)
else:
l = list(nums_or_size)
self._size = len(l)
self._tree = l + l
it = reversed(self._tree)
for i in range(self._size - 1, 0, -1):
self._tree[i] = min(next(it), next(it))
def set(self, pos: int, value: float):
i = pos + self._size
while i:
self._tree[i] = value
adj_value = self._tree[i - 1] if i % 2 else self._tree[i + 1]
if adj_value < value:
value = adj_value
i >>= 1
def get(self, pos: int) -> float:
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> float:
if end == beg + 1:
return self._tree[beg + self._size]
l, r = beg + self._size + 1, end + self._size - 2
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2 and self._tree[l] < ret_l:
ret_l = self._tree[l]
if not r % 2 and self._tree[r] < ret_r:
ret_r = self._tree[r]
l, r = (l + 1) >> 1, (r - 1) >> 1
return min(ret_l, ret_r)
==== 설명 ====
* [[#SegmentTree]] 를 min 연산에 최적화 시킨 버전.
* 연산자는 고정되어있으므로 인자로 넘길 필요가 없고, 인자로는 초기값들을 넘겨주거나, 크기와 디폴트값을 넘겨주거나 하면 된다,
* merge 함수를 적용하던 부분을 그냥 비교연산을 통해서 작은값으로 업데이트하게 된다. op1과 op2의 순서가 상관 없으므로, 그부분도 간단해졌다. 그러면 이제 self._tree[i - 1] if i % 2 else self._tree[i + 1] 도 self._tree[i ^ 1] 로 간단하게 써도 되긴 하는데, 놀랍게도 속도가 더 느려진다..;
==== 이 코드를 사용하는 문제 ====
---- struct table ----
schema: ps
cols: site, prob_id, %title%, prob_level
filter: teflib ~ *[MinSegmentTree]*
csv: 0
----
===== LazySegmentTree =====
==== 코드 ====
# N LazySegmentTree
# I {"version": "1.1", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["ValueType", "ParamType"]}
class LazySegmentTree:
def __init__(self,
values: Iterable[ValueType],
merge: Callable[[ValueType, ValueType], ValueType],
update_value: Callable[[ValueType, ParamType, int], ValueType],
update_param: Callable[[ParamType, ParamType], ParamType],
should_keep_update_order: bool = True):
l = list(values)
self._size = len(l)
self._tree = l + l
self._param = [None] * self._size
self._merge = merge
self._update_value = update_value
self._update_param = update_param
self._should_keep_update_order = should_keep_update_order
for i in range(self._size - 1, 0, -1):
self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])
def _apply(self, pos: int, param: ParamType, size: int):
self._tree[pos] = self._update_value(self._tree[pos], param, size)
if pos < self._size:
cur_param = self._param[pos]
self._param[pos] = (param if cur_param is None
else self._update_param(cur_param, param))
def _push_down(self, pos: int):
h = self._size.bit_length()
size = 1 << (h - 1)
for i in range(h, 0, -1):
parent = pos >> i
param = self._param[parent]
if param is not None:
self._apply(parent * 2, param, size)
self._apply(parent * 2 + 1, param, size)
self._param[parent] = None
size >>= 1
def _build_up(self, pos: int):
s = 1
while pos > 1:
pos >>= 1
s *= 2
t = self._merge(self._tree[pos * 2], self._tree[pos * 2 + 1])
self._tree[pos] = (t if self._param[pos] is None
else self._update_value(t, self._param[pos], s))
def range_update(self, beg: int, end: int, param: ParamType):
l, r = beg + self._size, end + self._size - 1
if self._should_keep_update_order:
self._push_down(l)
self._push_down(r)
l2, r2, size = l, r, 1
while l2 <= r2:
if l2 % 2:
self._apply(l2, param, size)
if not r2 % 2:
self._apply(r2, param, size)
l2, r2 = (l2 + 1) >> 1, (r2 - 1) >> 1
size *= 2
self._build_up(l)
self._build_up(r)
def get(self, pos: int) -> ValueType:
self._push_down(pos + self._size)
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> ValueType:
if end == beg + 1:
return self.get(beg)
l, r = beg + self._size + 1, end + self._size - 2
self._push_down(l - 1)
self._push_down(r + 1)
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2:
ret_l = self._merge(ret_l, self._tree[l])
if not r % 2:
ret_r = self._merge(self._tree[r], ret_r)
l, r = (l + 1) >> 1, (r - 1) >> 1
return self._merge(ret_l, ret_r)
==== 설명 ====
* [[ps:세그먼트 트리]] 참고
==== 이 코드를 사용하는 문제 ====
---- struct table ----
schema: ps
cols: site, prob_id, %title%, prob_level
filteror: teflib ~ *[LazySegmentTree]*
filteror: teflib ~ *[segmenttree.LazySegmentTree]*
csv: 0
----
===== OrderStatisticTree =====
==== 코드 ====
# N OrderStatisticTree
# I {"version": "1.0"}
class OrderStatisticTree:
def __init__(self, counts_or_max_num):
if isinstance(counts_or_max_num, int):
self._size = 1 << ((counts_or_max_num + 1).bit_length())
self._tree = [0] * (self._size * 2)
else:
l = list(counts_or_max_num)
self._size = 1 << (len(l) - 1).bit_length()
self._tree = [0] * (self._size) + l + [0] * (self._size - len(l))
for i in range(self._size - 1, 0, -1):
self._tree[i] = self._tree[i + i] + self._tree[i + i + 1]
def size(self) -> int:
return self._tree[1]
def count(self, num: int) -> int:
return self._tree[num + self._size]
def add(self, num: int, count: int = 1):
i = num + self._size
while i:
self._tree[i] += count
i >>= 1
def kth(self, k: int) -> int:
i = 1
while i < self._size:
i += i
t = self._tree[i]
if t < k:
k -= t
i += 1
return i - self._size
def count_less_than(self, num: int) -> int:
ret = 0
i = num + self._size - 1
while i:
if not i % 2:
ret += self._tree[i]
i -= 1
i >>= 1
return ret
==== 설명 ====
* [[ps:Order statistic tree]] 참고
* [[:ps:teflib:fenwicktree#OrderStatisticTree|teflib.fenwicktree.OrderStatisticTree]]도 동일한 메소드들을 갖고 있고 시간 복잡도도 동일하다. 이쪽 구현이 count_less_than()에 대해서는 약간 더 빠르게 동작한다. 따라서 주로 사용할 연산이 count_less_than() 이라면 이쪽 구현체를 사용하자.
==== 이 코드를 사용하는 문제 ====
---- struct table ----
schema: ps
cols: site, prob_id, %title%, prob_level
filter: teflib ~ *[segmenttree.OrderStatisticTree]*
csv: 0
----