Nearest Shrunken Centroids With Python

Nearest Centroids is a linear classification machine learning algorithm.

It involves predicting a class label for new examples based on which class-based centroid the example is closest to from the training dataset.

The Nearest Shrunken Centroids algorithm is an extension that involves shifting class-based centroids toward the centroid of the entire training dataset and removing those input variables that are less useful at discriminating the classes.

As such, the Nearest Shrunken Centroids algorithm performs an automatic form of feature selection, making it appropriate for datasets with very large numbers of input variables.

In this tutorial, you will discover the Nearest Shrunken Centroids classification machine learning algorithm.

After completing this tutorial, you will know:

  • The Nearest Shrunken Centroids is a simple linear machine learning algorithm for classification.
  • How to fit, evaluate, and make predictions with the Nearest Shrunken Centroids model with Scikit-Learn.
  • How to tune the hyperparameters of the Nearest Shrunken Centroids algorithm on a given dataset.

Let’s get started.

Nearest Shrunken Centroids With Python

Nearest Shrunken Centroids With Python
Photo by Giuseppe Milo, some rights reserved.

Tutorial Overview

This tutorial is divided into three parts; they are:

  1. Nearest Centroids Algorithm
  2. Nearest Centroids With Scikit-Learn
  3. Tuning Nearest Centroid Hyperparameters

Nearest Centroids Algorithm

Nearest Centroids is a classification machine learning algorithm.

The algorithm involves first summarizing the training dataset into a set of centroids (centers), then using the centroids to make predictions for new examples.

For each class, the centroid of the data is found by taking the average value of each predictor (per class) in the training set. The overall centroid is computed using the data from all of the classes.

— Page 307, Applied Predictive Modeling, 2013.

A centroid is the geometric center of a data distribution, such as the mean. In multiple dimensions, this would be the mean value along each dimension, forming a point of center of the distribution across each variable.

The Nearest Centroids algorithm assumes that the centroids in the input feature space are different for each target label. The training data is split into groups by class label, then the centroid for each group of data is calculated. Each centroid is simply the mean value of each of the input variables. If there are two classes, then two centroids or points are calculated; three classes give three centroids, and so on.

The centroids then represent the “model.” Given new examples, such as those in the test set or new data, the distance between a given row of data and each centroid is calculated and the closest centroid is used to assign a class label to the example.

Distance measures, such as Euclidean distance, are used for numerical data or hamming distance for categorical data, in which case it is best practice to scale input variables via normalization or standardization prior to training the model. This is to ensure that input variables with large values don’t dominate the distance calculation.

An extension to the nearest centroid method for classification is to shrink the centroids of each input variable towards the centroid of the entire training dataset. Those variables that are shrunk down to the value of the data centroid can then be removed as they do not help to discriminate between the class labels.

As such, the amount of shrinkage applied to the centroids is a hyperparameter that can be tuned for the dataset and used to perform an automatic form of feature selection. Thus, it is appropriate for a dataset with a large number of input variables, some of which may be irrelevant or noisy.

Consequently, the nearest shrunken centroid model also conducts feature selection during the model training process.

— Page 307, Applied Predictive Modeling, 2013.

This approach is referred to as “Nearest Shrunken Centroids” and was first described by Robert Tibshirani, et al. in their 2002 paper titled “Diagnosis Of Multiple Cancer Types By Shrunken Centroids Of Gene Expression.”

Nearest Centroids With Scikit-Learn

The Nearest Shrunken Centroids is available in the scikit-learn Python machine learning library via the NearestCentroid class.

The class allows the configuration of the distance metric used in the algorithm via the “metric” argument, which defaults to ‘euclidean‘ for the Euclidean distance metric.

This can be changed to other built-in metrics such as ‘manhattan.’

By default, no shrinkage is used, but shrinkage can be specified via the “shrink_threshold” argument, which takes a floating point value between 0 and 1.

We can demonstrate the Nearest Shrunken Centroids with a worked example.

First, let’s define a synthetic classification dataset.

We will use the make_classification() function to create a dataset with 1,000 examples, each with 20 input variables.

The example creates and summarizes the dataset.

Running the example creates the dataset and confirms the number of rows and columns of the dataset.

We can fit and evaluate a Nearest Shrunken Centroids model using repeated stratified k-fold cross-validation via the RepeatedStratifiedKFold class. We will use 10 folds and three repeats in the test harness.

We will use the default configuration of Euclidean distance and no shrinkage.

The complete example of evaluating the Nearest Shrunken Centroids model for the synthetic binary classification task is listed below.

Running the example evaluates the Nearest Shrunken Centroids algorithm on the synthetic dataset and reports the average accuracy across the three repeats of 10-fold cross-validation.

Your specific results may vary given the stochastic nature of the learning algorithm. Consider running the example a few times.

In this case, we can see that the model achieved a mean accuracy of about 71 percent.

We may decide to use the Nearest Shrunken Centroids as our final model and make predictions on new data.

This can be achieved by fitting the model on all available data and calling the predict() function passing in a new row of data.

We can demonstrate this with a complete example listed below.

Running the example fits the model and makes a class label prediction for a new row of data.

Next, we can look at configuring the model hyperparameters.

Tuning Nearest Centroid Hyperparameters

The hyperparameters for the Nearest Shrunken Centroid method must be configured for your specific dataset.

Perhaps the most important hyperparameter is the shrinkage controlled via the “shrink_threshold” argument. It is a good idea to test values between 0 and 1 on a grid of values such as 0.1 or 0.01.

The example below demonstrates this using the GridSearchCV class with a grid of values we have defined.

Running the example will evaluate each combination of configurations using repeated cross-validation.

Your specific results may vary given the stochastic nature of the learning algorithm. Try running the example a few times.

