dsu_union_rank.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. /**
  2. * @file
  3. * @brief [DSU (Disjoint
  4. * sets)](https://en.wikipedia.org/wiki/Disjoint-set-data_structure)
  5. * @details
  6. * dsu : It is a very powerful data structure which keeps track of different
  7. * clusters(sets) of elements, these sets are disjoint(doesnot have a common
  8. * element). Disjoint sets uses cases : for finding connected components in a
  9. * graph, used in Kruskal's algorithm for finding Minimum Spanning tree.
  10. * Operations that can be performed:
  11. * 1) UnionSet(i,j): add(element i and j to the set)
  12. * 2) findSet(i): returns the representative of the set to which i belogngs to.
  13. * 3) getParents(i): prints the parent of i and so on and so forth.
  14. * Below is the class-based approach which uses the heuristic of union-ranks.
  15. * Using union-rank in findSet(i),we are able to get to the representative of i
  16. * in slightly delayed O(logN) time but it allows us to keep tracks of the
  17. * parent of i.
  18. * @author [AayushVyasKIIT](https://github.com/AayushVyasKIIT)
  19. * @see dsu_path_compression.cpp
  20. */
  21. #include <cassert> /// for assert
  22. #include <iostream> /// for IO operations
  23. #include <vector> /// for std::vector
  24. using std::cout;
  25. using std::endl;
  26. using std::vector;
  27. /**
  28. * @brief Disjoint sets union data structure, class based representation.
  29. * @param n number of elements
  30. */
  31. class dsu {
  32. private:
  33. vector<uint64_t> p; ///< keeps track of the parent of ith element
  34. vector<uint64_t> depth; ///< tracks the depth(rank) of i in the tree
  35. vector<uint64_t> setSize; ///< size of each chunk(set)
  36. public:
  37. /**
  38. * @brief constructor for initialising all data members
  39. * @param n number of elements
  40. */
  41. explicit dsu(uint64_t n) {
  42. p.assign(n, 0);
  43. /// initially all of them are their own parents
  44. depth.assign(n, 0);
  45. setSize.assign(n, 0);
  46. for (uint64_t i = 0; i < n; i++) {
  47. p[i] = i;
  48. depth[i] = 0;
  49. setSize[i] = 1;
  50. }
  51. }
  52. /**
  53. * @brief Method to find the representative of the set to which i belongs
  54. * to, T(n) = O(logN)
  55. * @param i element of some set
  56. * @returns representative of the set to which i belongs to
  57. */
  58. uint64_t findSet(uint64_t i) {
  59. /// using union-rank
  60. while (i != p[i]) {
  61. i = p[i];
  62. }
  63. return i;
  64. }
  65. /**
  66. * @brief Method that combines two disjoint sets to which i and j belongs to
  67. * and make a single set having a common representative.
  68. * @param i element of some set
  69. * @param j element of some set
  70. * @returns void
  71. */
  72. void unionSet(uint64_t i, uint64_t j) {
  73. /// checks if both belongs to same set or not
  74. if (isSame(i, j)) {
  75. return;
  76. }
  77. /// we find representative of the i and j
  78. uint64_t x = findSet(i);
  79. uint64_t y = findSet(j);
  80. /// always keeping the min as x
  81. /// in order to create a shallow tree
  82. if (depth[x] > depth[y]) {
  83. std::swap(x, y);
  84. }
  85. /// making the shallower tree, root parent of the deeper root
  86. p[x] = y;
  87. /// if same depth, then increase one's depth
  88. if (depth[x] == depth[y]) {
  89. depth[y]++;
  90. }
  91. /// total size of the resultant set
  92. setSize[y] += setSize[x];
  93. }
  94. /**
  95. * @brief A utility function which check whether i and j belongs to same set
  96. * or not
  97. * @param i element of some set
  98. * @param j element of some set
  99. * @returns `true` if element i and j are in same set
  100. * @returns `false` if element i and j are not in same set
  101. */
  102. bool isSame(uint64_t i, uint64_t j) {
  103. if (findSet(i) == findSet(j)) {
  104. return true;
  105. }
  106. return false;
  107. }
  108. /**
  109. * @brief Method to print all the parents of i, or the path from i to
  110. * representative.
  111. * @param i element of some set
  112. * @returns void
  113. */
  114. vector<uint64_t> getParents(uint64_t i) {
  115. vector<uint64_t> ans;
  116. while (p[i] != i) {
  117. ans.push_back(i);
  118. i = p[i];
  119. }
  120. ans.push_back(i);
  121. return ans;
  122. }
  123. };
  124. /**
  125. * @brief Self-implementations, 1st test
  126. * @returns void
  127. */
  128. static void test1() {
  129. /* checks the parents in the resultant structures */
  130. uint64_t n = 10; ///< number of elements
  131. dsu d(n + 1); ///< object of class disjoint sets
  132. d.unionSet(2, 1); ///< performs union operation on 1 and 2
  133. d.unionSet(1, 4);
  134. d.unionSet(8, 1);
  135. d.unionSet(3, 5);
  136. d.unionSet(5, 6);
  137. d.unionSet(5, 7);
  138. d.unionSet(9, 10);
  139. d.unionSet(2, 10);
  140. // keeping track of the changes using parent pointers
  141. vector<uint64_t> ans = {7, 5};
  142. for (uint64_t i = 0; i < ans.size(); i++) {
  143. assert(d.getParents(7).at(i) ==
  144. ans[i]); // makes sure algorithm works fine
  145. }
  146. cout << "1st test passed!" << endl;
  147. }
  148. /**
  149. * @brief Self-implementations, 2nd test
  150. * @returns void
  151. */
  152. static void test2() {
  153. // checks the parents in the resultant structures
  154. uint64_t n = 10; ///< number of elements
  155. dsu d(n + 1); ///< object of class disjoint sets
  156. d.unionSet(2, 1); /// performs union operation on 1 and 2
  157. d.unionSet(1, 4);
  158. d.unionSet(8, 1);
  159. d.unionSet(3, 5);
  160. d.unionSet(5, 6);
  161. d.unionSet(5, 7);
  162. d.unionSet(9, 10);
  163. d.unionSet(2, 10);
  164. /// keeping track of the changes using parent pointers
  165. vector<uint64_t> ans = {2, 1, 10};
  166. for (uint64_t i = 0; i < ans.size(); i++) {
  167. assert(d.getParents(2).at(i) ==
  168. ans[i]); /// makes sure algorithm works fine
  169. }
  170. cout << "2nd test passed!" << endl;
  171. }
  172. /**
  173. * @brief Main function
  174. * @returns 0 on exit
  175. */
  176. int main() {
  177. test1(); // run 1st test case
  178. test2(); // run 2nd test case
  179. return 0;
  180. }