====== segmenttree.py ====== * [[ps:teflib:segmenttree (old)|이전 버전 코드]] ===== imports and globals ===== from typing import Callable, Iterable, TypeVar T = TypeVar('T') ===== SegmentTree ===== ==== 코드 ==== # N SegmentTree # I {"version": "1.41", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["T"]} class SegmentTree: """Bottom-up segment tree supporting point update and range query.""" def __init__(self, values: Iterable[T], merge: Callable[[T, T], T] = min): l = list(values) self._size = len(l) self._tree = l + l self._merge = merge for i in range(self._size - 1, 0, -1): self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1]) def set(self, pos: int, value: T): i = pos + self._size while i: self._tree[i] = value value = (self._merge(self._tree[i - 1], value) if i % 2 else self._merge(value, self._tree[i + 1])) i >>= 1 def get(self, pos: int) -> T: return self._tree[pos + self._size] def query(self, beg: int, end: int) -> T: if end == beg + 1: return self._tree[beg + self._size] l, r = beg + self._size + 1, end + self._size - 2 ret_l, ret_r = self._tree[l - 1], self._tree[r + 1] while l <= r: if l % 2: ret_l = self._merge(ret_l, self._tree[l]) if not r % 2: ret_r = self._merge(self._tree[r], ret_r) l, r = (l + 1) >> 1, (r - 1) >> 1 return self._merge(ret_l, ret_r) ==== 설명 ==== * [[ps:세그먼트 트리]] 참고 * merge인자를 안 주면 default로 min이 들어가서 구간 최솟값을 계산하게 된다. * 구간합 트리가 필요한 경우에는 merge에 operator.sum을 넘겨서 만드는 것 보다, [[ps:teflib:fenwicktree#FenwickTree|teflib.fenwicktree.FenwickTree]]를 사용하는 것이 효율적이다. * v1.2 -> v1.3의 변화 * merge에서 처리하는 함수가 교환법칙이 성립하지 않을 때에도 (merge(a,b) != merge(b,a) 일때) 정확히 처리하도록 수정했다. 그러나 이 수정으로 인해 속도는 v1.2보다 조금 느려졌다. * update 함수를 제거했다. 기본적으로 set으로 커버가 되기도 하고, 쓸일이 별로 없기도 하다. update가 주로 쓰이는 것은 그나마 구간합 쿼리의 경우인데, 이 경우는 어차피 펜윅트리를 쓸것이므로. * v1.3 -> v1.4 변화 * get 함수 추가 ==== 이 코드를 사용하는 문제 ==== ---- struct table ---- schema: ps cols: site, prob_id, %title%, prob_level filter: teflib ~ *[SegmentTree]* csv: 0 ---- ===== MinSegmentTree ===== ==== 코드 ==== # N MinSegmentTree # I {"version": "1.01", "typing": ["Iterable"]} class MinSegmentTree: """Bottom-up segment tree supporting point update and range min query.""" __slots__ = ('_size', '_tree') def __init__(self, nums_or_size: Iterable[float] | int, *, default: float = 0): if isinstance(nums_or_size, int): self._size = nums_or_size self._tree = [default] * (nums_or_size + nums_or_size) else: l = list(nums_or_size) self._size = len(l) self._tree = l + l it = reversed(self._tree) for i in range(self._size - 1, 0, -1): self._tree[i] = min(next(it), next(it)) def set(self, pos: int, value: float): i = pos + self._size while i: self._tree[i] = value adj_value = self._tree[i - 1] if i % 2 else self._tree[i + 1] if adj_value < value: value = adj_value i >>= 1 def get(self, pos: int) -> float: return self._tree[pos + self._size] def query(self, beg: int, end: int) -> float: if end == beg + 1: return self._tree[beg + self._size] l, r = beg + self._size + 1, end + self._size - 2 ret_l, ret_r = self._tree[l - 1], self._tree[r + 1] while l <= r: if l % 2 and self._tree[l] < ret_l: ret_l = self._tree[l] if not r % 2 and self._tree[r] < ret_r: ret_r = self._tree[r] l, r = (l + 1) >> 1, (r - 1) >> 1 return min(ret_l, ret_r) ==== 설명 ==== * [[#SegmentTree]] 를 min 연산에 최적화 시킨 버전. * 연산자는 고정되어있으므로 인자로 넘길 필요가 없고, 인자로는 초기값들을 넘겨주거나, 크기와 디폴트값을 넘겨주거나 하면 된다, * merge 함수를 적용하던 부분을 그냥 비교연산을 통해서 작은값으로 업데이트하게 된다. op1과 op2의 순서가 상관 없으므로, 그부분도 간단해졌다. 그러면 이제 self._tree[i - 1] if i % 2 else self._tree[i + 1] 도 self._tree[i ^ 1] 로 간단하게 써도 되긴 하는데, 놀랍게도 속도가 더 느려진다..; ==== 이 코드를 사용하는 문제 ==== ---- struct table ---- schema: ps cols: site, prob_id, %title%, prob_level filter: teflib ~ *[MinSegmentTree]* csv: 0 ---- ===== LazySegmentTree ===== ==== 코드 ==== # N LazySegmentTree # I {"version": "1.1", "typing": ["Callable", "Iterable", "TypeVar"], "const": ["ValueType", "ParamType"]} class LazySegmentTree: def __init__(self, values: Iterable[ValueType], merge: Callable[[ValueType, ValueType], ValueType], update_value: Callable[[ValueType, ParamType, int], ValueType], update_param: Callable[[ParamType, ParamType], ParamType], should_keep_update_order: bool = True): l = list(values) self._size = len(l) self._tree = l + l self._param = [None] * self._size self._merge = merge self._update_value = update_value self._update_param = update_param self._should_keep_update_order = should_keep_update_order for i in range(self._size - 1, 0, -1): self._tree[i] = merge(self._tree[i * 2], self._tree[i * 2 + 1]) def _apply(self, pos: int, param: ParamType, size: int): self._tree[pos] = self._update_value(self._tree[pos], param, size) if pos < self._size: cur_param = self._param[pos] self._param[pos] = (param if cur_param is None else self._update_param(cur_param, param)) def _push_down(self, pos: int): h = self._size.bit_length() size = 1 << (h - 1) for i in range(h, 0, -1): parent = pos >> i param = self._param[parent] if param is not None: self._apply(parent * 2, param, size) self._apply(parent * 2 + 1, param, size) self._param[parent] = None size >>= 1 def _build_up(self, pos: int): s = 1 while pos > 1: pos >>= 1 s *= 2 t = self._merge(self._tree[pos * 2], self._tree[pos * 2 + 1]) self._tree[pos] = (t if self._param[pos] is None else self._update_value(t, self._param[pos], s)) def range_update(self, beg: int, end: int, param: ParamType): l, r = beg + self._size, end + self._size - 1 if self._should_keep_update_order: self._push_down(l) self._push_down(r) l2, r2, size = l, r, 1 while l2 <= r2: if l2 % 2: self._apply(l2, param, size) if not r2 % 2: self._apply(r2, param, size) l2, r2 = (l2 + 1) >> 1, (r2 - 1) >> 1 size *= 2 self._build_up(l) self._build_up(r) def get(self, pos: int) -> ValueType: self._push_down(pos + self._size) return self._tree[pos + self._size] def query(self, beg: int, end: int) -> ValueType: if end == beg + 1: return self.get(beg) l, r = beg + self._size + 1, end + self._size - 2 self._push_down(l - 1) self._push_down(r + 1) ret_l, ret_r = self._tree[l - 1], self._tree[r + 1] while l <= r: if l % 2: ret_l = self._merge(ret_l, self._tree[l]) if not r % 2: ret_r = self._merge(self._tree[r], ret_r) l, r = (l + 1) >> 1, (r - 1) >> 1 return self._merge(ret_l, ret_r) ==== 설명 ==== * [[ps:세그먼트 트리]] 참고 ==== 이 코드를 사용하는 문제 ==== ---- struct table ---- schema: ps cols: site, prob_id, %title%, prob_level filteror: teflib ~ *[LazySegmentTree]* filteror: teflib ~ *[segmenttree.LazySegmentTree]* csv: 0 ---- ===== OrderStatisticTree ===== ==== 코드 ==== # N OrderStatisticTree # I {"version": "1.0"} class OrderStatisticTree: def __init__(self, counts_or_max_num): if isinstance(counts_or_max_num, int): self._size = 1 << ((counts_or_max_num + 1).bit_length()) self._tree = [0] * (self._size * 2) else: l = list(counts_or_max_num) self._size = 1 << (len(l) - 1).bit_length() self._tree = [0] * (self._size) + l + [0] * (self._size - len(l)) for i in range(self._size - 1, 0, -1): self._tree[i] = self._tree[i + i] + self._tree[i + i + 1] def size(self) -> int: return self._tree[1] def count(self, num: int) -> int: return self._tree[num + self._size] def add(self, num: int, count: int = 1): i = num + self._size while i: self._tree[i] += count i >>= 1 def kth(self, k: int) -> int: i = 1 while i < self._size: i += i t = self._tree[i] if t < k: k -= t i += 1 return i - self._size def count_less_than(self, num: int) -> int: ret = 0 i = num + self._size - 1 while i: if not i % 2: ret += self._tree[i] i -= 1 i >>= 1 return ret ==== 설명 ==== * [[ps:Order statistic tree]] 참고 * [[:ps:teflib:fenwicktree#OrderStatisticTree|teflib.fenwicktree.OrderStatisticTree]]도 동일한 메소드들을 갖고 있고 시간 복잡도도 동일하다. 이쪽 구현이 count_less_than()에 대해서는 약간 더 빠르게 동작한다. 따라서 주로 사용할 연산이 count_less_than() 이라면 이쪽 구현체를 사용하자. ==== 이 코드를 사용하는 문제 ==== ---- struct table ---- schema: ps cols: site, prob_id, %title%, prob_level filter: teflib ~ *[segmenttree.OrderStatisticTree]* csv: 0 ----