In this case, we can see that we achieved slightly better results than the default, with 71.4 percent vs 71.1 percent. We can see that the model assigned a shrink_threshold value of 0.53.

The other key configuration is the distance measure used, which can be chosen based on the distribution of the input variables.

Any of the built-in distance measures can be used, as listed here:

Common distance measures include:

  • ‘cityblock’, ‘cosine’, ‘euclidean’, ‘l1’, ‘l2’, ‘manhattan’

For more on how these distance measures are calculated, see the tutorial:

Given that our input variables are numeric, our dataset only supports ‘euclidean‘ and ‘manhattan.’

We can include these metrics in our grid search; the complete example is listed below.

Running the example fits the model and discovers the hyperparameters that give the best results using cross-validation.

Your specific results may vary given the stochastic nature of the learning algorithm. Try running the example a few times.

In this case, we can see that we get slightly better accuracy of 75 percent using no shrinkage and the manhattan instead of the euclidean distance measure.

A good extension to these experiments would be to add data normalization or standardization to the data as part of a modeling Pipeline.

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Tutorials

Papers

Books

APIs

Articles

Summary

In this tutorial, you discovered the Nearest Shrunken Centroids classification machine learning algorithm.

Specifically, you learned:

  • The Nearest Shrunken Centroids is a simple linear machine learning algorithm for classification.
  • How to fit, evaluate, and make predictions with the Nearest Shrunken Centroids model with Scikit-Learn.
  • How to tune the hyperparameters of the Nearest Shrunken Centroids algorithm on a given dataset.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

Discover Fast Machine Learning in Python!

Master Machine Learning With Python

Develop Your Own Models in Minutes

...with just a few lines of scikit-learn code

Learn how in my new Ebook:
Machine Learning Mastery With Python

Covers self-study tutorials and end-to-end projects like:
Loading data, visualization, modeling, tuning, and much more...

Finally Bring Machine Learning To
Your Own Projects

Skip the Academics. Just Results.

See What's Inside

6 Responses to Nearest Shrunken Centroids With Python

  1. RK_pat October 15, 2020 at 6:16 pm #

    Hi, How can we see the centroids of each dimension or feature and the main centroids of each target classes. Do you have any idea of creating a score for each entry

  2. Cameron August 18, 2022 at 12:36 am #

    Hi,
    Thank you for the article. How do you change the define dataset line of code to input your own dataset of numbers (gene expression data).
    Thank you

  3. Amit April 20, 2024 at 1:55 pm #

    Thanks for the very helpful tutorial. I’m still confused how to figure out which features have been shrunk to zero with shrinkage threshold and how many remain.

    • James Carmichael April 21, 2024 at 10:18 am #

      Hi Amit…Understanding how shrunken centroids work, particularly in the context of feature selection through shrinkage to zero, involves grasping the mechanics of techniques like the Shrunken Centroids Regularized Discriminant Analysis (also known as the “nearest shrunken centroids” method). This method is often used in scenarios like gene expression data classification, where the number of features (genes) can be very large compared to the number of samples. Let’s break down how this method works and how you can identify which features have been effectively eliminated by the shrinkage process.

      ### What is Shrunken Centroids?

      Shrunken centroids, popularized by the method called Predictive Analysis of Microarrays (PAM), is used primarily for classification. It shrinks the class centroids towards the overall centroid for all classes by an amount determined by a shrinkage parameter (often lambda). The goal is to improve classification accuracy by reducing variance without significantly increasing bias.

      ### How Shrinkage Works

      1. **Centroid Calculation**: For each class, calculate the centroid of the features for the samples belonging to that class. This is the average of all feature values for samples of a specific class.

      2. **Overall Centroid**: Calculate the overall centroid of the features across all classes.

      3. **Shrinkage**: Each class centroid component is “shrunk” towards the overall centroid. The degree of this shrinkage depends on the shrinkage parameter (lambda). Higher values of lambda result in greater shrinkage.

      4. **Effect of Shrinkage**: If the shrinkage pulls a centroid component all the way to the overall centroid (or very close), the effect is that the corresponding feature does not effectively contribute to distinguishing between classes. When the centroid component for a feature across all classes is shrunk to the overall centroid, it implies that the feature has little to no discriminative power given the shrinkage penalty and can be considered as being shrunk to zero.

      ### How to Identify Features Shrunk to Zero

      To figure out which features have been shrunk to zero, you need to look at the differences between the class centroids and the overall centroid post-shrinkage:

      1. **Thresholding**: After applying the shrinkage, any feature whose class centroids (for all classes) are sufficiently close to the overall centroid can be considered as having been shrunk to zero. The closeness can be determined based on a threshold related to the lambda value.

      2. **Practical Implementation**: If using software or a library like R’s pamr package for PAM, it often provides tools to visualize or directly identify which features have coefficients reduced to zero. For example, using pamr:
      R
      library(pamr)
      fit <- pamr.train(data, labels) thresholded <- pamr.threshold(fit) print(thresholded$features)

      Here, thresholded$features would list the features that remain after applying the shrinkage threshold. Features not listed are those shrunk to zero.

      ### Considerations

      - **Choice of Lambda**: The choice of the shrinkage parameter lambda is critical. It can often be selected via cross-validation, optimizing for the best classification performance while minimizing overfitting.

      - **Impact on Feature Selection**: This method acts as a form of feature selection, keeping only those features that contribute to class separation post-shrinkage.

      Understanding which features are shrunk to zero in shrunken centroids helps in simplifying the model and focusing on the most informative attributes. This makes the model both interpretable and efficient, especially in high-dimensional data scenarios.

Leave a Reply

Machine Learning Mastery is part of Guiding Tech Media, a leading digital media publisher focused on helping people figure out technology. Visit our corporate website to learn more about our mission and team.