fenwicktree.py
imports and globals
from typing import Iterable, Union
FenwickTree
코드
# N FenwickTree
# I {"version": "1.1", "typing": ["Iterable", "Union"]}
class FenwickTree:
"""Fenwick tree for sum operation."""
def __init__(self, nums_or_size: Union[Iterable[float], int]):
if isinstance(nums_or_size, int):
self._size = nums_or_size
self._arr = [0] * nums_or_size
self._tree = [0] * nums_or_size
else:
self._arr = list(nums_or_size)
self._size = len(self._arr)
self._tree = self._arr[:]
for i, num in enumerate(self._tree):
if i | (i + 1) < self._size:
self._tree[i | (i + 1)] += num
def update(self, pos: int, num: float):
self._arr[pos] += num
while pos < self._size:
self._tree[pos] += num
pos |= pos + 1
def set(self, pos: int, num: float):
self.update(pos, num - self._arr[pos])
def query(self, beg: int, end: int) -> float:
res = 0
i = end - 1
while i >= 0:
res += self._tree[i]
i = (i & (i + 1)) - 1
i = beg - 1
while i >= 0:
res -= self._tree[i]
i = (i & (i + 1)) - 1
return res
설명
이 코드를 사용하는 문제
FenwickTreeForRangeUpdatePointQuery
코드
# N FenwickTreeForRangeUpdatePointQuery
# I {"version": "1.0", "abc": ["Iterable"]}
class FenwickTreeForRangeUpdatePointQuery:
"""Fenwick tree for range add and point query."""
__slots__ = ('_size', '_tree')
def __init__(self, nums_or_size: Iterable[float] | int):
if isinstance(nums_or_size, int):
self._tree = [0] * nums_or_size
self._size = nums_or_size
else:
prev = 0
self._tree = [-prev + (prev := x) for x in nums_or_size]
self._size = len(self._tree)
for i, num in enumerate(self._tree):
if (t := i | (i + 1)) < self._size:
self._tree[t] += num
def range_update(self, beg: int, end: int, num: float):
while beg < self._size:
self._tree[beg] += num
beg |= beg + 1
while end < self._size:
self._tree[end] -= num
end |= end + 1
def get(self, pos: int) -> float:
res = 0
while pos >= 0:
res += self._tree[pos]
pos = (pos & (pos + 1)) - 1
return res
설명
기본 펜윅트리를 사용하면서, 레인지 업데이트 포인트 쿼리를 내부적으로 포인트 업데이트 레인지 쿼리로 변환해서 처리해주는 클래스
그냥 FenwickTree 클래스를 써서도 동일한 기능을 구현하는 것이 별로 복잡하지 않기에 굳이 이런 클래스를 만들지 않고 사용해왔지만, 결국은 그마저도 귀찮아서 따로 클래스를 만들었다.
이 코드를 사용하는 문제
OrderStatisticTree
코드
# N OrderStatisticTree
# I {"version": "1.0", "typing": ["Iterable", "Union"]}
class OrderStatisticTree:
def __init__(self, counts_or_max_num: Union[Iterable[int], int]):
if isinstance(counts_or_max_num, int):
self._size = 1 << ((counts_or_max_num + 1).bit_length())
self._arr = [0] * self._size
self._tree = [0] * self._size
else:
self._arr = list(counts_or_max_num)
self._size = 1 << ((len(self._arr) + 1).bit_length())
self._arr += [0] * (self._size - len(self._arr))
self._tree = self._arr[:]
for i, num in enumerate(self._tree):
if i | (i + 1) < self._size:
self._tree[i | (i + 1)] += num
def add(self, num: int, count: int = 1):
self._arr[num] += count
while num < self._size:
self._tree[num] += count
num |= num + 1
def size(self) -> int:
return self._tree[-1]
def count(self, num: int) -> int:
return self._arr[num]
def count_less_than(self, num: int) -> int:
res = 0
r = num - 1
while r >= 0:
res += self._tree[r]
r = (r & (r + 1)) - 1
return res
def kth(self, k: int) -> int:
pos = -1
for i in range(self._size.bit_length() - 1, -1, -1):
p = 1 << i
v = self._tree[pos + p]
if v < k:
k -= v
pos += p
return pos + 1
설명
이 코드를 사용하는 문제