# How To Implement The Decision Tree Algorithm From Scratch In Python

Decision trees are a powerful prediction method and extremely popular.

They are popular because the final model is so easy to understand by practitioners and domain experts alike. The final decision tree can explain exactly why a specific prediction was made, making it very attractive for operational use.

Decision trees also provide the foundation for more advanced ensemble methods such as bagging, random forests and gradient boosting.

In this tutorial, you will discover how to implement the Classification And Regression Tree algorithm from scratch with Python.

After completing this tutorial, you will know:

• How to calculate and evaluate candidate split points in a data.
• How to arrange splits into a decision tree structure.
• How to apply the classification and regression tree algorithm to a real problem.

Kick-start your project with my new book Machine Learning Algorithms From Scratch, including step-by-step tutorials and the Python source code files for all examples.

Let’s get started.

• Update Jan/2017: Changed the calculation of fold_size in cross_validation_split() to always be an integer. Fixes issues with Python 3.
• Update Feb/2017: Fixed a bug in build_tree.
• Update Aug/2017: Fixed a bug in Gini calculation, added the missing weighting of group Gini scores by group size (thanks Michael!).
• Update Aug/2018: Tested and updated to work with Python 3.6.

How To Implement The Decision Tree Algorithm From Scratch In Python
Photo by Martin Cathrae, some rights reserved.

## Descriptions

This section provides a brief introduction to the Classification and Regression Tree algorithm and the Banknote dataset used in this tutorial.

### Classification and Regression Trees

Classification and Regression Trees or CART for short is an acronym introduced by Leo Breiman to refer to Decision Tree algorithms that can be used for classification or regression predictive modeling problems.

We will focus on using CART for classification in this tutorial.

The representation of the CART model is a binary tree. This is the same binary tree from algorithms and data structures, nothing too fancy (each node can have zero, one or two child nodes).

A node represents a single input variable (X) and a split point on that variable, assuming the variable is numeric. The leaf nodes (also called terminal nodes) of the tree contain an output variable (y) which is used to make a prediction.

Once created, a tree can be navigated with a new row of data following each branch with the splits until a final prediction is made.

Creating a binary decision tree is actually a process of dividing up the input space. A greedy approach is used to divide the space called recursive binary splitting. This is a numerical procedure where all the values are lined up and different split points are tried and tested using a cost function.

The split with the best cost (lowest cost because we minimize cost) is selected. All input variables and all possible split points are evaluated and chosen in a greedy manner based on the cost function.

• Regression: The cost function that is minimized to choose split points is the sum squared error across all training samples that fall within the rectangle.
• Classification: The Gini cost function is used which provides an indication of how pure the nodes are, where node purity refers to how mixed the training data assigned to each node is.

Splitting continues until nodes contain a minimum number of training examples or a maximum tree depth is reached.

### Banknote Dataset

The banknote dataset involves predicting whether a given banknote is authentic given a number of measures taken from a photograph.

The dataset contains 1,372 rows with 5 numeric variables. It is a classification problem with two classes (binary classification).

Below provides a list of the five variables in the dataset.

1. variance of Wavelet Transformed image (continuous).
2. skewness of Wavelet Transformed image (continuous).
3. kurtosis of Wavelet Transformed image (continuous).
4. entropy of image (continuous).
5. class (integer).

Below is a sample of the first 5 rows of the dataset

Using the Zero Rule Algorithm to predict the most common class value, the baseline accuracy on the problem is about 50%.

## Tutorial

This tutorial is broken down into 5 parts:

1. Gini Index.
2. Create Split.
3. Build a Tree.
4. Make a Prediction.
5. Banknote Case Study.

These steps will give you the foundation that you need to implement the CART algorithm from scratch and apply it to your own predictive modeling problems.

### 1. Gini Index

The Gini index is the name of the cost function used to evaluate splits in the dataset.

A split in the dataset involves one input attribute and one value for that attribute. It can be used to divide training patterns into two groups of rows.

A Gini score gives an idea of how good a split is by how mixed the classes are in the two groups created by the split. A perfect separation results in a Gini score of 0, whereas the worst case split that results in 50/50 classes in each group result in a Gini score of 0.5 (for a 2 class problem).

Calculating Gini is best demonstrated with an example.

We have two groups of data with 2 rows in each group. The rows in the first group all belong to class 0 and the rows in the second group belong to class 1, so it’s a perfect split.

We first need to calculate the proportion of classes in each group.

The proportions for this example would be:

Gini is then calculated for each child node as follows:

The Gini index for each group must then be weighted by the size of the group, relative to all of the samples in the parent, e.g. all samples that are currently being grouped. We can add this weighting to the Gini calculation for a group as follows:

In this example the Gini scores for each group are calculated as follows:

The scores are then added across each child node at the split point to give a final Gini score for the split point that can be compared to other candidate split points.

The Gini for this split point would then be calculated as 0.0 + 0.0 or a perfect Gini score of 0.0.

Below is a function named gini_index() the calculates the Gini index for a list of groups and a list of known class values.

You can see that there are some safety checks in there to avoid a divide by zero for an empty group.

We can test this function with our worked example above. We can also test it for the worst case of a 50/50 split in each group. The complete example is listed below.

Running the example prints the two Gini scores, first the score for the worst case at 0.5 followed by the score for the best case at 0.0.

Now that we know how to evaluate the results of a split, let’s look at creating splits.

### 2. Create Split

A split is comprised of an attribute in the dataset and a value.

We can summarize this as the index of an attribute to split and the value by which to split rows on that attribute. This is just a useful shorthand for indexing into rows of data.

Creating a split involves three parts, the first we have already looked at which is calculating the Gini score. The remaining two parts are:

1. Splitting a Dataset.
2. Evaluating All Splits.

Let’s take a look at each.

#### 2.1. Splitting a Dataset

Splitting a dataset means separating a dataset into two lists of rows given the index of an attribute and a split value for that attribute.

Once we have the two groups, we can then use our Gini score above to evaluate the cost of the split.

Splitting a dataset involves iterating over each row, checking if the attribute value is below or above the split value and assigning it to the left or right group respectively.

Below is a function named test_split() that implements this procedure.

Not much to it.

Note that the right group contains all rows with a value at the index above or equal to the split value.

#### 2.2. Evaluating All Splits

With the Gini function above and the test split function we now have everything we need to evaluate splits.

Given a dataset, we must check every value on each attribute as a candidate split, evaluate the cost of the split and find the best possible split we could make.

Once the best split is found, we can use it as a node in our decision tree.

This is an exhaustive and greedy algorithm.

We will use a dictionary to represent a node in the decision tree as we can store data by name. When selecting the best split and using it as a new node for the tree we will store the index of the chosen attribute, the value of that attribute by which to split and the two groups of data split by the chosen split point.

Each group of data is its own small dataset of just those rows assigned to the left or right group by the splitting process. You can imagine how we might split each group again, recursively as we build out our decision tree.

Below is a function named get_split() that implements this procedure. You can see that it iterates over each attribute (except the class value) and then each value for that attribute, splitting and evaluating splits as it goes.

The best split is recorded and then returned after all checks are complete.

We can contrive a small dataset to test out this function and our whole dataset splitting process.

We can plot this dataset using separate colors for each class. You can see that it would not be difficult to manually pick a value of X1 (x-axis on the plot) to split this dataset.

CART Contrived Dataset

The example below puts all of this together.

The get_split() function was modified to print out each split point and it’s Gini index as it was evaluated.

Running the example prints all of the Gini scores and then prints the score of best split in the dataset of X1 < 6.642 with a Gini Index of 0.0 or a perfect split.

Now that we know how to find the best split points in a dataset or list of rows, let’s see how we can use it to build out a decision tree.

### 3. Build a Tree

Creating the root node of the tree is easy.

We call the above get_split() function using the entire dataset.

Adding more nodes to our tree is more interesting.

Building a tree may be divided into 3 main parts:

1. Terminal Nodes.
2. Recursive Splitting.
3. Building a Tree.

#### 3.1. Terminal Nodes

We need to decide when to stop growing a tree.

We can do that using the depth and the number of rows that the node is responsible for in the training dataset.

• Maximum Tree Depth. This is the maximum number of nodes from the root node of the tree. Once a maximum depth of the tree is met, we must stop splitting adding new nodes. Deeper trees are more complex and are more likely to overfit the training data.
• Minimum Node Records. This is the minimum number of training patterns that a given node is responsible for. Once at or below this minimum, we must stop splitting and adding new nodes. Nodes that account for too few training patterns are expected to be too specific and are likely to overfit the training data.

These two approaches will be user-specified arguments to our tree building procedure.

There is one more condition. It is possible to choose a split in which all rows belong to one group. In this case, we will be unable to continue splitting and adding child nodes as we will have no records to split on one side or another.

Now we have some ideas of when to stop growing the tree. When we do stop growing at a given point, that node is called a terminal node and is used to make a final prediction.

This is done by taking the group of rows assigned to that node and selecting the most common class value in the group. This will be used to make predictions.

Below is a function named to_terminal() that will select a class value for a group of rows. It returns the most common output value in a list of rows.

#### 3.2. Recursive Splitting

We know how and when to create terminal nodes, now we can build our tree.

