cart_with_tree.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. """Implementation of the CART algorithm to train decision tree classifiers."""
  2. import numpy as np
  3. import ray
  4. from sklearn import datasets, metrics
  5. import time
  6. import tempfile
  7. import os
  8. import json
  9. """Binary tree with decision tree semantics and ASCII visualization."""
  10. class Node:
  11. """A decision tree node."""
  12. def __init__(self, gini, num_samples, num_samples_per_class, predicted_class):
  13. self.gini = gini
  14. self.num_samples = num_samples
  15. self.num_samples_per_class = num_samples_per_class
  16. self.predicted_class = predicted_class
  17. self.feature_index = 0
  18. self.threshold = 0
  19. self.left = None
  20. self.right = None
  21. def debug(self, feature_names, class_names, show_details):
  22. """Print an ASCII visualization of the tree."""
  23. lines, _, _, _ = self._debug_aux(
  24. feature_names, class_names, show_details, root=True
  25. )
  26. for line in lines:
  27. print(line)
  28. def _debug_aux(self, feature_names, class_names, show_details, root=False):
  29. # See https://stackoverflow.com/a/54074933/1143396 for similar code.
  30. is_leaf = not self.right
  31. if is_leaf:
  32. lines = [class_names[self.predicted_class]]
  33. else:
  34. lines = [
  35. "{} < {:.2f}".format(feature_names[self.feature_index], self.threshold)
  36. ]
  37. if show_details:
  38. lines += [
  39. "gini = {:.2f}".format(self.gini),
  40. "samples = {}".format(self.num_samples),
  41. str(self.num_samples_per_class),
  42. ]
  43. width = max(len(line) for line in lines)
  44. height = len(lines)
  45. if is_leaf:
  46. lines = ["║ {:^{width}} ║".format(line, width=width) for line in lines]
  47. lines.insert(0, "╔" + "═" * (width + 2) + "╗")
  48. lines.append("╚" + "═" * (width + 2) + "╝")
  49. else:
  50. lines = ["│ {:^{width}} │".format(line, width=width) for line in lines]
  51. lines.insert(0, "┌" + "─" * (width + 2) + "┐")
  52. lines.append("└" + "─" * (width + 2) + "┘")
  53. lines[-2] = "┤" + lines[-2][1:-1] + "├"
  54. width += 4 # for padding
  55. if is_leaf:
  56. middle = width // 2
  57. lines[0] = lines[0][:middle] + "╧" + lines[0][middle + 1 :]
  58. return lines, width, height, middle
  59. # If not a leaf, must have two children.
  60. left, n, p, x = self.left._debug_aux(feature_names, class_names, show_details)
  61. right, m, q, y = self.right._debug_aux(feature_names, class_names, show_details)
  62. top_lines = [n * " " + line + m * " " for line in lines[:-2]]
  63. # fmt: off
  64. middle_line = x * " " + "┌" + (
  65. n - x - 1) * "─" + lines[-2] + y * "─" + "┐" + (m - y - 1) * " "
  66. bottom_line = x * " " + "│" + (
  67. n - x - 1) * " " + lines[-1] + y * " " + "│" + (m - y - 1) * " "
  68. # fmt: on
  69. if p < q:
  70. left += [n * " "] * (q - p)
  71. elif q < p:
  72. right += [m * " "] * (p - q)
  73. zipped_lines = zip(left, right)
  74. lines = (
  75. top_lines
  76. + [middle_line, bottom_line]
  77. + [a + width * " " + b for a, b in zipped_lines]
  78. )
  79. middle = n + width // 2
  80. if not root:
  81. lines[0] = lines[0][:middle] + "┴" + lines[0][middle + 1 :]
  82. return lines, n + m + width, max(p, q) + 2 + len(top_lines), middle
  83. class DecisionTreeClassifier:
  84. def __init__(self, max_depth=None, tree_limit=5000, feature_limit=2000):
  85. self.max_depth = max_depth
  86. self.tree_limit = tree_limit
  87. self.feature_limit = feature_limit
  88. def fit(self, X, y):
  89. """Build decision tree classifier."""
  90. self.n_classes_ = len(set(y)) # classes are assumed to go from 0 to n-1
  91. self.n_features_ = X.shape[1]
  92. self.tree_ = self._grow_tree(X, y)
  93. def predict(self, X):
  94. """Predict class for X."""
  95. return [self._predict(inputs) for inputs in X]
  96. def debug(self, feature_names, class_names, show_details=True):
  97. """Print ASCII visualization of decision tree."""
  98. self.tree_.debug(feature_names, class_names, show_details)
  99. def _gini(self, y):
  100. """Compute Gini impurity of a non-empty node.
  101. Gini impurity is defined as Σ p(1-p) over all classes, with p the freq
  102. class within the node. Since Σ p = 1, this is equivalent to 1 - Σ p^2.
  103. """
  104. m = y.size
  105. return 1.0 - sum((np.sum(y == c) / m) ** 2 for c in range(self.n_classes_))
  106. def _best_split(self, X, y):
  107. return best_split(self, X, y)
  108. def _grow_tree(self, X, y, depth=0):
  109. future = grow_tree_remote.remote(self, X, y, depth)
  110. return ray.get(future)
  111. def _predict(self, inputs):
  112. """Predict class for a single sample."""
  113. node = self.tree_
  114. while node.left:
  115. if inputs[node.feature_index] < node.threshold:
  116. node = node.left
  117. else:
  118. node = node.right
  119. return node.predicted_class
  120. def grow_tree_local(tree, X, y, depth):
  121. """Build a decision tree by recursively finding the best split."""
  122. # Population for each class in current node. The predicted class is the one
  123. # largest population.
  124. num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)]
  125. predicted_class = np.argmax(num_samples_per_class)
  126. node = Node(
  127. gini=tree._gini(y),
  128. num_samples=y.size,
  129. num_samples_per_class=num_samples_per_class,
  130. predicted_class=predicted_class,
  131. )
  132. # Split recursively until maximum depth is reached.
  133. if depth < tree.max_depth:
  134. idx, thr = tree._best_split(X, y)
  135. if idx is not None:
  136. indices_left = X[:, idx] < thr
  137. X_left, y_left = X[indices_left], y[indices_left]
  138. X_right, y_right = X[~indices_left], y[~indices_left]
  139. node.feature_index = idx
  140. node.threshold = thr
  141. node.left = grow_tree_local(tree, X_left, y_left, depth + 1)
  142. node.right = grow_tree_local(tree, X_right, y_right, depth + 1)
  143. return node
  144. @ray.remote
  145. def grow_tree_remote(tree, X, y, depth=0):
  146. """Build a decision tree by recursively finding the best split."""
  147. # Population for each class in current node. The predicted class is the one
  148. # largest population.
  149. num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)]
  150. predicted_class = np.argmax(num_samples_per_class)
  151. node = Node(
  152. gini=tree._gini(y),
  153. num_samples=y.size,
  154. num_samples_per_class=num_samples_per_class,
  155. predicted_class=predicted_class,
  156. )
  157. # Split recursively until maximum depth is reached.
  158. if depth < tree.max_depth:
  159. idx, thr = tree._best_split(X, y)
  160. if idx is not None:
  161. indices_left = X[:, idx] < thr
  162. X_left, y_left = X[indices_left], y[indices_left]
  163. X_right, y_right = X[~indices_left], y[~indices_left]
  164. node.feature_index = idx
  165. node.threshold = thr
  166. if len(X_left) > tree.tree_limit or len(X_right) > tree.tree_limit:
  167. left_future = grow_tree_remote.remote(tree, X_left, y_left, depth + 1)
  168. right_future = grow_tree_remote.remote(
  169. tree, X_right, y_right, depth + 1
  170. )
  171. node.left = ray.get(left_future)
  172. node.right = ray.get(right_future)
  173. else:
  174. node.left = grow_tree_local(tree, X_left, y_left, depth + 1)
  175. node.right = grow_tree_local(tree, X_right, y_right, depth + 1)
  176. return node
  177. def best_split_original(tree, X, y):
  178. """Find the best split for a node."""
  179. # Need at least two elements to split a node.
  180. m = y.size
  181. if m <= 1:
  182. return None, None
  183. # Count of each class in the current node.
  184. num_parent = [np.sum(y == c) for c in range(tree.n_classes_)]
  185. # Gini of current node.
  186. best_gini = 1.0 - sum((n / m) ** 2 for n in num_parent)
  187. best_idx, best_thr = None, None
  188. # Loop through all features.
  189. for idx in range(tree.n_features_):
  190. # Sort data along selected feature.
  191. thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
  192. # print("Classes are: ", classes, " ", thresholds)
  193. # We could actually split the node according to each feature/threshold
  194. # and count the resulting population for each class in the children,
  195. # instead we compute them in an iterative fashion, making this for loop
  196. # linear rather than quadratic.
  197. num_left = [0] * tree.n_classes_
  198. num_right = num_parent.copy()
  199. for i in range(1, m): # possible split positions
  200. c = classes[i - 1]
  201. # print("c is ", c, "num left is", len(num_left))
  202. num_left[c] += 1
  203. num_right[c] -= 1
  204. gini_left = 1.0 - sum(
  205. (num_left[x] / i) ** 2 for x in range(tree.n_classes_)
  206. )
  207. gini_right = 1.0 - sum(
  208. (num_right[x] / (m - i)) ** 2 for x in range(tree.n_classes_)
  209. )
  210. # The Gini impurity of a split is the weighted average of the Gini
  211. # impurity of the children.
  212. gini = (i * gini_left + (m - i) * gini_right) / m
  213. # The following condition is to make sure we don't try to split two
  214. # points with identical values for that feature, as it is impossibl
  215. # (both have to end up on the same side of a split).
  216. if thresholds[i] == thresholds[i - 1]:
  217. continue
  218. if gini < best_gini:
  219. best_gini = gini
  220. best_idx = idx
  221. best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint
  222. return best_idx, best_thr
  223. def best_split_for_idx(tree, idx, X, y, num_parent, best_gini):
  224. """Find the best split for a node and a given index"""
  225. # Sort data along selected feature.
  226. thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
  227. # print("Classes are: ", classes, " ", thresholds)
  228. # We could actually split the node according to each feature/threshold pair
  229. # and count the resulting population for each class in the children, but
  230. # instead we compute them in an iterative fashion, making this for loop
  231. # linear rather than quadratic.
  232. m = y.size
  233. num_left = [0] * tree.n_classes_
  234. num_right = num_parent.copy()
  235. best_thr = float("NaN")
  236. for i in range(1, m): # possible split positions
  237. c = classes[i - 1]
  238. # print("c is ", c, "num left is", len(num_left))
  239. num_left[c] += 1
  240. num_right[c] -= 1
  241. gini_left = 1.0 - sum((num_left[x] / i) ** 2 for x in range(tree.n_classes_))
  242. gini_right = 1.0 - sum(
  243. (num_right[x] / (m - i)) ** 2 for x in range(tree.n_classes_)
  244. )
  245. # The Gini impurity of a split is the weighted average of the Gini
  246. # impurity of the children.
  247. gini = (i * gini_left + (m - i) * gini_right) / m
  248. # The following condition is to make sure we don't try to split two
  249. # points with identical values for that feature, as it is impossible
  250. # (both have to end up on the same side of a split).
  251. if thresholds[i] == thresholds[i - 1]:
  252. continue
  253. if gini < best_gini:
  254. best_gini = gini
  255. best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint
  256. return best_gini, best_thr
  257. @ray.remote
  258. def best_split_for_idx_remote(tree, idx, X, y, num_parent, best_gini):
  259. return best_split_for_idx(tree, idx, X, y, num_parent, best_gini)
  260. def best_split(tree, X, y):
  261. """Find the best split for a node."""
  262. # Need at least two elements to split a node.
  263. m = y.size
  264. if m <= 1:
  265. return None, None
  266. # Count of each class in the current node.
  267. num_parent = [np.sum(y == c) for c in range(tree.n_classes_)]
  268. # Gini of current node.
  269. best_gini = 1.0 - sum((n / m) ** 2 for n in num_parent)
  270. best_idx, best_thr = -1, best_gini
  271. if m > tree.feature_limit:
  272. split_futures = [
  273. best_split_for_idx_remote.remote(tree, i, X, y, num_parent, best_gini)
  274. for i in range(tree.n_features_)
  275. ]
  276. best_splits = [ray.get(result) for result in split_futures]
  277. else:
  278. best_splits = [
  279. best_split_for_idx(tree, i, X, y, num_parent, best_gini)
  280. for i in range(tree.n_features_)
  281. ]
  282. ginis = np.array([x for (x, _) in best_splits])
  283. best_idx = np.argmin(ginis)
  284. best_thr = best_splits[best_idx][1]
  285. return best_idx, best_thr
  286. @ray.remote
  287. def run_in_cluster():
  288. dataset = datasets.fetch_covtype(data_home=tempfile.mkdtemp())
  289. X, y = dataset.data, dataset.target - 1
  290. training_size = 400000
  291. max_depth = 10
  292. clf = DecisionTreeClassifier(max_depth=max_depth)
  293. start = time.time()
  294. clf.fit(X[:training_size], y[:training_size])
  295. end = time.time()
  296. y_pred = clf.predict(X[training_size:])
  297. accuracy = metrics.accuracy_score(y[training_size:], y_pred)
  298. return end - start, accuracy
  299. if __name__ == "__main__":
  300. import argparse
  301. parser = argparse.ArgumentParser()
  302. parser.add_argument("--concurrency", type=int, default=1)
  303. args = parser.parse_args()
  304. ray.init(address=os.environ["RAY_ADDRESS"])
  305. futures = []
  306. for i in range(args.concurrency):
  307. print(f"concurrent run: {i}")
  308. futures.append(run_in_cluster.remote())
  309. time.sleep(10)
  310. for i, f in enumerate(futures):
  311. treetime, accuracy = ray.get(f)
  312. print(f"Tree {i} building took {treetime} seconds")
  313. print(f"Test Accuracy: {accuracy}")
  314. with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
  315. f.write(json.dumps({"build_time": treetime, "success": 1}))