from typing import Callable, Iterable, TypeVar
T = TypeVar('T')
# N SegmentTree
# I {"version": "1.41", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]}
class SegmentTree:
"""Bottom-up segment tree supporting point update and range query."""
def __init__(self,
values: Iterable[T],
merge: Callable[[T, T], T] = min):
l = list(values)
self._size = len(l)
self._tree = l + l
self._merge = merge
for i in range(self._size - 1, 0, -1):
self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])
def set(self, pos: int, value: T):
i = pos + self._size
while i:
self._tree[i] = value
value = (self._merge(self._tree[i - 1], value) if i % 2
else self._merge(value, self._tree[i + 1]))
i >>= 1
def get(self, pos: int) -> T:
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> T:
if end == beg + 1:
return self._tree[beg + self._size]
l, r = beg + self._size + 1, end + self._size - 2
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2:
ret_l = self._merge(ret_l, self._tree[l])
if not r % 2:
ret_r = self._merge(self._tree[r], ret_r)
l, r = (l + 1) >> 1, (r - 1) >> 1
return self._merge(ret_l, ret_r)
출처 | 문제 번호 | Page | 레벨 |
---|---|---|---|
BOJ | 10167 | 금광 | 다이아몬드 5 |
BOJ | 11503 | 가장 긴 증가하는 부분 수열 | 실버 2 |
BOJ | 11055 | 가장 큰 증가 부분 수열 | 실버 2 |
BOJ | 11505 | 구간 곱 구하기 | 골드 1 |
BOJ | 12986 | 화려한 마을2 | 플래티넘 2 |
BOJ | 13557 | 수열과 쿼리 10 | 플래티넘 1 |
BOJ | 14002 | 가장 긴 증가하는 부분 수열 4 | 골드 4 |
BOJ | 14427 | 수열과 쿼리 15 | 골드 1 |
BOJ | 14428 | 수열과 쿼리 16 | 골드 1 |
BOJ | 14438 | 수열과 쿼리 17 | 골드 1 |
BOJ | 15560 | 구간 합 최대? 1 | 골드 2 |
BOJ | 15561 | 구간 합 최대? 2 | 플래티넘 2 |
BOJ | 16933 | 연속합과 쿼리 | 플래티넘 2 |
BOJ | 17407 | 괄호 문자열과 쿼리 | 플래티넘 2 |
BOJ | 17975 | Strike Zone | 다이아몬드 5 |
BOJ | 19651 | 수열과 쿼리 39 | 다이아몬드 5 |
BOJ | 6519 | Frequent values | 플래티넘 1 |
# N MinSegmentTree
# I {"version": "1.01", "typing": ["Iterable"]}
class MinSegmentTree:
"""Bottom-up segment tree supporting point update and range min query."""
__slots__ = ('_size', '_tree')
def __init__(self,
nums_or_size: Iterable[float] | int,
*,
default: float = 0):
if isinstance(nums_or_size, int):
self._size = nums_or_size
self._tree = [default] * (nums_or_size + nums_or_size)
else:
l = list(nums_or_size)
self._size = len(l)
self._tree = l + l
it = reversed(self._tree)
for i in range(self._size - 1, 0, -1):
self._tree[i] = min(next(it), next(it))
def set(self, pos: int, value: float):
i = pos + self._size
while i:
self._tree[i] = value
adj_value = self._tree[i - 1] if i % 2 else self._tree[i + 1]
if adj_value < value:
value = adj_value
i >>= 1
def get(self, pos: int) -> float:
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> float:
if end == beg + 1:
return self._tree[beg + self._size]
l, r = beg + self._size + 1, end + self._size - 2
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2 and self._tree[l] < ret_l:
ret_l = self._tree[l]
if not r % 2 and self._tree[r] < ret_r:
ret_r = self._tree[r]
l, r = (l + 1) >> 1, (r - 1) >> 1
return min(ret_l, ret_r)
출처 | 문제 번호 | Page | 레벨 |
---|---|---|---|
BOJ | 15648 | 추출하는 폴도 바리스타입니다 | 플래티넘 4 |
BOJ | 9345 | 디지털 비디오 디스크(DVDs) | 플래티넘 3 |
# N LazySegmentTree
# I {"version": "1.1", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["ValueType", "ParamType"]}
class LazySegmentTree:
def __init__(self,
values: Iterable[ValueType],
merge: Callable[[ValueType, ValueType], ValueType],
update_value: Callable[[ValueType, ParamType, int], ValueType],
update_param: Callable[[ParamType, ParamType], ParamType],
should_keep_update_order: bool = True):
l = list(values)
self._size = len(l)
self._tree = l + l
self._param = [None] * self._size
self._merge = merge
self._update_value = update_value
self._update_param = update_param
self._should_keep_update_order = should_keep_update_order
for i in range(self._size - 1, 0, -1):
self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1])
def _apply(self, pos: int, param: ParamType, size: int):
self._tree[pos] = self._update_value(self._tree[pos], param, size)
if pos < self._size:
cur_param = self._param[pos]
self._param[pos] = (param if cur_param is None
else self._update_param(cur_param, param))
def _push_down(self, pos: int):
h = self._size.bit_length()
size = 1 << (h - 1)
for i in range(h, 0, -1):
parent = pos >> i
param = self._param[parent]
if param is not None:
self._apply(parent * 2, param, size)
self._apply(parent * 2 + 1, param, size)
self._param[parent] = None
size >>= 1
def _build_up(self, pos: int):
s = 1
while pos > 1:
pos >>= 1
s *= 2
t = self._merge(self._tree[pos * 2], self._tree[pos * 2 + 1])
self._tree[pos] = (t if self._param[pos] is None
else self._update_value(t, self._param[pos], s))
def range_update(self, beg: int, end: int, param: ParamType):
l, r = beg + self._size, end + self._size - 1
if self._should_keep_update_order:
self._push_down(l)
self._push_down(r)
l2, r2, size = l, r, 1
while l2 <= r2:
if l2 % 2:
self._apply(l2, param, size)
if not r2 % 2:
self._apply(r2, param, size)
l2, r2 = (l2 + 1) >> 1, (r2 - 1) >> 1
size *= 2
self._build_up(l)
self._build_up(r)
def get(self, pos: int) -> ValueType:
self._push_down(pos + self._size)
return self._tree[pos + self._size]
def query(self, beg: int, end: int) -> ValueType:
if end == beg + 1:
return self.get(beg)
l, r = beg + self._size + 1, end + self._size - 2
self._push_down(l - 1)
self._push_down(r + 1)
ret_l, ret_r = self._tree[l - 1], self._tree[r + 1]
while l <= r:
if l % 2:
ret_l = self._merge(ret_l, self._tree[l])
if not r % 2:
ret_r = self._merge(self._tree[r], ret_r)
l, r = (l + 1) >> 1, (r - 1) >> 1
return self._merge(ret_l, ret_r)
# N OrderStatisticTree
# I {"version": "1.0"}
class OrderStatisticTree:
def __init__(self, counts_or_max_num):
if isinstance(counts_or_max_num, int):
self._size = 1 << ((counts_or_max_num + 1).bit_length())
self._tree = [0] * (self._size * 2)
else:
l = list(counts_or_max_num)
self._size = 1 << (len(l) - 1).bit_length()
self._tree = [0] * (self._size) + l + [0] * (self._size - len(l))
for i in range(self._size - 1, 0, -1):
self._tree[i] = self._tree[i + i] + self._tree[i + i + 1]
def size(self) -> int:
return self._tree[1]
def count(self, num: int) -> int:
return self._tree[num + self._size]
def add(self, num: int, count: int = 1):
i = num + self._size
while i:
self._tree[i] += count
i >>= 1
def kth(self, k: int) -> int:
i = 1
while i < self._size:
i += i
t = self._tree[i]
if t < k:
k -= t
i += 1
return i - self._size
def count_less_than(self, num: int) -> int:
ret = 0
i = num + self._size - 1
while i:
if not i % 2:
ret += self._tree[i]
i -= 1
i >>= 1
return ret