Building a decision tree involves calling the above developed get_split() function over and over again on the groups created for each node.

New nodes added to an existing node are called child nodes. A node may have zero children (a terminal node), one child (one side makes a prediction directly) or two child nodes. We will refer to the child nodes as left and right in the dictionary representation of a given node.

Once a node is created, we can create child nodes recursively on each group of data from the split by calling the same function again.

Below is a function that implements this recursive procedure. It takes a node as an argument as well as the maximum depth, minimum number of patterns in a node and the current depth of a node.

You can imagine how this might be first called passing in the root node and the depth of 1. This function is best explained in steps:

1. Firstly, the two groups of data split by the node are extracted for use and deleted from the node. As we work on these groups the node no longer requires access to these data.
2. Next, we check if either left or right group of rows is empty and if so we create a terminal node using what records we do have.
3. We then check if we have reached our maximum depth and if so we create a terminal node.
4. We then process the left child, creating a terminal node if the group of rows is too small, otherwise creating and adding the left node in a depth first fashion until the bottom of the tree is reached on this branch.
5. The right side is then processed in the same manner, as we rise back up the constructed tree to the root.

#### 3.3. Building a Tree

We can now put all of the pieces together.

Building the tree involves creating the root node and calling the split() function that then calls itself recursively to build out the whole tree.

Below is the small build_tree() function that implements this procedure.

We can test out this whole procedure using the small dataset we contrived above.

Below is the complete example.

Also included is a small print_tree() function that recursively prints out nodes of the decision tree with one line per node. Although not as striking as a real decision tree diagram, it gives an idea of the tree structure and decisions made throughout.

We can vary the maximum depth argument as we run this example and see the effect on the printed tree.

With a maximum depth of 1 (the second parameter in the call to the build_tree() function), we can see that the tree uses the perfect split we discovered in the previous section. This is a tree with one node, also called a decision stump.

Increasing the maximum depth to 2, we are forcing the tree to make splits even when none are required. The X1 attribute is then used again by both the left and right children of the root node to split up the already perfect mix of classes.

Finally, and perversely, we can force one more level of splits with a maximum depth of 3.

These tests show that there is great opportunity to refine the implementation to avoid unnecessary splits. This is left as an extension.

Now that we can create a decision tree, let’s see how we can use it to make predictions on new data.

### 4. Make a Prediction

Making predictions with a decision tree involves navigating the tree with the specifically provided row of data.

Again, we can implement this using a recursive function, where the same prediction routine is called again with the left or the right child nodes, depending on how the split affects the provided data.

We must check if a child node is either a terminal value to be returned as the prediction, or if it is a dictionary node containing another level of the tree to be considered.

Below is the predict() function that implements this procedure. You can see how the index and value in a given node

You can see how the index and value in a given node is used to evaluate whether the row of provided data falls on the left or the right of the split.

We can use our contrived dataset to test this function. Below is an example that uses a hard-coded decision tree with a single node that best splits the data (a decision stump).

The example makes a prediction for each row in the dataset.

Running the example prints the correct prediction for each row, as expected.

We now know how to create a decision tree and use it to make predictions. Now, let’s apply it to a real dataset.

### 5. Banknote Case Study

This section applies the CART algorithm to the Bank Note dataset.

The first step is to load the dataset and convert the loaded data to numbers that we can use to calculate split points. For this we will use the helper function load_csv() to load the file and str_column_to_float() to convert string numbers to floats.

We will evaluate the algorithm using k-fold cross-validation with 5 folds. This means that 1372/5=274.4 or just over 270 records will be used in each fold. We will use the helper functions evaluate_algorithm() to evaluate the algorithm with cross-validation and accuracy_metric() to calculate the accuracy of predictions.

A new function named decision_tree() was developed to manage the application of the CART algorithm, first creating the tree from the training dataset, then using the tree to make predictions on a test dataset.

The complete example is listed below.

The example uses the max tree depth of 5 layers and the minimum number of rows per node to 10. These parameters to CART were chosen with a little experimentation, but are by no means are they optimal.

Running the example prints the average classification accuracy on each fold as well as the average performance across all folds.

You can see that CART and the chosen configuration achieved a mean classification accuracy of about 97% which is dramatically better than the Zero Rule algorithm that achieved 50% accuracy.

## Extensions

This section lists extensions to this tutorial that you may wish to explore.

• Algorithm Tuning. The application of CART to the Bank Note dataset was not tuned. Experiment with different parameter values and see if you can achieve better performance.
• Cross Entropy. Another cost function for evaluating splits is cross entropy (logloss). You could implement and experiment with this alternative cost function.
• Tree Pruning. An important technique for reducing overfitting of the training dataset is to prune the trees. Investigate and implement tree pruning methods.
• Categorical Dataset. The example was designed for input data with numerical or ordinal input attributes, experiment with categorical input data and splits that may use equality instead of ranking.
• Regression. Adapt the tree for regression using a different cost function and method for creating terminal nodes.
• More Datasets. Apply the algorithm to more datasets on the UCI Machine Learning Repository.

Did you explore any of these extensions?

## Review

In this tutorial, you discovered how to implement the decision tree algorithm from scratch with Python.

Specifically, you learned:

• How to select and evaluate split points in a training dataset.
• How to recursively build a decision tree from multiple splits.
• How to apply the CART algorithm to a real world classification predictive modeling problem.

Do you have any questions?

## Discover How to Code Algorithms From Scratch!

#### No Libraries, Just Python Code.

...with step-by-step tutorials on real-world datasets

Discover how in my new Ebook:
Machine Learning Algorithms From Scratch

It covers 18 tutorials with all the code for 12 top algorithms, like:
Linear Regression, k-Nearest Neighbors, Stochastic Gradient Descent and much more...

### 377 Responses to How To Implement The Decision Tree Algorithm From Scratch In Python

1. steve November 23, 2016 at 4:02 am #

Super good .. Thanks a lot for sharing

• Jason Brownlee November 23, 2016 at 9:01 am #

I’m glad you found it useful steve.

• MAk August 26, 2017 at 1:14 pm #

Hi , Could you explain what do it mean ?

[X1 < 6.642]
[X1 < 2.771]
[0]
[0]
[X1 < 7.498]
[1]
[1]

Does it mean if X1 <6.642 and X1 <2.771, it belongs to class 0, if X1 <6.642 and X1 < 7.498, it belong to class 1 ?
Mak

• Jason Brownlee August 27, 2017 at 5:46 am #

Yes, [X1 < 6.642] is the root node with two child nodes, each leaf node has a classification label.

• Furkon April 13, 2018 at 9:17 pm #

Hello, can you help me with this problem?

This project is expecting you to write two functions. The first will take the training data including the type of features and returns a decision tree best modeling the classification problem. This should not do any pruning of the tree. An optional part of the homework will require the pruning step. The second function will just apply the decision tree to a given set of data points.
You should follow the following for this homework:
1. Load the data and performa a quick analysis of what it is and what features it has. You will need to construct a vector indicating the type (values) of each of the features. In this case, you can assume that you have numeric (real or integer) and categorical values.
2. Implement the function “build_dt(X, y, attribute_types, options)”.
a. X: is the matrix of features/attributes for the training data. Each row includes a data sample.
b. y: The vector containing the class labels for each sample in the rows of X.
c. attribute_types: The vector containing (1: integer/real) or (2: categorical) indicating the type of each attributes (the columns of X).
d. options: Any options you might want to pass to your decision tree builder.
e. Returns a decision tree of the structure of your choice.
3. Implement the function “predict_dt(dt, X, options)”.
a. dt: The decision tree modeled by “build_dt” function.
b. X: is the matrix of features/attributes for the test data.
c. Returns a vector for the predicted class labels.
4. Report the performance of your implementation using an appropriate k-fold cross validation using confusion matrices on the given dataset.
[Optional] Implement the pruning strategy discussed in the class. Repeat the steps 4 above. Indicate any assumptions you might have made.

• Jason Brownlee April 14, 2018 at 6:44 am #

I think you should complete your own homework assignments.

• Edgar Panganiban July 23, 2018 at 2:44 pm #

Hi can you help me interpret the model the I created..thanks

[X1 < 2.000]
[X10 < 0.538]
[0.0]
[1.0]
[X15 < 10.200]
[X20 < 2.700]
[X10 < 0.574]
[X6 < 14.000]
[1.0]
[X8 < 256.000]
[1.0]
[X16 < 5.700]
[X18 < 0.600]
[X2 < 82.000]
[X15 < 4.100]
[0.0]
[0.0]
[1.0]
[X4 < 112.000]
[X1 < 3.000]
[0.0]
[0.0]
[0.0]
[0.0]
[0.0]
[0.0]
[1.0]

• Railot Railot November 17, 2018 at 6:45 pm #

• Adam Harris September 11, 2019 at 6:35 pm #

Hi Jason, can you use a decision tree in place of a rules engine?

• Jason Brownlee September 12, 2019 at 5:16 am #

It really depends on the specifics of the application and requirements of the stakeholders.

2. jonathan November 27, 2016 at 9:22 am #

can this code be used for a multinomial Decision tree dataset?

• Jason Brownlee November 27, 2016 at 10:22 am #

It can with some modification.

• STORM RICK November 29, 2016 at 1:02 am #

