123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import operator
- from typing import Any, Optional
- class SegmentTree:
- """A Segment Tree data structure.
- https://en.wikipedia.org/wiki/Segment_tree
- Can be used as regular array, but with two important differences:
- a) Setting an item's value is slightly slower. It is O(lg capacity),
- instead of O(1).
- b) Offers efficient `reduce` operation which reduces the tree's values
- over some specified contiguous subsequence of items in the array.
- Operation could be e.g. min/max/sum.
- The data is stored in a list, where the length is 2 * capacity.
- The second half of the list stores the actual values for each index, so if
- capacity=8, values are stored at indices 8 to 15. The first half of the
- array contains the reduced-values of the different (binary divided)
- segments, e.g. (capacity=4):
- 0=not used
- 1=reduced-value over all elements (array indices 4 to 7).
- 2=reduced-value over array indices (4 and 5).
- 3=reduced-value over array indices (6 and 7).
- 4-7: values of the tree.
- NOTE that the values of the tree are accessed by indices starting at 0, so
- `tree[0]` accesses `internal_array[4]` in the above example.
- """
- def __init__(self,
- capacity: int,
- operation: Any,
- neutral_element: Optional[Any] = None):
- """Initializes a Segment Tree object.
- Args:
- capacity (int): Total size of the array - must be a power of two.
- operation (operation): Lambda obj, obj -> obj
- The operation for combining elements (eg. sum, max).
- Must be a mathematical group together with the set of
- possible values for array elements.
- neutral_element (Optional[obj]): The neutral element for
- `operation`. Use None for automatically finding a value:
- max: float("-inf"), min: float("inf"), sum: 0.0.
- """
- assert capacity > 0 and capacity & (capacity - 1) == 0, \
- "Capacity must be positive and a power of 2!"
- self.capacity = capacity
- if neutral_element is None:
- neutral_element = 0.0 if operation is operator.add else \
- float("-inf") if operation is max else float("inf")
- self.neutral_element = neutral_element
- self.value = [self.neutral_element for _ in range(2 * capacity)]
- self.operation = operation
- def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
- """Applies `self.operation` to subsequence of our values.
- Subsequence is contiguous, includes `start` and excludes `end`.
- self.operation(
- arr[start], operation(arr[start+1], operation(... arr[end])))
- Args:
- start (int): Start index to apply reduction to.
- end (Optional[int]): End index to apply reduction to (excluded).
- Returns:
- any: The result of reducing self.operation over the specified
- range of `self._value` elements.
- """
- if end is None:
- end = self.capacity
- elif end < 0:
- end += self.capacity
- # Init result with neutral element.
- result = self.neutral_element
- # Map start/end to our actual index space (second half of array).
- start += self.capacity
- end += self.capacity
- # Example:
- # internal-array (first half=sums, second half=actual values):
- # 0 1 2 3 | 4 5 6 7
- # - 6 1 5 | 1 0 2 3
- # tree.sum(0, 3) = 3
- # internally: start=4, end=7 -> sum values 1 0 2 = 3.
- # Iterate over tree starting in the actual-values (second half)
- # section.
- # 1) start=4 is even -> do nothing.
- # 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
- # 3) int-divide start and end by 2: start=2, end=3
- # 4) start still smaller end -> iterate once more.
- # 5) start=2 is even -> do nothing.
- # 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
- # NOTE: This adds the sum of indices 4 and 5 to the result.
- # Iterate as long as start != end.
- while start < end:
- # If start is odd: Add its value to result and move start to
- # next even value.
- if start & 1:
- result = self.operation(result, self.value[start])
- start += 1
- # If end is odd: Move end to previous even value, then add its
- # value to result. NOTE: This takes care of excluding `end` in any
- # situation.
- if end & 1:
- end -= 1
- result = self.operation(result, self.value[end])
- # Divide both start and end by 2 to make them "jump" into the
- # next upper level reduce-index space.
- start //= 2
- end //= 2
- # Then repeat till start == end.
- return result
- def __setitem__(self, idx: int, val: float) -> None:
- """
- Inserts/overwrites a value in/into the tree.
- Args:
- idx (int): The index to insert to. Must be in [0, `self.capacity`[
- val (float): The value to insert.
- """
- assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"
- # Index of the leaf to insert into (always insert in "second half"
- # of the tree, the first half is reserved for already calculated
- # reduction-values).
- idx += self.capacity
- self.value[idx] = val
- # Recalculate all affected reduction values (in "first half" of tree).
- idx = idx >> 1 # Divide by 2 (faster than division).
- while idx >= 1:
- update_idx = 2 * idx # calculate only once
- # Update the reduction value at the correct "first half" idx.
- self.value[idx] = self.operation(self.value[update_idx],
- self.value[update_idx + 1])
- idx = idx >> 1 # Divide by 2 (faster than division).
- def __getitem__(self, idx: int) -> Any:
- assert 0 <= idx < self.capacity
- return self.value[idx + self.capacity]
- def get_state(self):
- return self.value
- def set_state(self, state):
- assert len(state) == self.capacity * 2
- self.value = state
- class SumSegmentTree(SegmentTree):
- """A SegmentTree with the reduction `operation`=operator.add."""
- def __init__(self, capacity: int):
- super(SumSegmentTree, self).__init__(
- capacity=capacity, operation=operator.add)
- def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
- """Returns the sum over a sub-segment of the tree."""
- return self.reduce(start, end)
- def find_prefixsum_idx(self, prefixsum: float) -> int:
- """Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
- Args:
- prefixsum (float): `prefixsum` upper bound in above constraint.
- Returns:
- int: Largest possible index (i) satisfying above constraint.
- """
- assert 0 <= prefixsum <= self.sum() + 1e-5
- # Global sum node.
- idx = 1
- # While non-leaf (first half of tree).
- while idx < self.capacity:
- update_idx = 2 * idx
- if self.value[update_idx] > prefixsum:
- idx = update_idx
- else:
- prefixsum -= self.value[update_idx]
- idx = update_idx + 1
- return idx - self.capacity
- class MinSegmentTree(SegmentTree):
- def __init__(self, capacity: int):
- super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
- def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
- """Returns min(arr[start], ..., arr[end])"""
- return self.reduce(start, end)
|