123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370 |
- """Implementation of the CART algorithm to train decision tree classifiers."""
- import numpy as np
- import ray
- from sklearn import datasets, metrics
- import time
- import tempfile
- import os
- import json
- """Binary tree with decision tree semantics and ASCII visualization."""
- class Node:
- """A decision tree node."""
- def __init__(self, gini, num_samples, num_samples_per_class, predicted_class):
- self.gini = gini
- self.num_samples = num_samples
- self.num_samples_per_class = num_samples_per_class
- self.predicted_class = predicted_class
- self.feature_index = 0
- self.threshold = 0
- self.left = None
- self.right = None
- def debug(self, feature_names, class_names, show_details):
- """Print an ASCII visualization of the tree."""
- lines, _, _, _ = self._debug_aux(
- feature_names, class_names, show_details, root=True
- )
- for line in lines:
- print(line)
- def _debug_aux(self, feature_names, class_names, show_details, root=False):
- # See https://stackoverflow.com/a/54074933/1143396 for similar code.
- is_leaf = not self.right
- if is_leaf:
- lines = [class_names[self.predicted_class]]
- else:
- lines = [
- "{} < {:.2f}".format(feature_names[self.feature_index], self.threshold)
- ]
- if show_details:
- lines += [
- "gini = {:.2f}".format(self.gini),
- "samples = {}".format(self.num_samples),
- str(self.num_samples_per_class),
- ]
- width = max(len(line) for line in lines)
- height = len(lines)
- if is_leaf:
- lines = ["║ {:^{width}} ║".format(line, width=width) for line in lines]
- lines.insert(0, "╔" + "═" * (width + 2) + "╗")
- lines.append("╚" + "═" * (width + 2) + "╝")
- else:
- lines = ["│ {:^{width}} │".format(line, width=width) for line in lines]
- lines.insert(0, "┌" + "─" * (width + 2) + "┐")
- lines.append("└" + "─" * (width + 2) + "┘")
- lines[-2] = "┤" + lines[-2][1:-1] + "├"
- width += 4 # for padding
- if is_leaf:
- middle = width // 2
- lines[0] = lines[0][:middle] + "╧" + lines[0][middle + 1 :]
- return lines, width, height, middle
- # If not a leaf, must have two children.
- left, n, p, x = self.left._debug_aux(feature_names, class_names, show_details)
- right, m, q, y = self.right._debug_aux(feature_names, class_names, show_details)
- top_lines = [n * " " + line + m * " " for line in lines[:-2]]
- # fmt: off
- middle_line = x * " " + "┌" + (
- n - x - 1) * "─" + lines[-2] + y * "─" + "┐" + (m - y - 1) * " "
- bottom_line = x * " " + "│" + (
- n - x - 1) * " " + lines[-1] + y * " " + "│" + (m - y - 1) * " "
- # fmt: on
- if p < q:
- left += [n * " "] * (q - p)
- elif q < p:
- right += [m * " "] * (p - q)
- zipped_lines = zip(left, right)
- lines = (
- top_lines
- + [middle_line, bottom_line]
- + [a + width * " " + b for a, b in zipped_lines]
- )
- middle = n + width // 2
- if not root:
- lines[0] = lines[0][:middle] + "┴" + lines[0][middle + 1 :]
- return lines, n + m + width, max(p, q) + 2 + len(top_lines), middle
- class DecisionTreeClassifier:
- def __init__(self, max_depth=None, tree_limit=5000, feature_limit=2000):
- self.max_depth = max_depth
- self.tree_limit = tree_limit
- self.feature_limit = feature_limit
- def fit(self, X, y):
- """Build decision tree classifier."""
- self.n_classes_ = len(set(y)) # classes are assumed to go from 0 to n-1
- self.n_features_ = X.shape[1]
- self.tree_ = self._grow_tree(X, y)
- def predict(self, X):
- """Predict class for X."""
- return [self._predict(inputs) for inputs in X]
- def debug(self, feature_names, class_names, show_details=True):
- """Print ASCII visualization of decision tree."""
- self.tree_.debug(feature_names, class_names, show_details)
- def _gini(self, y):
- """Compute Gini impurity of a non-empty node.
- Gini impurity is defined as Σ p(1-p) over all classes, with p the freq
- class within the node. Since Σ p = 1, this is equivalent to 1 - Σ p^2.
- """
- m = y.size
- return 1.0 - sum((np.sum(y == c) / m) ** 2 for c in range(self.n_classes_))
- def _best_split(self, X, y):
- return best_split(self, X, y)
- def _grow_tree(self, X, y, depth=0):
- future = grow_tree_remote.remote(self, X, y, depth)
- return ray.get(future)
- def _predict(self, inputs):
- """Predict class for a single sample."""
- node = self.tree_
- while node.left:
- if inputs[node.feature_index] < node.threshold:
- node = node.left
- else:
- node = node.right
- return node.predicted_class
- def grow_tree_local(tree, X, y, depth):
- """Build a decision tree by recursively finding the best split."""
- # Population for each class in current node. The predicted class is the one
- # largest population.
- num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)]
- predicted_class = np.argmax(num_samples_per_class)
- node = Node(
- gini=tree._gini(y),
- num_samples=y.size,
- num_samples_per_class=num_samples_per_class,
- predicted_class=predicted_class,
- )
- # Split recursively until maximum depth is reached.
- if depth < tree.max_depth:
- idx, thr = tree._best_split(X, y)
- if idx is not None:
- indices_left = X[:, idx] < thr
- X_left, y_left = X[indices_left], y[indices_left]
- X_right, y_right = X[~indices_left], y[~indices_left]
- node.feature_index = idx
- node.threshold = thr
- node.left = grow_tree_local(tree, X_left, y_left, depth + 1)
- node.right = grow_tree_local(tree, X_right, y_right, depth + 1)
- return node
- @ray.remote
- def grow_tree_remote(tree, X, y, depth=0):
- """Build a decision tree by recursively finding the best split."""
- # Population for each class in current node. The predicted class is the one
- # largest population.
- num_samples_per_class = [np.sum(y == i) for i in range(tree.n_classes_)]
- predicted_class = np.argmax(num_samples_per_class)
- node = Node(
- gini=tree._gini(y),
- num_samples=y.size,
- num_samples_per_class=num_samples_per_class,
- predicted_class=predicted_class,
- )
- # Split recursively until maximum depth is reached.
- if depth < tree.max_depth:
- idx, thr = tree._best_split(X, y)
- if idx is not None:
- indices_left = X[:, idx] < thr
- X_left, y_left = X[indices_left], y[indices_left]
- X_right, y_right = X[~indices_left], y[~indices_left]
- node.feature_index = idx
- node.threshold = thr
- if len(X_left) > tree.tree_limit or len(X_right) > tree.tree_limit:
- left_future = grow_tree_remote.remote(tree, X_left, y_left, depth + 1)
- right_future = grow_tree_remote.remote(
- tree, X_right, y_right, depth + 1
- )
- node.left = ray.get(left_future)
- node.right = ray.get(right_future)
- else:
- node.left = grow_tree_local(tree, X_left, y_left, depth + 1)
- node.right = grow_tree_local(tree, X_right, y_right, depth + 1)
- return node
- def best_split_original(tree, X, y):
- """Find the best split for a node."""
- # Need at least two elements to split a node.
- m = y.size
- if m <= 1:
- return None, None
- # Count of each class in the current node.
- num_parent = [np.sum(y == c) for c in range(tree.n_classes_)]
- # Gini of current node.
- best_gini = 1.0 - sum((n / m) ** 2 for n in num_parent)
- best_idx, best_thr = None, None
- # Loop through all features.
- for idx in range(tree.n_features_):
- # Sort data along selected feature.
- thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
- # print("Classes are: ", classes, " ", thresholds)
- # We could actually split the node according to each feature/threshold
- # and count the resulting population for each class in the children,
- # instead we compute them in an iterative fashion, making this for loop
- # linear rather than quadratic.
- num_left = [0] * tree.n_classes_
- num_right = num_parent.copy()
- for i in range(1, m): # possible split positions
- c = classes[i - 1]
- # print("c is ", c, "num left is", len(num_left))
- num_left[c] += 1
- num_right[c] -= 1
- gini_left = 1.0 - sum(
- (num_left[x] / i) ** 2 for x in range(tree.n_classes_)
- )
- gini_right = 1.0 - sum(
- (num_right[x] / (m - i)) ** 2 for x in range(tree.n_classes_)
- )
- # The Gini impurity of a split is the weighted average of the Gini
- # impurity of the children.
- gini = (i * gini_left + (m - i) * gini_right) / m
- # The following condition is to make sure we don't try to split two
- # points with identical values for that feature, as it is impossibl
- # (both have to end up on the same side of a split).
- if thresholds[i] == thresholds[i - 1]:
- continue
- if gini < best_gini:
- best_gini = gini
- best_idx = idx
- best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint
- return best_idx, best_thr
- def best_split_for_idx(tree, idx, X, y, num_parent, best_gini):
- """Find the best split for a node and a given index"""
- # Sort data along selected feature.
- thresholds, classes = zip(*sorted(zip(X[:, idx], y)))
- # print("Classes are: ", classes, " ", thresholds)
- # We could actually split the node according to each feature/threshold pair
- # and count the resulting population for each class in the children, but
- # instead we compute them in an iterative fashion, making this for loop
- # linear rather than quadratic.
- m = y.size
- num_left = [0] * tree.n_classes_
- num_right = num_parent.copy()
- best_thr = float("NaN")
- for i in range(1, m): # possible split positions
- c = classes[i - 1]
- # print("c is ", c, "num left is", len(num_left))
- num_left[c] += 1
- num_right[c] -= 1
- gini_left = 1.0 - sum((num_left[x] / i) ** 2 for x in range(tree.n_classes_))
- gini_right = 1.0 - sum(
- (num_right[x] / (m - i)) ** 2 for x in range(tree.n_classes_)
- )
- # The Gini impurity of a split is the weighted average of the Gini
- # impurity of the children.
- gini = (i * gini_left + (m - i) * gini_right) / m
- # The following condition is to make sure we don't try to split two
- # points with identical values for that feature, as it is impossible
- # (both have to end up on the same side of a split).
- if thresholds[i] == thresholds[i - 1]:
- continue
- if gini < best_gini:
- best_gini = gini
- best_thr = (thresholds[i] + thresholds[i - 1]) / 2 # midpoint
- return best_gini, best_thr
- @ray.remote
- def best_split_for_idx_remote(tree, idx, X, y, num_parent, best_gini):
- return best_split_for_idx(tree, idx, X, y, num_parent, best_gini)
- def best_split(tree, X, y):
- """Find the best split for a node."""
- # Need at least two elements to split a node.
- m = y.size
- if m <= 1:
- return None, None
- # Count of each class in the current node.
- num_parent = [np.sum(y == c) for c in range(tree.n_classes_)]
- # Gini of current node.
- best_gini = 1.0 - sum((n / m) ** 2 for n in num_parent)
- best_idx, best_thr = -1, best_gini
- if m > tree.feature_limit:
- split_futures = [
- best_split_for_idx_remote.remote(tree, i, X, y, num_parent, best_gini)
- for i in range(tree.n_features_)
- ]
- best_splits = [ray.get(result) for result in split_futures]
- else:
- best_splits = [
- best_split_for_idx(tree, i, X, y, num_parent, best_gini)
- for i in range(tree.n_features_)
- ]
- ginis = np.array([x for (x, _) in best_splits])
- best_idx = np.argmin(ginis)
- best_thr = best_splits[best_idx][1]
- return best_idx, best_thr
- @ray.remote
- def run_in_cluster():
- dataset = datasets.fetch_covtype(data_home=tempfile.mkdtemp())
- X, y = dataset.data, dataset.target - 1
- training_size = 400000
- max_depth = 10
- clf = DecisionTreeClassifier(max_depth=max_depth)
- start = time.time()
- clf.fit(X[:training_size], y[:training_size])
- end = time.time()
- y_pred = clf.predict(X[training_size:])
- accuracy = metrics.accuracy_score(y[training_size:], y_pred)
- return end - start, accuracy
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--concurrency", type=int, default=1)
- args = parser.parse_args()
- ray.init(address=os.environ["RAY_ADDRESS"])
- futures = []
- for i in range(args.concurrency):
- print(f"concurrent run: {i}")
- futures.append(run_in_cluster.remote())
- time.sleep(10)
- for i, f in enumerate(futures):
- treetime, accuracy = ray.get(f)
- print(f"Tree {i} building took {treetime} seconds")
- print(f"Test Accuracy: {accuracy}")
- with open(os.environ["TEST_OUTPUT_JSON"], "w") as f:
- f.write(json.dumps({"build_time": treetime, "success": 1}))
|