What modifications would you recommend?

• Jason Brownlee November 29, 2016 at 8:53 am #

Specifically the handling of evaluating and selecting nominal values at split points.

3. Mike December 24, 2016 at 2:48 pm #

Thanks for detailed description and code.

I tried to run and got ‘ValueError: empty range for randrange()’ in line 26:

index = randrange(len(dataset_copy))

if replace dataset_copy to list(dataset) and run this line manually it works.

• Jason Brownlee December 26, 2016 at 7:39 am #

Sounds like a Python 3 issue Mike.

Replace

With:

• Jason Brownlee January 3, 2017 at 9:53 am #

I have updated the cross_validation_split() function in the above example to address issues with Python 3.

4. Mohendra Roy January 4, 2017 at 12:32 am #

How about to use of euclidian distance instead of calculating for each element in the dataset?

• Jason Brownlee January 4, 2017 at 8:55 am #

What do you mean exactly? Are you able to elaborate?

5. Selva Rani B January 12, 2017 at 4:43 pm #

Thank you very much

6. Sokrates January 21, 2017 at 4:10 am #

Hi Jason,

Great tutorial on CART!

The results of decision trees are quite dependent on the training vs test data. With this in mind, how do I set the amount of training vs test data in the code right now to changes in the result? From what I can see, it looks like they are being set in the evaluate_algorithm method.

//Kind regards
Sokrates

7. Adeshina Alani January 27, 2017 at 3:52 am #

Nice Post. I will like to ask if i this implementation can be used for time series data with only one feature

8. vishal January 30, 2017 at 5:06 am #

Really helpful. Thanks a lot for sharing.

• Jason Brownlee February 1, 2017 at 10:18 am #

I’m glad you found the post useful vishal.

9. elberiver February 3, 2017 at 6:02 pm #

Hi Jason,

there is a minor point in your code. Specifically, in the follwing procedure:
——
# Build a decision tree
def build_tree(train, max_depth, min_size):
root = get_split(dataset)
split(root, max_depth, min_size, 1)
return root
——
I think it should be root = get_split(train), eventhough your code is still running correctly since dataset is the global variable.

Thank you for your nice posts.
I like your blog very much.

• Jason Brownlee February 4, 2017 at 9:59 am #

I think you’re right, nice catch!

I’ll investigate and fix up the example.

10. from Thailand March 8, 2017 at 2:35 pm #

Thanks a lot Jason, really helpful

• Jason Brownlee March 9, 2017 at 9:52 am #

• eve November 5, 2018 at 1:32 am #

Excuse me,I’m confused about the get_split function in your work.Actually, i encountered a problem that the function always returns a tree with only a depth of one even if i have two more features .As a novice,i found it a little difficult to work out the problem.So would you please answer the question for me?Just show me a direction is OK.Thanks.

• Jason Brownlee November 5, 2018 at 6:18 am #

Perhaps you data does not benefit from more than one spit?

Perhaps try a decision tree as part of the scikit-learn library?

11. Amit Moondra April 2, 2017 at 6:31 am #

I’m slowly going through your code and I’m confused about a line in your get_split function

groups = test_split(index, row[index], dataset)

Doesn’t this only return the left group? It seems we need both groups to calculate the gini_index?

Thank you.

• Jason Brownlee April 2, 2017 at 6:34 am #

Hi Amit,

The get_split() function evaluates all possible split points and returns a dict of the best found split point, gini score, split data.

• Amit Moondra April 2, 2017 at 12:15 pm #

After playing around with the code for a bit, I realized that function returns both groups (left and right) under one variable.

12. Amit Moondra April 2, 2017 at 9:59 am #

In the function split

if not left or not right:
node[‘left’] = node[‘right’] = to_terminal(left + right)
return

Why do you add (left + right)? Are you adding the two groups together into one group?

Thank you.

13. Amit Moondra April 2, 2017 at 12:30 pm #

Another question (line 132)

if isinstance(node[‘left’], dict):
return predict(node[‘left’], row)

isinstance is just checking if we have already created such a dictionary instance?

Thank you.

• Jason Brownlee April 4, 2017 at 9:05 am #

It is checking if the type of the variable is a dict.

• Cynthia Allan-Gyimah December 23, 2019 at 9:24 am #

My name is Cynthia I am very interested in machine learning and at the same time python but l am very new to coding can this book help me in coding? I am a geomatic engineer aspiring to upgrade myself into environmental engineering

14. Ann April 3, 2017 at 9:25 am #

Hello,

I’ve been trying some stuff out with this code and I thought I was understanding what was going on but when I tried it on a dataset with binary values it doesn’t seem to work and I can’t figure out why. Could you help me out please?

Thanks.

• Jason Brownlee April 4, 2017 at 9:11 am #

The example assumes real-valued inputs, binary or categorical inputs should be handled differently.

I don’t have an example at hand, sorry.

• Hendra Bunyamin March 15, 2019 at 12:18 pm #

Hello there,
If the inputs are categorical, it this the right way to split the inputs? Instead of using less than and greater than, I use == or !=.

Thanks

• Jason Brownlee March 15, 2019 at 2:30 pm #

Perhaps try it and see.

• Videl December 8, 2020 at 7:14 am #

Hey, Hendra

Did it work for you? When I tried with my dataset the accuracy dropped to 30%.

15. Dimple April 17, 2017 at 1:37 am #

Hi

Could you tell me how decision trees are used for predicting an unknown function when a sample dataset is given. What i mean how it is used for regression?

• Jason Brownlee April 17, 2017 at 5:15 am #

Good question, sorry, I don’t have an example of decision trees for regression from scratch.

16. Dimple April 17, 2017 at 10:18 am #

How can we use weka for regression using decision trees?

• Jason Brownlee April 18, 2017 at 8:28 am #

Consider using the search function of this blog.

17. Joe April 18, 2017 at 1:31 am #

Great article, this is exactly what I was looking for!

• Jason Brownlee April 18, 2017 at 8:33 am #

I’m really glad to hear that Joe!

18. ansar April 19, 2017 at 3:13 am #

I am new to machine learning … successfully ran the code with the given data set

Now I want to run it for my own data set … will the algo always treat the last column as the column to classify?

thanks

• Jason Brownlee April 19, 2017 at 7:54 am #

Yes, that is how it was coded.

• Ansar April 20, 2017 at 2:53 am #

Thank you! It works beautifully

19. Ansar April 21, 2017 at 10:22 pm #

Apologies if I am taking too much time but I tried to run this algo on the below scenario with 10 folds

#
#
#outlook
#
#1 = sunny
#2 = overcast
#3 = rain
#
#humidity
#
#1 = high
#2 = normal
#
#wind
#
#1 = weak
#2 = strong
#
#play
#
#0 = no
#1 = yes

The tree generated does not cater to x2, x3 variables (for some reason), just generates for x1 (what am I doing wrong?) … the accuracy has dropped to 60%

