Decision tree classifier

A simple implementation of the ID3 algorithm.


This post presents a simple but still fully functional Python3 implementation of a decision tree classifier. It is not aimed to be a tutorial on machine learning, classifications or even decision trees: there are a lot of resources on the web already. The main idea is to provide a Python example implementation for those who are familiar or comfortable with this language.

There are several decision tree algorithms that have been developed over time, each one improving or optimizing something over the predecessor. In this post, the implementation presented corresponds to the first well-known algorithm on the subject: the Iterative Dichotomiser 3 (ID3), developed in 1986 by Ross Quinlan.

For those familiar with the scikit-learn library, its documentation includes a specific section devoted to decision trees. This API provides a production-ready, fully parametric implementation of an optimized version of the CART algorithm.


The code is here:

Basically a tree is represented using Python dicts. A couple of very simple classes that extend dict where created to distinguish between tree or leaf nodes. Also, a namedtuple was defined to match each training sample with its corresponding class. The necessary information_gain and entropy functions where created, their implementations really simple thanks to Python’s standard collections lib.

Finally, the main piece of code is the tree-creation method: this is where all the magic happens.

def create_decision_tree(self, training_samples, predicting_features):
    """Recursively, create a desition tree and return the parent node."""

    if not predicting_features:
        # No more predicting features
        default_klass = self.get_most_common_class(training_samples)
        root_node = DecisionTreeLeaf(default_klass)
        klasses = [sample.klass for sample in training_samples]
        if len(set(klasses)) == 1:
            target_klass = training_samples[0].klass
            root_node = DecisionTreeLeaf(target_klass)
            best_feature = self.select_best_feature(training_samples,
            # Create the node to return and create the sub-tree.
            root_node = DecisionTreeNode(best_feature)
            best_feature_values = {s.sample[best_feature]
                                   for s in training_samples}
            for value in best_feature_values:
                samples = [s for s in training_samples
                           if s.sample[best_feature] == value]
                # Recursively, create a child node.
                child = self.create_decision_tree(samples,
                root_node[value] = child
    return root_node

Motivated by the already mentioned scikit-learn library, the algorithm is developed within a class with the following methods:

  • fit(training_samples, known_labels) : Creates the decision tree using the training data.
  • predict(samples) : given a fitted model, predict the label of a new set of data. It returns the learned label for each sample in the given array.
  • score(samples, known_labels) : predicts the labels for the given data samples and contrasts with the truth provided in the known_labels. Returns a score which is a number between 0 (no matches) and 1 (perfect match).

Other than that, the code is pretty much self explanatory. Using the standard Python module collections, the auxiliary methods (select_best_feature, information_gain, entropy) are very concise. The tree is easily implemented using dict:

  • Each node is either a leaf or a branch: If it is a leaf then it represents a class. If it is a branch, then it represents a feature.
  • Each branch has got as many children as possible values has the represented feature.

Then, to classify a given vector X = [f0, ..., fn], starting with the root of the generated tree:

  1. Take the root node (usually a branch, unless X has only one feature, which is not really useful).
  2. Such node will have a related feature, fi, so we check the value of X for the target feature: v = X [ fi]
  3. If node[v] is a leaf, then we assign the leaf’s related class to X.
  4. If v is not a key in node[v], then we can’t assign a class with the existing tree and we assign a default class (the most probable one).
  5. If node[v] is another branch, we repeat this procedure using the new node as root.

It is working in Python 3.4, some minor modifications are needed for it to work properly in Python 2 (for example, the entropy function relies on the fact that integer division in Python 3 returns float).

I’ll not dig further in the details as this is not supposed to be a tutorial or course on decision trees. Some minimal previous knowledge should be enough to understand the code. In any case, don’t hesitate to post your questions or comments.

To keep updated about Machine Learning, Data Processing and Complex Web Development follow us on @machinalis.

Previous / Next posts