SALE! Use code blackfriday for 40% off everything!
Hurry, sale ends soon! Click to see the full catalog.

How to Identify Overfitting Machine Learning Models in Scikit-Learn

Last Updated on November 27, 2020

Overfitting is a common explanation for the poor performance of a predictive model.

An analysis of learning dynamics can help to identify whether a model has overfit the training dataset and may suggest an alternate configuration to use that could result in better predictive performance.

Performing an analysis of learning dynamics is straightforward for algorithms that learn incrementally, like neural networks, but it is less clear how we might perform the same analysis with other algorithms that do not learn incrementally, such as decision trees, k-nearest neighbors, and other general algorithms in the scikit-learn machine learning library.

In this tutorial, you will discover how to identify overfitting for machine learning models in Python.

After completing this tutorial, you will know:

  • Overfitting is a possible cause of poor generalization performance of a predictive model.
  • Overfitting can be analyzed for machine learning models by varying key model hyperparameters.
  • Although overfitting is a useful tool for analysis, it must not be confused with model selection.

Let’s get started.

Identify Overfitting Machine Learning Models With Scikit-Learn

Identify Overfitting Machine Learning Models With Scikit-Learn
Photo by Bonnie Moreland, some rights reserved.

Tutorial Overview

This tutorial is divided into five parts; they are:

  1. What Is Overfitting
  2. How to Perform an Overfitting Analysis
  3. Example of Overfitting in Scikit-Learn
  4. Counterexample of Overfitting in Scikit-Learn
  5. Separate Overfitting Analysis From Model Selection

What Is Overfitting

Overfitting refers to an unwanted behavior of a machine learning algorithm used for predictive modeling.

It is the case where model performance on the training dataset is improved at the cost of worse performance on data not seen during training, such as a holdout test dataset or new data.

We can identify if a machine learning model has overfit by first evaluating the model on the training dataset and then evaluating the same model on a holdout test dataset.

If the performance of the model on the training dataset is significantly better than the performance on the test dataset, then the model may have overfit the training dataset.

We care about overfitting because it is a common cause for “poor generalization” of the model as measured by high “generalization error.” That is error made by the model when making predictions on new data.

This means, if our model has poor performance, maybe it is because it has overfit.

But what does it mean if a model’s performance is “significantly better” on the training set compared to the test set?

For example, it is common and perhaps normal for the model to have better performance on the training set than the test set.

As such, we can perform an analysis of the algorithm on the dataset to better expose the overfitting behavior.

How to Perform an Overfitting Analysis

An overfitting analysis is an approach for exploring how and when a specific model is overfitting on a specific dataset.

It is a tool that can help you learn more about the learning dynamics of a machine learning model.

This might be achieved by reviewing the model behavior during a single run for algorithms like neural networks that are fit on the training dataset incrementally.

A plot of the model performance on the train and test set can be calculated at each point during training and plots can be created. This plot is often called a learning curve plot, showing one curve for model performance on the training set and one curve for the test set for each increment of learning.

If you would like to learn more about learning curves for algorithms that learn incrementally, see the tutorial:

The common pattern for overfitting can be seen on learning curve plots, where model performance on the training dataset continues to improve (e.g. loss or error continues to fall or accuracy continues to rise) and performance on the test or validation set improves to a point and then begins to get worse.

If this pattern is observed, then training should stop at that point where performance gets worse on the test or validation set for algorithms that learn incrementally

This makes sense for algorithms that learn incrementally like neural networks, but what about other algorithms?

  • How do you perform an overfitting analysis for machine learning algorithms in scikit-learn?

One approach for performing an overfitting analysis on algorithms that do not learn incrementally is by varying a key model hyperparameter and evaluating the model performance on the train and test sets for each configuration.

To make this clear, let’s explore a case of analyzing a model for overfitting in the next section.

Example of Overfitting in Scikit-Learn

In this section, we will look at an example of overfitting a machine learning model to a training dataset.

First, let’s define a synthetic classification dataset.

We will use the make_classification() function to define a binary (two class) classification prediction problem with 10,000 examples (rows) and 20 input features (columns).

The example below creates the dataset and summarizes the shape of the input and output components.

Running the example creates the dataset and reports the shape, confirming our expectations.

Next, we need to split the dataset into train and test subsets.

We will use the train_test_split() function and split the data into 70 percent for training a model and 30 percent for evaluating it.

Running the example splits the dataset and we can confirm that we have 7,000 examples for training and 3,000 for evaluating a model.

Next, we can explore a machine learning model overfitting the training dataset.

We will use a decision tree via the DecisionTreeClassifier and test different tree depths with the “max_depth” argument.