[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 3.000]
[X1 < 1.000]
[1.0]
[1.0]
[0.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
[X1 < 1.000]
[1.0]
[1.0]
Scores: [100.0, 100.0, 0.0, 100.0, 0.0, 100.0, 0.0, 0.0, 100.0, 100.0]
Mean Accuracy: 60.000%

• Jason Brownlee April 22, 2017 at 9:26 am #

20. Ansar April 25, 2017 at 3:01 am #

Yes, working fine now 🙂

Would love to get my hands on a script that would print the tree in a more graphical (readable) format. The current format helps but does get confusing at times.

Thanks a lot!

• Greg May 6, 2017 at 10:15 am #

This will generate a graphviz dot file that you can use to generate a .pnj, jpeg etc.

e.g:

dot -Tpng graph1.dot > graph.png

Note it generates a new file each time it is called – graph1.dot … graphN.dot

# code begin

def graph_node(f, node):
if isinstance(node, dict):
f.write(‘ %d [label=\”[X%d %d;\n’ % ((id(node), id(node[‘left’]))))
f.write(‘ %d -> %d;\n’ % ((id(node), id(node[‘right’]))))
graph_node(f, node[‘left’])
graph_node(f, node[‘right’])
else:
f.write(‘ %d [label=\”[%s]\”];\n’ % ((id(node), node)))

def graph_tree(node):
if not hasattr(graph_tree, ‘n’): graph_tree.n = 0
graph_tree.n += 1
fn = ‘graph’ + str(graph_tree.n) + ‘.dot’
f = open(fn, ‘w’)
f.write(‘digraph {\n’)
f.write(‘ node[shape=box];\n’)
graph_node(f, node)
f.write(‘}\n’)

• Edgar Panganiban July 22, 2018 at 10:41 pm #

Hello dude, Do I only need to put this codes at the end part?? You only define functions…How can I run this and on what parts of the code….Thanks

• Giorgia February 27, 2020 at 3:38 am #

Hello Greg. Your code doesn’t work in my case. Can you correct it?

21. katana April 28, 2017 at 1:27 am #

Thanks a lot for this, Dr. Brownlee!

• Jason Brownlee April 28, 2017 at 7:47 am #

I’m glad you found it useful.

22. godavari May 4, 2017 at 11:08 pm #

excellent explanation

23. King Deng May 14, 2017 at 7:15 am #

I’m implementing AdaBoost from scratch now, and I have a tough time understanding how to apply the sample weights calculated in each iteration to build the decision tree? I guess I should modify the gini index in this regard, but I’m not specifically sure how to do that. Could you shed some light? Thanks!

24. Pavithra May 19, 2017 at 7:10 pm #

Part of the code: predicted = algorithm(train_set, test_set, *args)

TypeError: ‘int’ object is not callable

• Jason Brownlee May 20, 2017 at 5:36 am #

I’m sorry to hear that.

Ensure that you have copied all of the code without any extra white space.

25. Luis Ilabaca May 25, 2017 at 9:23 am #

hey jason

honestly dude stuff like this is no joke man.

I did BA in math and one year of MA in math
then MA in Statistical Computing and Data Mining
and then sas certifications and a lot of R and man let me tell you,
when I read your work and see how you have such a strong understanding of the unifications of all the different fields needed to be successful at applying machine learning.

you my friend, are a killer.

26. Saurabh May 25, 2017 at 7:16 pm #

Hello Sir!!
First of all Thank You for such a great tutorial.

I would like to make a suggestion for function get_split()-
In this function instead of calculating gini index considering every value of that attribute in data set , we can just use the mean of that attribute as the split_value for test_split function.

This is just my idea please do correct me if this approach is wrong.

Thank You!!

27. Habiba June 4, 2017 at 6:36 am #

Hello Sir,
I am a student and i need to develop an algorithm for both Decision Tree and Ensemble(Preferably,Random Forest) both using python and R. i really need the book that contains everything that is,the super bundle.

Thank you so very much for the post and the tutorials. They have been really helpful.

28. Dev June 21, 2017 at 9:17 pm #

Hello Sir,
First of all Thank You for such a great tutorial
I am new to machine learning and python as well.I’m slowly going through your code
I have a doubt in the below section

# Build a decision tree
def build_tree(train, max_depth, min_size):
root = get_split(train)
split(root, max_depth, min_size, 1)
return root

In this section the “split” function returns “none”,Then how the changes made in “split” function are reflecting in the variable “root”
To know what values are stored in “root” variable, I run the code as below

# Build a decision tree
def build_tree(train, max_depth, min_size):
root = get_split(train)
print(root) — Before calling split function
split(root, max_depth, min_size, 1)
print(root) — after calling split function

The values in both cases are different.But I’m confused how it happens
Can anybody help me out please?
Thank you.

• Jason Brownlee June 22, 2017 at 6:05 am #

The split function adds child nodes to the passed in root node to build the tree.

29. Xinglong Li June 24, 2017 at 12:48 pm #

Hello sir
Thanks for this tutorial.
I guess that when sum the Gini indexes of subgroups into a total one, their corresponding group sizes should be taken into consideration, that is, I think it should be a weighted sum.

• Michael August 5, 2017 at 9:07 am #

I think I have the same question. The Gini Index computed in the above examples are not weighted by the size of the partitions. Shouldn’t they be? I posted a detailed example, but it has not been excepted (yet?).

30. sanju June 27, 2017 at 10:33 pm #

Hello Sir,

Assume we have created a model using the above algorithm.
Then if we go for prediction using this model ,will this model process on the whole training data.?

Thank you.

• Jason Brownlee June 28, 2017 at 6:25 am #

We do not need to evaluate the model when we want to make predictions. We fit it on all the training data then start using it.

31. Rohit June 28, 2017 at 7:38 pm #

Hello Sir,

Assume I have 1 million training data,and I created and saved a model to predict whether the patient is diabetic or not ,Every thing is ok with the model and the model is deployed in client side(hospital).consider the case below

1. If a new patient come then based on some input from the patient the model will predict whether the patient is diabetic or not.

2.If another patient come then also our model will predict

my doubt is in the both case will the model process on one million training data.?
that is if 100 patient come at different time, will this model process 100 times over one million data(training data).?

Thank you

32. Michael Shparber June 29, 2017 at 7:26 am #

Excellent, thank you for sharing!
Are there tools that allow to do this without ANY coding?
I mean drag-n-drop / right-click-options.
It is very useful to understand behind-the-scenes but much faster in many cases to use some sort of UI.
What would you recommend?
Thank you,
Michael

33. Ali Mesbahi July 5, 2017 at 12:35 am #

Hello,

Excellent post. It has been really helpful for me!
Is there a similar article/tutorial for the C4.5 algorithm?
Is there any implementations in R?

Thank you,
Ali.

34. Jeet July 9, 2017 at 9:35 pm #

dataset = list(lines)

Error: iterator should return strings, not bytes (did you open the file in text mode?)

How can i solve this issue ?

• Jason Brownlee July 11, 2017 at 10:17 am #

This might be a Python version issue.

35. AJENG SHILVIE NURLATIFAH July 12, 2017 at 6:03 pm #

i have a final task to use some classification model like decision tree to fill missing values in data set, is it possible ?

Thank you

36. Nghi July 21, 2017 at 12:43 am #

Hi Jason,

I tried running the modified version by adding labels for the data

#to add label but not part of the training
dataset = pd.DataFrame(data)
dataset.columns = [‘var’, ‘skew’, ‘curt’, ‘ent’, ‘bin’]

so on the last step, i only ran

n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print(‘Scores: %s’ % scores)
print(‘Mean Accuracy: %.3f%%’ % (sum(scores)/float(len(scores))))

and i got the error message saying

ValueError: empty range for randrange()

Could you help explain why?

Also on this section to load the CSV file

file = open(filename, “rb”)
dataset = list(lines)
return dataset

could you help explain how to add the location of the file if it’s on local drive?

Thank you very much,

37. preet shah July 24, 2017 at 9:04 pm #

What changes would you suggest i should make to solve the following problem:

A premium payer wants to improve the medical care using Machine Learning. They want to predict next events of Diagnosis, Procedure or Treatment that is going to be happen to patients. They’ve provided with patient journey information coded using ICD9.

You have to predict the next 10 events reported by patient in order of occurrence in 2014.

Data Description

The train data consists of patients information from Jan 2011 to Dec 2013. The test data consists of Patient IDs for the year 2014.

Variable Description
ID Patient ID
Date Period of Diagnosis
Event Event ID (ICD9 Format) – Target Variable

38. Ikib Kilam July 26, 2017 at 10:14 pm #

Jason,

Thanks for this wonderful example. I rewrote your code, exactly as you wrote it, but cannot replicate your scores and mean accuracy. I get the following:

Scores: [28.467153284671532, 37.591240875912405, 76.64233576642336, 68.97810218978103, 70.43795620437956]
Mean Accuracy: 56.42335766423357

As another run, I changed nFolds to be 7, maxDepth to be 4 and minSize to be 4 and got:

Scores: [40.30612244897959, 38.775510204081634, 37.755102040816325, 71.42857142857143, 70.91836734693877, 70.91836734693877, 69.89795918367348]
Mean Accuracy: 57.142857142857146

I have tried all types of combinations for nFolds, minSize and maxDepth, and even tried stopping randomizing the selection of data instances into Folds. However, my scores do not change, and my mean accuracy has never exceeded 60%, and in fact consistently is in the 55%+ range. Strangely, my first two scores are low and then they increase, though not to 80%. I am at my wits end, why I cannot replicate your high (80%+) scores/accuracy.

Any ideas on what might be happening? Thanks for taking the time to read my comment and again thanks for the wonderful illustration.

Ikib

• Jason Brownlee July 27, 2017 at 8:05 am #

That is odd.

Perhaps a copy-paste error somewhere? That would be my best guess.

39. buzznizz August 4, 2017 at 9:22 pm #

Hi Jason

It’s not usefull splitting on groups with size 1 (as you do now with groups size 0) and you can make them directly terminal.

BR

40. Michael August 5, 2017 at 8:22 am #

Thanks Jason for a very well put together tutorial!

Read through all the replies here to see if someone had asked the question I had. I think the question Xinglong Li asked on June 24, 2017 is the same one I have, but it wasn’t answered so I’ll rephrase:

Isn’t the calculation of the Gini Index suppose to be weighted by the count in each group?

For example, in the table just above section “3. Build a Tree”, on line 15, the output lists:
X2 < 2.209 Gini=0.934

When I compute this by hand, I get this exact value if I do NOT weight the calculation by the size of the partitions. When I do this calculation the way I think it's supposed to be done (weighting by partition size), I get Gini=0.4762. This is how I compute this value:

1) Start with the test data sorted by X1 so we can easily do the split (expressed in csv format):

X1,X2,Y
7.444542326,0.476683375,1
1.728571309,1.169761413,0
2.771244718,1.784783929,0
— group 1 above this line, group 2 below this line —
2.999208922,2.209014212,0
3.961043357,2.61995032,0
3.678319846,2.81281357,0
7.497545867,3.162953546,1
10.12493903,3.234550982,1
6.642287351,3.319983761,1
9.00220326,3.339047188,1

2) Compute the proportions (computing to 9 places, displaying 4):

P(1, 0) = group 1 (above line), class 0 = 0.6667
P(1, 1) = group 1 (above line), class 1 = 0.3333
P(2, 0) = group 2 (below line), class 0 = 0.4286
P(2, 1) = group 2 (below line), class 1 = 0.5714

3) Compute Gini Score WITHOUT weighting (computing to 9 places, displaying 4):

