segment_tree.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import operator
  2. from typing import Any, Optional
  3. class SegmentTree:
  4. """A Segment Tree data structure.
  5. https://en.wikipedia.org/wiki/Segment_tree
  6. Can be used as regular array, but with two important differences:
  7. a) Setting an item's value is slightly slower. It is O(lg capacity),
  8. instead of O(1).
  9. b) Offers efficient `reduce` operation which reduces the tree's values
  10. over some specified contiguous subsequence of items in the array.
  11. Operation could be e.g. min/max/sum.
  12. The data is stored in a list, where the length is 2 * capacity.
  13. The second half of the list stores the actual values for each index, so if
  14. capacity=8, values are stored at indices 8 to 15. The first half of the
  15. array contains the reduced-values of the different (binary divided)
  16. segments, e.g. (capacity=4):
  17. 0=not used
  18. 1=reduced-value over all elements (array indices 4 to 7).
  19. 2=reduced-value over array indices (4 and 5).
  20. 3=reduced-value over array indices (6 and 7).
  21. 4-7: values of the tree.
  22. NOTE that the values of the tree are accessed by indices starting at 0, so
  23. `tree[0]` accesses `internal_array[4]` in the above example.
  24. """
  25. def __init__(self,
  26. capacity: int,
  27. operation: Any,
  28. neutral_element: Optional[Any] = None):
  29. """Initializes a Segment Tree object.
  30. Args:
  31. capacity (int): Total size of the array - must be a power of two.
  32. operation (operation): Lambda obj, obj -> obj
  33. The operation for combining elements (eg. sum, max).
  34. Must be a mathematical group together with the set of
  35. possible values for array elements.
  36. neutral_element (Optional[obj]): The neutral element for
  37. `operation`. Use None for automatically finding a value:
  38. max: float("-inf"), min: float("inf"), sum: 0.0.
  39. """
  40. assert capacity > 0 and capacity & (capacity - 1) == 0, \
  41. "Capacity must be positive and a power of 2!"
  42. self.capacity = capacity
  43. if neutral_element is None:
  44. neutral_element = 0.0 if operation is operator.add else \
  45. float("-inf") if operation is max else float("inf")
  46. self.neutral_element = neutral_element
  47. self.value = [self.neutral_element for _ in range(2 * capacity)]
  48. self.operation = operation
  49. def reduce(self, start: int = 0, end: Optional[int] = None) -> Any:
  50. """Applies `self.operation` to subsequence of our values.
  51. Subsequence is contiguous, includes `start` and excludes `end`.
  52. self.operation(
  53. arr[start], operation(arr[start+1], operation(... arr[end])))
  54. Args:
  55. start (int): Start index to apply reduction to.
  56. end (Optional[int]): End index to apply reduction to (excluded).
  57. Returns:
  58. any: The result of reducing self.operation over the specified
  59. range of `self._value` elements.
  60. """
  61. if end is None:
  62. end = self.capacity
  63. elif end < 0:
  64. end += self.capacity
  65. # Init result with neutral element.
  66. result = self.neutral_element
  67. # Map start/end to our actual index space (second half of array).
  68. start += self.capacity
  69. end += self.capacity
  70. # Example:
  71. # internal-array (first half=sums, second half=actual values):
  72. # 0 1 2 3 | 4 5 6 7
  73. # - 6 1 5 | 1 0 2 3
  74. # tree.sum(0, 3) = 3
  75. # internally: start=4, end=7 -> sum values 1 0 2 = 3.
  76. # Iterate over tree starting in the actual-values (second half)
  77. # section.
  78. # 1) start=4 is even -> do nothing.
  79. # 2) end=7 is odd -> end-- -> end=6 -> add value to result: result=2
  80. # 3) int-divide start and end by 2: start=2, end=3
  81. # 4) start still smaller end -> iterate once more.
  82. # 5) start=2 is even -> do nothing.
  83. # 6) end=3 is odd -> end-- -> end=2 -> add value to result: result=1
  84. # NOTE: This adds the sum of indices 4 and 5 to the result.
  85. # Iterate as long as start != end.
  86. while start < end:
  87. # If start is odd: Add its value to result and move start to
  88. # next even value.
  89. if start & 1:
  90. result = self.operation(result, self.value[start])
  91. start += 1
  92. # If end is odd: Move end to previous even value, then add its
  93. # value to result. NOTE: This takes care of excluding `end` in any
  94. # situation.
  95. if end & 1:
  96. end -= 1
  97. result = self.operation(result, self.value[end])
  98. # Divide both start and end by 2 to make them "jump" into the
  99. # next upper level reduce-index space.
  100. start //= 2
  101. end //= 2
  102. # Then repeat till start == end.
  103. return result
  104. def __setitem__(self, idx: int, val: float) -> None:
  105. """
  106. Inserts/overwrites a value in/into the tree.
  107. Args:
  108. idx (int): The index to insert to. Must be in [0, `self.capacity`[
  109. val (float): The value to insert.
  110. """
  111. assert 0 <= idx < self.capacity, f"idx={idx} capacity={self.capacity}"
  112. # Index of the leaf to insert into (always insert in "second half"
  113. # of the tree, the first half is reserved for already calculated
  114. # reduction-values).
  115. idx += self.capacity
  116. self.value[idx] = val
  117. # Recalculate all affected reduction values (in "first half" of tree).
  118. idx = idx >> 1 # Divide by 2 (faster than division).
  119. while idx >= 1:
  120. update_idx = 2 * idx # calculate only once
  121. # Update the reduction value at the correct "first half" idx.
  122. self.value[idx] = self.operation(self.value[update_idx],
  123. self.value[update_idx + 1])
  124. idx = idx >> 1 # Divide by 2 (faster than division).
  125. def __getitem__(self, idx: int) -> Any:
  126. assert 0 <= idx < self.capacity
  127. return self.value[idx + self.capacity]
  128. def get_state(self):
  129. return self.value
  130. def set_state(self, state):
  131. assert len(state) == self.capacity * 2
  132. self.value = state
  133. class SumSegmentTree(SegmentTree):
  134. """A SegmentTree with the reduction `operation`=operator.add."""
  135. def __init__(self, capacity: int):
  136. super(SumSegmentTree, self).__init__(
  137. capacity=capacity, operation=operator.add)
  138. def sum(self, start: int = 0, end: Optional[Any] = None) -> Any:
  139. """Returns the sum over a sub-segment of the tree."""
  140. return self.reduce(start, end)
  141. def find_prefixsum_idx(self, prefixsum: float) -> int:
  142. """Finds highest i, for which: sum(arr[0]+..+arr[i - i]) <= prefixsum.
  143. Args:
  144. prefixsum (float): `prefixsum` upper bound in above constraint.
  145. Returns:
  146. int: Largest possible index (i) satisfying above constraint.
  147. """
  148. assert 0 <= prefixsum <= self.sum() + 1e-5
  149. # Global sum node.
  150. idx = 1
  151. # While non-leaf (first half of tree).
  152. while idx < self.capacity:
  153. update_idx = 2 * idx
  154. if self.value[update_idx] > prefixsum:
  155. idx = update_idx
  156. else:
  157. prefixsum -= self.value[update_idx]
  158. idx = update_idx + 1
  159. return idx - self.capacity
  160. class MinSegmentTree(SegmentTree):
  161. def __init__(self, capacity: int):
  162. super(MinSegmentTree, self).__init__(capacity=capacity, operation=min)
  163. def min(self, start: int = 0, end: Optional[Any] = None) -> Any:
  164. """Returns min(arr[start], ..., arr[end])"""
  165. return self.reduce(start, end)