Shallow decision trees (e.g. few levels) generally do not overfit but have poor performance (high bias, low variance). Whereas deep trees (e.g. many levels) generally do overfit and have good performance (low bias, high variance). A desirable tree is one that is not so shallow that it has low skill and not so deep that it overfits the training dataset.

We evaluate decision tree depths from 1 to 20.

We will enumerate each tree depth, fit a tree with a given depth on the training dataset, then evaluate the tree on both the train and test sets.

The expectation is that as the depth of the tree increases, performance on train and test will improve to a point, and as the tree gets too deep, it will begin to overfit the training dataset at the expense of worse performance on the holdout test set.

At the end of the run, we will then plot all model accuracy scores on the train and test sets for visual comparison.

Tying this together, the complete example of exploring different tree depths on the synthetic binary classification dataset is listed below.

Running the example fits and evaluates a decision tree on the train and test sets for each tree depth and reports the accuracy scores.

Note: Your results may vary given the stochastic nature of the algorithm or evaluation procedure, or differences in numerical precision. Consider running the example a few times and compare the average outcome.

In this case, we can see a trend of increasing accuracy on the training dataset with the tree depth to a point around a depth of 19-20 levels where the tree fits the training dataset perfectly.

We can also see that the accuracy on the test set improves with tree depth until a depth of about eight or nine levels, after which accuracy begins to get worse with each increase in tree depth.

This is exactly what we would expect to see in a pattern of overfitting.

We would choose a tree depth of eight or nine before the model begins to overfit the training dataset.

A figure is also created that shows line plots of the model accuracy on the train and test sets with different tree depths.

The plot clearly shows that increasing the tree depth in the early stages results in a corresponding improvement in both train and test sets.

This continues until a depth of around 10 levels, after which the model is shown to overfit the training dataset at the cost of worse performance on the holdout dataset.

Line Plot of Decision Tree Accuracy on Train and Test Datasets for Different Tree Depths

Line Plot of Decision Tree Accuracy on Train and Test Datasets for Different Tree Depths

This analysis is interesting. It shows why the model has a worse hold-out test set performance when “max_depth” is set to large values.

But it is not required.

We can just as easily choose a “max_depth” using a grid search without performing an analysis on why some values result in better performance and some result in worse performance.

In fact, in the next section, we will show where this analysis can be misleading.

Counterexample of Overfitting in Scikit-Learn

Sometimes, we may perform an analysis of machine learning model behavior and be deceived by the results.

A good example of this is varying the number of neighbors for the k-nearest neighbors algorithms, which we can implement using the KNeighborsClassifier class and configure via the “n_neighbors” argument.

Let’s forget how KNN works for the moment.

We can perform the same analysis of the KNN algorithm as we did in the previous section for the decision tree and see if our model overfits for different configuration values. In this case, we will vary the number of neighbors from 1 to 50 to get more of the effect.

The complete example is listed below.

Running the example fits and evaluates a KNN model on the train and test sets for each number of neighbors and reports the accuracy scores.

Note: Your results may vary given the stochastic nature of the algorithm or evaluation procedure, or differences in numerical precision. Consider running the example a few times and compare the average outcome.

Recall, we are looking for a pattern where performance on the test set improves and then starts to get worse, and performance on the training set continues to improve.

We do not see this pattern.

Instead, we see that accuracy on the training dataset starts at perfect accuracy and falls with almost every increase in the number of neighbors.

We also see that performance of the model on the holdout test improves to a value of about five neighbors, holds level and begins a downward trend after that.

A figure is also created that shows line plots of the model accuracy on the train and test sets with different numbers of neighbors.

The plots make the situation clearer. It looks as though the line plot for the training set is dropping to converge with the line for the test set. Indeed, this is exactly what is happening.

Line Plot of KNN Accuracy on Train and Test Datasets for Different Numbers of Neighbors

Line Plot of KNN Accuracy on Train and Test Datasets for Different Numbers of Neighbors

Now, recall how KNN works.

The “model” is really just the entire training dataset stored in an efficient data structure. Skill for the “model” on the training dataset should be 100 percent and anything less is unforgivable.

In fact, this argument holds for any machine learning algorithm and slices to the core of the confusion around overfitting for beginners.

Separate Overfitting Analysis From Model Selection

Overfitting can be an explanation for poor performance of a predictive model.

Creating learning curve plots that show the learning dynamics of a model on the train and test dataset is a helpful analysis for learning more about a model on a dataset.

But overfitting should not be confused with model selection.

We choose a predictive model or model configuration based on its out-of-sample performance. That is, its performance on new data not seen during training.