Gini Score =
[P(1, 0) x (1 – P(1, 0))] +
[P(1, 1) x (1 – P(1, 1))] +
[P(2, 0) x (1 – P(2, 0))] +
[P(2, 1) x (1 – P(2, 1))] =
0.2222 + 0.2222 + 0.2449 + 0.2449 = 0.934 (just what you got)

4) Compute Gini Score WITH weighting (computing to 9 places, displaying 4):

[(3/10) x (0.2222 + 0.2222)] + [(7/10) x (0.2449 + 0.2449)] = 0.4762

Could you explain why you don't weight by partition size when you compute your Gini Index?

Thanks!

• Jason Brownlee August 6, 2017 at 7:34 am #

Hi Michael, I believe the counts are weighted.

See “proportion” both in the description of the algorithm and the code calculation of gini for each class group.

See “an introduction to statistical learning with applications in r” pages 311 onwards for the equations.

Does that help?

• Michael August 9, 2017 at 7:41 am #

Hi Jason, Thanks for the reply. After reading and re-reading pages 311 and 312 several times, it seems to me that equation (8.6) in the ISL (love this book BTW, but in this rare case, it is lacking some important details) should really have a subscript m on the G because it is computing the Gini score for a particular region. Notice that in this equation: 1) the summation is over K (the total number of classes) and 2) the mk subscript on p hat. These imply that both the proportion (p hat sub mk) and G are with respect to a particular region m.

Eqn. (8.6) in the ISL is correct, but it’s not the final quanitity that should be used for determining the quality of the split. There is an additional step (not described in the ISL) which needs to be done: weighting the Gini scores by the size of the proposed split regions as shown in this example:

https://www.researchgate.net/post/How_to_compute_impurity_using_Gini_Index

I think the gini_index function should look something like what is shown below. This version gives me the values I expect and is consistent with how the gini score of a split is computed in the above example:

Thoughts?

• Michael August 9, 2017 at 7:44 am #

Sorry about the loss of indentation in the code… Seems like the webapp parses these out before posting.

• Jason Brownlee August 10, 2017 at 6:35 am #

I added some pre tags for you.

• Jason Brownlee August 10, 2017 at 6:35 am #

I’ll take a look, thanks for sharing.

• Michael August 10, 2017 at 11:56 am #

Thanks for your eyeballs. The original code I proposed works, but it has some unnecessary stuff in it. This is a cleaner (less lines) implementation (hopefully the pre tags I insert work…):

• Jason Brownlee August 10, 2017 at 4:42 pm #

Thanks. Added to my trello to review.

Update: Yes, you are 100% correct. Thank you so much for pointing this out and helping me see the fault Michael! I really appreciate it!

I did some homework on the calculation (checked some textbooks and read sklearns source) and wrote a new version of the gini calculation function from scratch. I then update the tutorial.

I think it’s good now, but shout if you see anything off.

41. nehasharma August 18, 2017 at 6:24 am #

print(gini_index([[[1, 1], [1, 0]], [[1, 1], [1, 0]]], [0, 1]))

how should i interpret the number of ZEROES and ONES(distribution of classes in group)?I am new to this..

• Jason Brownlee August 18, 2017 at 6:56 am #

Great question!

No, each array is a pattern, the final value in each array is a class value. The final array are the valid classes.

• John Sobola October 22, 2017 at 10:36 pm #

I did not notice any difference in the split point and prediction when I used both the old and revised gini codes

42. aaaaa August 28, 2017 at 12:08 am #

how to apply this for discrete value ?

43. scssek September 15, 2017 at 1:33 pm #

Do you recommend this decision tree model for binary file based data?

• Jason Brownlee September 16, 2017 at 8:37 am #

I do not recommend an algorithm. I recommend testing a suite of algorithms to see what works best.

44. Ann September 18, 2017 at 12:30 am #

Good Day sir!
I would like to ask why does my code run around 30-46 minutes to get the mean accuracy? I am running about 24100 rows of data with 3 columns. Is it normally this slow? Thank you very much!

45. Jarich September 19, 2017 at 12:12 am #

Can you leave an example with a larger hard-coded decision tree, with 2 or even 3 stages?
I can’t seem to get the syntax right to work with larger decision trees.

• Jason Brownlee September 19, 2017 at 7:47 am #

Thanks for the suggestion. Note that the code can develop such a tree.

• jarich September 19, 2017 at 5:42 pm #

Yeah I just found that out myself, might be better to not even give the example and let us think for some time.

46. Maria September 22, 2017 at 1:16 pm #

I would like to ask if there is a way to store the already learned decision tree? so that when the predict function is called for one test data only, it can be run at a much faster speed. Thank you very much!

• Jarich September 22, 2017 at 5:54 pm #

Hey Maria, you could save it to a .txt file and then read it back in for the prediction. I did it that way and it works flawlessly!

• Maria September 22, 2017 at 11:53 pm #

oh okay. I will save the result for build_tree to a text file? Thank you very much for this information! 😀

• Jason Brownlee September 23, 2017 at 5:37 am #

Great tip Jarich.

• Jason Brownlee September 23, 2017 at 5:35 am #

Yes, I would recommend using the sklearn library then save the fit model using pickle.

I have examples of this on my blog, use the search.

47. chakri September 26, 2017 at 2:01 am #

Jason : Thank you for a great artcile. Can u provide me the code for generating the tree. I have got scores and accuracy and would like to view the decision tree. Please provide the code related to generating tree

• chakri September 26, 2017 at 2:02 am #

i would need to code using graphviz

• Jason Brownlee September 26, 2017 at 5:39 am #

Thanks for the suggestion.

48. Maria September 30, 2017 at 11:45 pm #

I would like to know if this tree can also handle multiple classification aside from 0 or 1? for example I need to classify if it is 1,2,3,4,5 Thank you very much!

• Jason Brownlee October 1, 2017 at 9:07 am #

It could be extended to multiple classes. I do not have an example sorry.

49. Leo October 4, 2017 at 3:45 am #

I can’t get started.
I use Python 3.5 on Spyder 3.0.0.
Error: iterator should return strings, not bytes (did you open the file in text mode?)

I’m a beginner in both Python and ML.

Can you help?

Thank you, Leo

• Jason Brownlee October 4, 2017 at 5:48 am #

The code was developed for Python 2.7. I will update it for Python 3 in the future.

https://machinelearningmastery.com/start-here/#weka

• Andisheh July 4, 2018 at 11:48 am #

Hello, I am a beginner and I have the same problem, too. Why are codes different for Python 2 and 3? Shall I save the file as text or csv if I am using your codes? What should I do now? Thank you

• Jason Brownlee July 4, 2018 at 2:56 pm #

Python 2 and 3 are different programming language. The example is intended for Python2 and requires modification to work with Python3.

I will make those changes in coming months.

50. lightbandit October 12, 2017 at 4:18 am #

I have a dataset with titles on both axes (will post example below) and they have 1’s and 0’s signifying if the characteristic is present or not. Where would you suggest I start with altering your code in order to fit this dataset? Right now I am getting an error with the converting string to float function.

white blue tall
ch1 0 1 1
ch2 1 0 0
ch3 1 0 1

51. Liu Yong October 15, 2017 at 7:32 pm #

This is great!
May I know do you plan to introduce cost-complexity pruning for CART?

Many thanks!

• Jason Brownlee October 16, 2017 at 5:42 am #

Not at this stage. Perhaps you could post some links?

52. Revi October 19, 2017 at 3:27 am #

Hello sir, i was trying to use your code on my data, but i find a problem which said ValueError: invalid literal for float(): 1;0.766126609;45;2;0.802982129;9120;13;0;6;0;2 . How could i fix this?

53. Frackson October 21, 2017 at 10:27 pm #

Hello Sir,

Thanks for the tutorial. It is quite clear and concise. However, how best can you advise on altering it to suit the following data set:

54. Steven Lee October 25, 2017 at 8:22 am #

Hey Jason,

Thank you very much for this tutorial. It has been extremely helpful in my understanding of decision trees. I have one point of inquiry which requires clarification. In the ‘def split (node, max_depth, min_size, depth)’ method I can see that you recursively split the left nodes until the max depth or min size condition is met. When you perform the split you add 1 to the current depth to but once you move to the next if statement to process the first right child/partition there is no function that resets the depth.

Is this because in a recursive function you are saving new depth values for every iteration? Meaning unique depth values are saved at each recursion?

• Jason Brownlee October 25, 2017 at 3:58 pm #

Depth is passed down each line of the recursion.

55. Steven Lee October 25, 2017 at 9:00 am #

Also how is the return of build tree read?

i.e.:

in:

tree = build_tree(dataset, 10, 1)
print tree

out:

{‘index’: 0, ‘right’: {‘index’: 0, ‘right’: {‘index’: 0, ‘right’: 1, ‘value’: 7.497545867, ‘left’: 1}, ‘value’: 7.497545867, ‘left’: {‘index’: 0, ‘right’: 1, ‘value’: 7.444542326, ‘left’: 1}}, ‘value’: 6.642287351, ‘left’: {‘index’: 0, ‘right’: {‘index’: 0, ‘right’: 0, ‘value’: 2.771244718, ‘left’: 0}, ‘value’: 2.771244718, ‘left’: 0}}

• Jason Brownlee October 25, 2017 at 4:01 pm #

Each node had has references to left and right nodes.

56. Henrik October 31, 2017 at 12:03 am #

Hey,
I have a dataset with numerical and categorical values, so I modified some of the code in load_csv and main to convert categorical values to numerical:

#file = open(filename, “r”)
“shot_dist”,”pts_type”,”close_def_dist”,”target”]

cleanup_nums = {“location”: {“H”: 1, “A”: 0},
“w”: {“W”: 1, “L”: 0},
df.replace(cleanup_nums, inplace=True)

obj_df=list(df.values.flatten())
return obj_df

#main:
# Test CART on Basketball dataset
seed(1)
i=0
new_list=[]
while i<len(dataset):
new_list.append(dataset[i:i+13])
i+=13
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(new_list, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))

The program is still working for your dataset, but when i try to run it on my own, the program freezes in the listcomp in the line p = [row[-1] for row in group].count(class_val) / size, in the method gini_index.

Do you have any idea how i can break this infinite loop?

Thanks

57. amitabh November 30, 2017 at 6:28 pm #

Thanks it helps!

58. adan December 11, 2017 at 5:12 pm #

I don’t get this part row [-1] for row in group. please explain what’s happening here

59. Khalida December 15, 2017 at 2:22 am #

Just thanks <3

60. John Wang January 3, 2018 at 3:45 pm #

Really helpful, thanks a lot !

• Jason Brownlee January 3, 2018 at 3:46 pm #

Thanks, I’m glad to hear that.

61. Chetan February 1, 2018 at 9:59 pm #

Easy to understand, no need of any background knowledge. Excellent Article

62. Nandit Khosa February 10, 2018 at 3:45 am #

Great tutorial,can you help me using entropy criteria for splitting rather than gini index. Any example for that. I tried but not working for me

• Jason Brownlee February 10, 2018 at 8:59 am #

Sorry, I don’t have an example.

• Kevin Sequeira October 17, 2018 at 5:24 pm #

Hi Nandit,

I tried out Jason’s algorithm along with Entropy as cost function. Might be able to help you.

• Hitesh January 21, 2019 at 11:31 am #

@Kevin Sequeira

Hi Kevin,

I tried using entropy – information gain criteria in Jason’s algorithm, but it’s not working for me. Please, can you provide an example of its implementation?

63. Mala Gupta February 14, 2018 at 1:06 pm #

Hello Sir,

Is it possible to visualize the above tree? is it possible implement using Graphviz? Can you please provide an example to visualize the above decision tree?

64. dpk March 2, 2018 at 4:33 am #

Hi –

I notice that to_terminal is basically the Zero Rule algorithm. Is that just a coincidence (i.e. it might just be the least stupid thing to do with small data sets)? Or is there a deeper idea here that when you cut down the tree to whatever minimum data size you choose, you might want to apply some other predictive algorithm to the remaining subset (sort of using a decision tree to prefilter inputs to some other method)?

(If it’s just a coincidence, then I guess I’m guilty of over-fitting… 🙂 )

Thanks.

• Jason Brownlee March 2, 2018 at 5:36 am #

Just a coincidence, well done on noting it!

65. David Urive March 8, 2018 at 2:36 pm #

Im having trouble trying to classify the iris.csv dataset, how would you suggest modifying the code to to load in the attributes and a non numeric class?

• Jason Brownlee March 8, 2018 at 2:56 pm #

Load the data as a mixed numpy array, then covert the strings to integers via a label encoding.

I have many examples on the blog.

66. harry March 8, 2018 at 4:39 pm #