The reason we do this is that in predictive modeling, we are primarily interested in a model that makes skillful predictions. We want the model that can make the best possible predictions given the time and computational resources we have available.

This might mean we choose a model that looks like it has overfit the training dataset. In which case, an overfit analysis might be misleading.

It might also mean that the model has poor or terrible performance on the training dataset.

In general, if we cared about model performance on the training dataset in model selection, then we would expect a model to have perfect performance on the training dataset. It’s data we have available; we should not tolerate anything less.

As we saw with the KNN example above, we can achieve perfect performance on the training set by storing the training set directly and returning predictions with one neighbor at the cost of poor performance on any new data.

  • Wouldn’t a model that performs well on both train and test datasets be a better model?

Maybe. But, maybe not.

This argument is based on the idea that a model that performs well on both train and test sets has a better understanding of the underlying problem.

A corollary is that a model that performs well on the test set but poor on the training set is lucky (e.g. a statistical fluke) and a model that performs well on the train set but poor on the test set is overfit.

I believe this is the sticking point for beginners that often ask how to fix overfitting for their scikit-learn machine learning model.

The worry is that a model must perform well on both train and test sets, otherwise, they are in trouble.

This is not the case.

Performance on the training set is not relevant during model selection. You must focus on the out-of-sample performance only when choosing a predictive model.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Tutorials

APIs

Articles

Summary

In this tutorial, you discovered how to identify overfitting for machine learning models in Python.

Specifically, you learned:

  • Overfitting is a possible cause of poor generalization performance of a predictive model.
  • Overfitting can be analyzed for machine learning models by varying key model hyperparameters.
  • Although overfitting is a useful tool for analysis, it must not be confused with model selection.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

Discover Fast Machine Learning in Python!

Master Machine Learning With Python

Develop Your Own Models in Minutes

...with just a few lines of scikit-learn code

Learn how in my new Ebook:
Machine Learning Mastery With Python

Covers self-study tutorials and end-to-end projects like:
Loading data, visualization, modeling, tuning, and much more...

Finally Bring Machine Learning To
Your Own Projects

Skip the Academics. Just Results.

See What's Inside

12 Responses to How to Identify Overfitting Machine Learning Models in Scikit-Learn

  1. marco November 11, 2020 at 7:26 pm #

    Hi Jason,
    a question about gradient descent. Is it for both Classifciation and Regression? Is it applicable also to machine (shallow) learning and deep learning?
    Thanks,
    Marco

    • Jason Brownlee November 12, 2020 at 6:37 am #

      Yes, SGD can be used for classification and regression.

  2. marco November 11, 2020 at 7:26 pm #

    Hi Jason,
    is is possbile to classify sounds with machine learning or deep learning? Do you have an example?

    Do you have any example of Keras functional APIs?
    Thanks, Marco

  3. KV Subbaiah Setty November 12, 2020 at 2:34 pm #

    I think there is some fishy in the sentence :

    “If this pattern is observed, then training should stop at that point where performance gets worse on the training set for algorithms that learn incrementally”

    of the section “How to Perform an Overfitting Analysis”

    I think we should stop training when performance starts degrading on the test set and not the training set as mentioned above.

    Please let me know if my understanding is wrong

  4. Jay November 13, 2020 at 7:37 am #

    Your blogs is really great and very informative. I am curious and want to know that in your blogs rarely you use real world datasets. In my opinion real world datasets make more sense and provide bigger picture. Majority of bloggers will use iris dataset or some random numbers 0f 10000, how that can be correlated to real world scenario. You are PhD & i believe you can do better job.

    • Jason Brownlee November 13, 2020 at 7:48 am #

      Thanks for your feedback.

      It is easier to explain and undertand an algorithm on a simple synthetic dataset. There are many projects on real datasets on the blog as well – perhaps try the blog search.

  5. Leo November 18, 2020 at 11:16 am #

    Hi Jason! I really liked this post. Thanks for your great work.

    I agree with you 100% that performance on the training set is not relevant during model selection.

    But when do you think it is very relevant? Maybe, when you´re trying to spot overfitting.
    But let´s say e_test<e_train in your learning curves, would you still care and try to lower e_train
    if your e_test is already good enough? Can you think of other examples when you should focus on
    trying to improve e_train, when e_test is already fine? When should you care about e_train at all?

  6. Andrew November 26, 2020 at 11:54 pm #

    “Running the example splits the dataset and we can confirm that we have 70k examples for training and 30k for evaluating a model.”

    You must’ve revised your code, because it’s 7k and 3k, eh?

Leave a Reply