I am having loading dataset it says the following error ” iterator should return strings, not bytes (did you open the file in text mode?”

• Jason Brownlee March 9, 2018 at 6:19 am #

Are you able to confirm that you are using Python 2.7?

67. Suyog March 15, 2018 at 2:31 pm #

A typo under the subheading Banknote Dataset…

{The dataset contains 1,372 with 5 numeric variables.} should be
{The dataset contains 1,372 rows with 5 numeric variables.}

68. AKSHAY TENKALE March 17, 2018 at 2:38 am #

outcomes = [row[-1] for row in group]
what does this code snippet does.?

for Regression tree, to_terminal(group) function should return mean of target attribute values, but I am getting group varible having all NaN values.

69. Arvind March 18, 2018 at 6:28 pm #

Thanks, This is extremely useful.

We are trying to make a decision based on the attributes found in a WebElemnt on a WebPage for generating dynamic test cases.

for example, we want to feed some properties like type = “text” and name =”email” etc… and then come to a decision that its a text filed and generate test cases applicable for textfield. An example of what we want to read properties from

Any pointers, how to begin ? Any help would be appreciated. Basically how to start preparing the training dataset for links, images and other element types..

70. Kelvin April 6, 2018 at 5:58 am #

Hi Jason,

I must be blind. Once an attribute is used in a split, I don’t see you remove it from the next recursive split/branching.

Thanks!

• Jason Brownlee April 6, 2018 at 6:38 am #

It can be reused if it can add value, but note the subset of data at the next level down will be different – e.g. having been split.

71. Sergey Kojoian April 7, 2018 at 5:40 am #

Hello Jason,

Thanks for the great tutorial.

I am a bit confused by the definition of gini index.

gini_index = sum(proportion * (1.0 – proportion)) is just 2*p*q where p is the proportion of one class in the group and q is the proportion of the other, i.e.
gini_index = sum(proportion * (1.0 – proportion)) == 2p*(1-p)

Correct?

72. Adez April 28, 2018 at 8:53 pm #

Nice tutorial sir. I really enjoy going through you work sir. It has helped me to understand what is happening under the hood better. How can someone plot the logloss for this tutorial?

• Jason Brownlee April 29, 2018 at 6:25 am #

Thanks.

How would a plot of logloss work for a decision tree exactly? As it is being constructed? It might not be appropriate.

73. eye_water May 9, 2018 at 1:18 pm #

Hi Jason,
thanks for your tutorial, i want to translate your essay to Chinese not for business just for share.
Can i?
If i can, there are something i should notice?

• Jason Brownlee May 9, 2018 at 2:57 pm #

Please do not translate and republish my content.

• eye_water May 9, 2018 at 4:57 pm #

74. Parikshit Bhinde June 2, 2018 at 5:30 pm #

I’d recommend that for the recursive split explanation, you may add a note regarding Python creating a separate copy of the ‘node’ dictionary on each recursive call. It is not trivially known to people without extensive programming experience and can confuse them if they try to evaluate the expressions manually for the small ‘dataset’. Thanks otherwise for such a great article.

• Jason Brownlee June 3, 2018 at 6:21 am #

Thanks for the suggestion.

• Parikshit Bhinde June 3, 2018 at 1:39 pm #

Welcome. I ran through the whole recursive code for the small dataset manually myself just to get a hang of the flow. Used a different variable name for ‘node’s at each depth level. Here’s the code, incase it is helpful to anyone:
root = get_split(dataset)
max_depth = 3
min_size = 1
depth = 1
node = root.copy()
left, right = node[‘groups’]
del(node[‘groups’])
node[‘left’] = l1 = get_split(left)
depth = depth + 1
l1left, l1right = l1[‘groups’]
del(l1[‘groups’])
len(l1left) <= min_size
l1['left'] = to_terminal(l1left)
l1['right'] = l1r2 = get_split(l1right)
depth = depth + 1
l1r2left, l1r2right = l1r2['groups']
del(l1r2['groups'])
not l1r2left or not l1r2right
l1r2['left'] = l1r2['right'] = to_terminal(l1r2left + l1r2right)
depth = depth – 1
node['right'] = r1 = get_split(right)
r1left, r1right = r1['groups']
del(r1['groups'])
r1['left'] = r1l2 = get_split(r1left)
depth = depth + 1
r1l2left, r1l2right = r1l2['groups']
del(r1l2['groups'])
r1l2['left'], r1l2['right'] = to_terminal(r1l2left), to_terminal(r1l2right)
depth = depth – 1
r1['right'] = r1r2 = get_split(r1right)
depth = depth + 1
r1r2left, r1r2right = r1r2['groups']
del(r1r2['groups'])
not r1r2left or not r1r2right
r1r2['left'] = r1r2['right'] = to_terminal(r1r2left + r1r2right)

75. Vincent June 10, 2018 at 8:54 am #

What does the number after each X mean? (X1, X2, etc) In the print_tree function it stands for node[‘index’]. Is this the gini-index?
Why do only two indexes appear around your example-data?
I tried some other dataset and this is the result:
[X4 < 23.570]
[X1 < 1143.631]
[X1 < 1096.228]
[X1 < 905.490]
[X1 < 714.853]
[X1 < 524.248]
[0]
[0]
[0]
[0]
[0]
[X1 < 1143.631]
[0]
[0]
[X1 < 1373.940]
[X1 < 1292.075]
[X1 < 1276.251]
[X1 < 1141.360]
[X1 < 912.083]
[1]
[1]
[1]
[1]
[1]
[X1 < 1373.940]
[1]
[1]

• Jason Brownlee June 11, 2018 at 6:02 am #

Xn is the variable number, the floating point value number after is the observation of that variable chosen to be the split point.

76. Vincent June 10, 2018 at 8:58 am #

Another question:

How exactly do you have to modify your python-code to solve an n-dimensional problem with a decision tree?

I would be more than happy and thankful.

Anyway, thank you so much for your work here.

• Jason Brownlee June 11, 2018 at 6:04 am #

The code is to teach you how the algorithm works.

To solve real problems, I recommend using the sklearn library. I have many examples, search the blog.

77. Shivan June 12, 2018 at 7:49 pm #

Really nice writeup ! Can I know what these “groups” are exactly ? I know they’re 3d arrays having 2 groups of data but what do they hold ? Thanks !

• Jason Brownlee June 13, 2018 at 6:18 am #

Thanks. Which groups?

• Shivan June 14, 2018 at 2:56 am #

I was referring to the “groups” of data but I understood what they were after I read the code. Thank you for your time !

78. Soumyodipta Karmakar June 14, 2018 at 2:52 pm #

Hello,
I am working on decision tree classifier. Could you please share the coding of SLIQ – decission tree classifier in python language?

79. Vincent June 22, 2018 at 9:55 pm #

I highly recommend publishing this code here on github since it makes contributing very easy. I would then suggest a simple pruning-function. And you already suggest other improvements/features in your text. Plus: you as the owner get more inisghts and numbers concerning clones/usages etc. What do you think?

• Jason Brownlee June 23, 2018 at 6:17 am #

Thanks for the suggestion, but I prefer to not have my code on github.

• Vincent June 26, 2018 at 12:37 am #

Why?

• Jason Brownlee June 26, 2018 at 6:41 am #

Because I control this blog and all aspects of the experience, I have zero control over github, it is someone else’s platform.

I write code tutorials, I don’t just dump source code online. Context is critical.

Also sell books, which allows me to keep writing more tutorials.

80. TonyStorm August 3, 2018 at 10:56 pm #

thanks alot!!
it really helps!!!
I am a beginner of ML and I sort of have a language problem in reading those material , but not that difficult.
Q1:I wonder how to corectly use your book and your site tutorial and have a good combination of them?

81. Jun Li August 20, 2018 at 4:28 pm #

Hi, train_set.remove(fold) line would raise an error.
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

I solve this problem by refer to this note.
https://stackoverflow.com/questions/3157374/how-do-you-remove-a-numpy-array-from-a-list-of-numpy-arrays

removearray(train_set,fold)

Hope this can help others.

• Jason Brownlee August 21, 2018 at 6:10 am #

The code was written for Python 2.7. Are you using this version of Python?

82. Gaurav Nakum August 27, 2018 at 3:01 am #

Hi Jason!

First of all, wonderful article and thank you very much for writing it. I wanted to ask you regarding training a perceptron decision tree, where as you might know, each node is a perceptron. Now to build such a tree, ideally I should be training each node separately and depending on the output (say if its a binary classification problem with y = 0 or 1), I would split the dataset into two child nodes, provided the impurity is greater than threshold.

Since the split at each node depends on the weights used for the perceptron, the entire tree structure also depends on the weights. Thus, during training I would be constantly changing the tree structure (which might be computationally intensive). Or should I train each perceptron separately starting from the root and do the same for the subsequent nodes?

Thanks!

• Jason Brownlee August 27, 2018 at 6:14 am #

Sorry, I don’t have an example of such an algorithm.

83. mahi September 20, 2018 at 4:05 am #

Sir I have to implement decision tree and linear svc from scratch on a data set of comments (string) fetched from twitter the csd will contain comments under the hashtags student refugee and trip ….

• Jason Brownlee September 20, 2018 at 8:07 am #

Perhaps use a library like scikit-learn, it will be much easier!

84. reynold October 3, 2018 at 10:58 pm #

Sir I have implement this decision tree,but i want to visualize the tree,how can i do that?
by using sklearn i can use pydotplus.graph_from_dot_data,but i don’t know how to do it by using this tutorial

• Jason Brownlee October 4, 2018 at 6:18 am #

Sorry, I don’t have an example of plotting a decision tree.

85. Sajith Madusanka October 4, 2018 at 4:13 pm #

How can i connect mysql database to get dataset

• Jason Brownlee October 5, 2018 at 5:28 am #

I don’t see why not.

• Latha February 25, 2020 at 3:29 am #

Did u get code

86. Tarun October 5, 2018 at 7:30 pm #

Sir ,
I am doing project that will take images as data and will process it such that has to take a decision “Yes” or “No”……….suppose i give a data image as blue sky image…..it has to print yes only if its blue….else no…..this is to be done by Machine Learning …..Can you please help me telling a code for that….or any source ……please help me out

87. Lincy October 8, 2018 at 9:46 am #

Hi Jason, This is a great article which you have posted. Really helped me a lot.
By any chance do you have links of the extension of this tutorial – Cross-Entropy and Tree Pruning?

Thank you so much!

88. KK October 15, 2018 at 1:39 pm #

Thank for your article. It is really helpful to my assignment. In your case, “class” has been generated; either 0 or 1. How would you generate class if you haven’t done so? Any references on generating “class”. Thanks for your detailed explanation.

89. Arslan October 27, 2018 at 8:08 am #

I’m in a situation where i need to have a custom splitting function that splits on the basis of difference in rate (churn rates for different demographic groups) into two nodes, one with low difference and other with higher difference.

I worked out that function, but my issue now lies in splitting.If i have 3M rows of data, then i will have to calculate my metric 3M-1 times: for every combination of rows. is there a more efficient way to choose splits to evaluate?
Any thoughts would be appreciable. Thanks,

• Jason Brownlee October 28, 2018 at 6:04 am #

This question sounds specific to your dataset, I’m not sure I can answer it sensibly sorry.

• Arslan October 28, 2018 at 1:14 pm #

In the n x m dataset :The get_split(dataset) function above is calculating gini index for n-1 possible combinations of row-values for that column, before selecting the best split. Which means m x (n-1) evaluations of gini_index().

I was able bring down these gini evaluations to #of-unique-row-values in that column, but still is is pretty low.Is there a way to make it more efficient. I don’t know how sklearn or rpart do i so quickly.

• Jason Brownlee October 29, 2018 at 5:53 am #

I expect there is, I would recommend using an implementation from a robust library such as sklearn. This code is for learning how the algorithm works, not for operational use.

90. Steve October 30, 2018 at 7:28 pm #

Hi Jason, Great tutorial!! I ran your code on the small dataset and then replicated this process manually in Excel to gain a deeper understanding. Finally, I tried using the same data in sklearn but could not get the same results. What is sklearn in DecisionTreeClassifier doing differently/additionally from your code? Thanks!!!

• Jason Brownlee October 31, 2018 at 6:25 am #

The sklearn library will be doing many more things to ensure that the results are robust and that the implementations use best practices. These differences add up to give slightly different resulting trees.

91. kiv November 19, 2018 at 7:30 am #

Hi. Can anyone help me what is “algorithm” in “predicted = algorithm(train_set, test_set, *args)” line? Thanks a lot.

92. Anant November 22, 2018 at 1:25 am #

Hi Jason,

I built a CART model for my case study in my project using the similar code. But when run it, it throws error p=[row[-1] for row in group].count(class_val)/size ;invalid index to scalar variable.
Due to inter conversion of the dataset many times throughout the code,group finally becomes a list instead of a nd array, so row in group becomes a scalar variable for which we are trying to access [-1] index,resulting in that error.But how to fix it i am not able to figure out. Can you help me here ?

• Jason Brownlee November 22, 2018 at 6:25 am #

93. Dphan November 22, 2018 at 3:56 pm #

Hi Jason,
I’m building a tree based on your code but I get stuck when I try to visualize the tree. Can you show me how to do that?? Thanks

• Jason Brownlee November 23, 2018 at 7:42 am #

I show how to visualize with a simple ascii output.

94. AnnaFayn January 8, 2019 at 1:27 am #

Hi Jason! Thank you for the great tutorial.
I have a question regarding decision trees. Can I use label encoder for nominal categorical variables? I found a variety of answers for this question on the web by nothing unequivocal.

• Jason Brownlee January 8, 2019 at 6:52 am #

You can, but no need with decision trees.

I also recommend using sklearn’s implementation.

95. Inbar January 16, 2019 at 11:24 pm #

Thank you so much 🙂

96. Lisa January 23, 2019 at 8:14 pm #

Hello,

thanks a lot for the nice explanation.
I am wondering if I am able to add somewhere in the code an overall constraint regarding the test data.
For example, depending on to which group an instance is assigned, it has a specific value for a parameter. For all the test data the sum of these parameters have to equal something.

Is this somehow possible?

• Jason Brownlee January 24, 2019 at 6:44 am #

Sorry, Id on’t follow, perhaps you can elaborate what you mean?

97. Jugal January 31, 2019 at 4:08 pm #

Hey Dr. Brownlee,

Some comments: I needed to change rb to rt, and I removed the string to float function for my dataset of >400000 obs (since they were already float). I know you didn’t design this for such a large dataset, but it’s just for fun for now! If my calculations are correct, this will take my machine ~ 200 minutes.

Just fyi, it’s still working with Python 3.7 🙂

Best,

• Jason Brownlee February 1, 2019 at 5:33 am #

Thanks.

• cupps May 25, 2019 at 6:33 am #

Hello, is there a way to modify this for a large dataset? Wouldn’t want to wait for so long.
Thank you

• Jason Brownlee May 25, 2019 at 8:01 am #

Yes, I would encourage you to use the faster implementation in sklearn, this post is just for learning about the algorithm.

98. Suresh January 31, 2019 at 10:30 pm #

Hi Jason thanks for your material, in decision tree after every split we should remove the splitter column but your code is not doing that could you provide some sample code for the same

• Jason Brownlee February 1, 2019 at 5:37 am #

No, we do not remove the column used in the split in this algorithm.

99. Raghav February 13, 2019 at 9:43 am #

Thank you . It was super helpful in understanding the algorithm

• Jason Brownlee February 13, 2019 at 1:55 pm #

100. Laksh February 25, 2019 at 12:23 pm #

Very nice Jason! Really clear explanation and well-written code. I really appreciate you putting out this content for free!

• Jason Brownlee February 25, 2019 at 2:18 pm #

Thanks. I hope that it helps.

101. Daniel March 4, 2019 at 8:09 pm #

Hi Jason,

Firstly I must say this website is an absolute life-saver when it comes to learning ML without any prior experience – amazing job creating and maintaining these articles!

I am a 3rd year student and I was given my Final year individual project to do which is implementing ML onto an affordable embedded system to recognise human activities – in my case walking and running to start with – the project is mainly proof-of-concept. Until I found this post I was very stuck but when I did I ran it on my collected data (from an on-board accelerometer) which I labeled and got a staggering 93-96% accuracy on my PC.

The microcontroller I am using runs MicroPython – I am using the LoPy by Pycom if you are curious – hence it is fairly straightforward to adapt it to run on it with a very small modified dataset. However I am new to ML and still learning python and I was wondering given the example above what would you suggest I should try to attempt to have a labelled dataset for training and an unlabelled dataset for testing, could that be done easily?
I tried examining your cross-validation function and it seems to need labels to work properly (I looked at your post about it – https://machinelearningmastery.com/implement-resampling-methods-scratch-python/). Correct me if I’m wrong of course. Do you have any suggestions? Thank you again for what you are doing here!

• Jason Brownlee March 5, 2019 at 6:35 am #

Thanks.

Well done!

An unlabelled dataset means you are not testing, it means you are making a prediction on new data. No resampling is possible.

102. Umaib March 15, 2019 at 3:57 am #

103. Alejandro Estrella Gabilondo March 29, 2019 at 3:48 am #

Man, I can only say you are solid gold. You really “help developers (like you) get results”.

Awesome tutorial and unlike many others, practical and usable out of the box.

My most sincere thank you.

104. Harshit April 4, 2019 at 6:02 pm #

Hey Jason I know ML in C++ is not something you teavh but I am trying to implement decision tree in C++ can you help me by providing some code in C++..?

• Jason Brownlee April 5, 2019 at 6:12 am #

Sorry, I don’t have the capacity to write some cpp for you.

105. Devanshee April 13, 2019 at 2:56 am #

Thanks a lot ! I am new to algorithms. and your blog me helped me a lot to understand and more than that in implementation. Thanks again 🙂

106. Reynaldi May 22, 2019 at 6:31 am #

Thanks man. This gives me a light to implement c45 algorithm

107. Jay May 30, 2019 at 7:01 pm #

There is one thing which I find a bit odd:
In the “gini_index” you have the

p=[row[-1] for row in groups].count(label)

which just counts how many time the label appears as the last value of your feature vector (I assume each row in group is a feature vector – please correct me if Im wrong) – I dont see how that is usefull ?

• Jason Brownlee May 31, 2019 at 7:44 am #

The group is scored based on the ratio of each class. More purity (less mixing) is better.

• Henry Tremblay December 2, 2019 at 5:31 pm #

Yes, I found the same thing. For example, if you pass the following to the gini function, you get the same result:

gini_index([[[10, 1], [‘c’,’b’, 0]], [[‘e’, 1], [1, 0]]], [0, 1]))

You get the same result. In other words, the first element in each list does not matter. So what exactly does this data represent?

• Antonello July 31, 2020 at 10:45 pm #

Exactly. Same doubt here. The gini_index function seems to accept multidimensional data, but then the computation is done only with regard on the last dimension of the data, a bit puzzling…

108. Han Qi June 7, 2019 at 2:01 pm #

I’ve been thinking about refining the implementation to avoid unnecessary splits as suggested in the article.

Looks like the get_split function is testing every single value from a column. In the situation where all the labels are already the same but the predictors are different, if it had tested unique (sorted) values only, the smallest value in each column would be tested and set as the best gini score, and all of the rows would be allocated to the right side due to test_split checking for <, and the "if not left or not right:" check in the split function would terminate it due to an empty left side.

Assuming i don't do the unique + sort before checking for split points, another idea i have is to add the best gini score to the node dictionary too, and add another terminating condition of "if node['gini'] == 0: node['left'], node['right'] = to_terminal(left), to_terminal(right)".

Comparing the 2 strategies above, the gini check (2nd) strategy would allow the tree to terminate 1 level earlier than the "empty left side" check. However, doing unique+sort would cut down a lot of computation if the dataset has many duplicate values in it's columns, such as label encoded values from categorical data.(which usually repeat).

Any other ideas of how it can be cut shorter?/reduced computation?

I've been thinking about pruning too. What are the considerations/functions i may need to edit/variables to track?

• Jason Brownlee June 7, 2019 at 2:36 pm #

Yes, I would encourage you to look at the implementation in sklearn, it is much faster by design.

• Tom June 19, 2019 at 12:19 pm #

Thanks for the post, code, and all the follow ups! I also was wondering about avoiding some of the unnecessary splits at the bottom of the tree. In case others are curious, one short addition to the code would be to add

as the second line of get_split, and then

if not isinstance(node, dict): return

at the beginning of split.

It looks like this avoids unnecessary splitting in the example given and in some other datasets.

Thanks again!

109. Aravinda July 4, 2019 at 9:51 pm #

I want to implement c4.5 from scratch in python. Any links for this ?

110. Hatem Alamir July 20, 2019 at 5:08 pm #

I did some modification to deal with categorical variables,

I ran the program on another dataset (Loan Prediction) and it worked fine with about 79% accuracy.

111. Joseph M August 23, 2019 at 7:35 pm #

Hi Jason,

Thanks for producing such a great resource!

When evaluating all possible splits, how should you proceed if the impurity measure returns multiple local minimums (assuming this is possible?). So if for instance, multiple features contain a Gini index of zero, or if a feature were to contain equivalent minimum Gini indices at different split points?

Should you test all possible combinations of splits?

• Joseph M August 23, 2019 at 11:04 pm #

Also, as in the example, what’s reason for not applying a condition which terminates the node if pure?

• Jason Brownlee August 24, 2019 at 7:51 am #

It should (and may, I don’t recall). Great comment!

• Jason Brownlee August 24, 2019 at 7:48 am #

If all splits are even (same score), then a random spit can be chosen.

Yes, we test all splits as part of the algorithm.

112. alistar October 5, 2019 at 2:59 am #

Hi, can I use some of your code for a personal project? I am trying to implement a slightly different decision tree.

113. Jordan October 30, 2019 at 3:02 pm #

Hi Jason,

I found this example super helpful. And most of the time when we want to build decision trees, we need to define our own questions, and use these questions to create and build a decision tree. Now, you have essentially created your own functions to do the job.

In your experience, do you find that most often than not, you need to build your own decision tree from scratch and not use models such as XGBoost, Sklearn etc to do the underlying work? I guess I am also asking would you be able to do what you did using an open-source library like sklean or XGBoost or other boosted open-source tools? Would a scientist be able to ask their own questions and provide their own split values using open-source tools, or do you need to construct your own decision tree, prediction mechanism and accuracy calculation functions or can this be created using open-source tools? Hope you can understand what I am getting at.

and thought this was the most interesting as you built it from scratch.

• Jason Brownlee October 31, 2019 at 5:26 am #

Thanks, I’m happy to hear that.

No, I very rarely write code from scratch – it’s too easy to introduce bugs and write slow code. I use libs because they are tested and are fast.

Yes, I recommend using open source libs in nearly all cases.

114. himanshu November 10, 2019 at 2:02 pm #

“A perfect separation results in a Gini score of 0, whereas the worst case split that results in 50/50 classes in each group result in a Gini score of 0.5 (for a 2 class problem).” wait don’t you think this is wrong, 0.5 means perfect seperation.

• Jason Brownlee November 11, 2019 at 6:04 am #

No, in this case it means a worst case where the classes are mixed in each group.

115. Mehmet Yılmaz December 11, 2019 at 7:52 am #

The code line:
file = open(filename, “rb”)

is incompatible with the data file , it should be instead:
file = open(filename, “rt”)

116. Carol December 17, 2019 at 12:56 pm #

Hi Jason, I’m not very familiar with the Gini index, so please ignore me if I am wrong. But according to the book ‘Elements of Statistical Learning’, the definition of Gini index is

gini = 1-sum(p_k * (1-p_k))

This one also seems to me to be more intuitive to understand.