# How To Implement Simple Linear Regression From Scratch With Python

Linear regression is a prediction method that is more than 200 years old.

Simple linear regression is a great first machine learning algorithm to implement as it requires you to estimate properties from your training dataset, but is simple enough for beginners to understand.

In this tutorial, you will discover how to implement the simple linear regression algorithm from scratch in Python.

After completing this tutorial you will know:

• How to estimate statistical quantities from training data.
• How to estimate linear regression coefficients from data.
• How to make predictions using linear regression for new data.

Let’s get started.

How To Implement Simple Linear Regression From Scratch With Python
Photo by Kamyar Adl, some rights reserved.

## Description

This section is divided into two parts, a description of the simple linear regression technique and a description of the dataset to which we will later apply it.

### Simple Linear Regression

Linear regression assumes a linear or straight line relationship between the input variables (X) and the single output variable (y).

More specifically, that output (y) can be calculated from a linear combination of the input variables (X). When there is a single input variable, the method is referred to as a simple linear regression.

In simple linear regression we can use statistics on the training data to estimate the coefficients required by the model to make predictions on new data.

The line for a simple linear regression model can be written as:

where b0 and b1 are the coefficients we must estimate from the training data.

Once the coefficients are known, we can use this equation to estimate output values for y given new input examples of x.

It requires that you calculate statistical properties from the data such as mean, variance and covariance.

All the algebra has been taken care of and we are left with some arithmetic to implement to estimate the simple linear regression coefficients.

Briefly, we can estimate the coefficients as follows:

where the i refers to the value of the ith value of the input x or output y.

Don’t worry if this is not clear right now, these are the functions will implement in the tutorial.

### Swedish Insurance Dataset

We will use a real dataset to demonstrate simple linear regression.

The dataset is called the “Auto Insurance in Sweden” dataset and involves predicting the total payment for all the claims in thousands of Swedish Kronor (y) given the total number of claims (x).

This means that for a new number of claims (x) we will be able to predict the total payment of claims (y).

Here is a small sample of the first 5 records of the dataset.

Using the Zero Rule algorithm (that predicts the mean value) a Root Mean Squared Error or RMSE of about 72.251 (thousands of Kronor) is expected.

Below is a scatter plot of the entire dataset.

Swedish Insurance Dataset

Save it to a CSV file in your local working directory with the name “insurance.csv“.

Note, you may need to convert the European “,” to the decimal “.”. You will also need change the file from white-space-separated variables to CSV format.

## Tutorial

This tutorial is broken down into five parts:

1. Calculate Mean and Variance.
2. Calculate Covariance.
3. Estimate Coefficients.
4. Make Predictions.
5. Predict Insurance.

These steps will give you the foundation you need to implement and train simple linear regression models for your own prediction problems.

### 1. Calculate Mean and Variance

The first step is to estimate the mean and the variance of both the input and output variables from the training data.

The mean of a list of numbers can be calculated as:

Below is a function named mean() that implements this behavior for a list of numbers.

The variance is the sum squared difference for each value from the mean value.

Variance for a list of numbers can be calculated as:

Below is a function named variance() that calculates the variance of a list of numbers. It requires the mean of the list to be provided as an argument, just so we don’t have to calculate it more than once.

We can put these two functions together and test them on a small contrived dataset.

Below is a small dataset of x and y values.

NOTE: delete the column headers from this data if you save it to a .CSV file for use with the final code example.

We can plot this dataset on a scatter plot graph as follows:

Small Contrived Dataset For Simple Linear Regression

We can calculate the mean and variance for both the x and y values in the example below.

Running this example prints out the mean and variance for both columns.

This is our first step, next we need to put these values to use in calculating the covariance.

### 2. Calculate Covariance

The covariance of two groups of numbers describes how those numbers change together.

Covariance is a generalization of correlation. Correlation describes the relationship between two groups of numbers, whereas covariance can describe the relationship between two or more groups of numbers.

Additionally, covariance can be normalized to produce a correlation value.

Nevertheless, we can calculate the covariance between two variables as follows:

Below is a function named covariance() that implements this statistic. It builds upon the previous step and takes the lists of x and y values as well as the mean of these values as arguments.

We can test the calculation of the covariance on the same small contrived dataset as in the previous section.

Putting it all together we get the example below.

Running this example prints the covariance for the x and y variables.

We now have all the pieces in place to calculate the coefficients for our model.

### 3. Estimate Coefficients

We must estimate the values for two coefficients in simple linear regression.

The first is B1 which can be estimated as:

We have learned some things above and can simplify this arithmetic to:

We already have functions to calculate covariance() and variance().

Next, we need to estimate a value for B0, also called the intercept as it controls the starting point of the line where it intersects the y-axis.

Again, we know how to estimate B1 and we have a function to estimate mean().

We can put all of this together into a function named coefficients() that takes the dataset as an argument and returns the coefficients.

We can put this together with all of the functions from the previous two steps and test out the calculation of coefficients.

Running this example calculates and prints the coefficients.

Now that we know how to estimate the coefficients, the next step is to use them.

### 4. Make Predictions

The simple linear regression model is a line defined by coefficients estimated from training data.

Once the coefficients are estimated, we can use them to make predictions.

The equation to make predictions with a simple linear regression model is as follows:

Below is a function named simple_linear_regression() that implements the prediction equation to make predictions on a test dataset. It also ties together the estimation of the coefficients on training data from the steps above.

The coefficients prepared from the training data are used to make predictions on the test data, which are then returned.

Let’s pull together everything we have learned and make predictions for our simple contrived dataset.

As part of this example, we will also add in a function to manage the evaluation of the predictions called evaluate_algorithm() and another function to estimate the Root Mean Squared Error of the predictions called rmse_metric().

The full example is listed below.

Running this example displays the following output that first lists the predictions and the RMSE of these predictions.

Finally, we can plot the predictions as a line and compare it to the original dataset.

Predictions For Small Contrived Dataset For Simple Linear Regression

### 5. Predict Insurance

We now know how to implement a simple linear regression model.

Let’s apply it to the Swedish insurance dataset.

This section assumes that you have downloaded the dataset to the file insurance.csv and it is available in the current working directory.

We will add some convenience functions to the simple linear regression from the previous steps.

Specifically a function to load the CSV file called load_csv(), a function to convert a loaded dataset to numbers called str_column_to_float(), a function to evaluate an algorithm using a train and test set called train_test_split() a function to calculate RMSE called rmse_metric() and a function to evaluate an algorithm called evaluate_algorithm().

The complete example is listed below.

A training dataset of 60% of the data is used to prepare the model and predictions are made on the remaining 40%.

Running the algorithm prints the RMSE for the trained model on the training dataset.

A score of about 38 (thousands of Kronor) was achieved, which is much better than the Zero Rule algorithm that achieves approximately 72 (thousands of Kronor) on the same problem.

## Extensions

The best extension to this tutorial is to try out the algorithm on more problems.

Small datasets with just an input (x) and output (y) columns are popular for demonstration in statistical books and courses. Many of these datasets are available online.

Seek out some more small datasets and make predictions using simple linear regression.

Did you apply simple linear regression to another dataset?

## Review

In this tutorial, you discovered how to implement the simple linear regression algorithm from scratch in Python.

Specifically, you learned:

• How to estimate statistics from a training dataset like mean, variance and covariance.
• How to estimate model coefficients and use them to make predictions.
• How to use simple linear regression to make predictions on a real dataset.

Do you have any questions?

## Want to Code Algorithms in Python Without Math?

#### Code Your First Algorithm in Minutes

…with step-by-step tutorials on real-world datasets

Discover how in my new Ebook:
Machine Learning Algorithms From Scratch

It covers 18 tutorials with all the code for 12 top algorithms, like:
Linear Regression, k-Nearest Neighbors, Stochastic Gradient Descent and much more…

### 56 Responses to How To Implement Simple Linear Regression From Scratch With Python

1. Vineeth October 27, 2016 at 7:28 pm #

Hi Jason,

i have downloaded the csv file, but when i try to run the script against the file, i get the following error

” could not convert string to float: ‘X’ ”

this script stops at function def train_test_split(dataset, split)

can you confirm how your csv file is structured ?

Regards
Vineeth

• Jason Brownlee October 28, 2016 at 9:08 am #

Sorry to hear that Vineeth.

Totally my error, do not include the column headers in the small contrived dataset. Delete the first row.

I will update the example.

• En-wai October 30, 2016 at 7:58 am #

Hi Jason…..i have deleted the column headers X and Y along with all other descriptive info in the file but i kee getting this error:

” ValueError: could not convert string to float: i”

here are the first 5 values in my csv file after removing the white space(replacing it with commas) and changing from european “,” to decimal “.”

108,392.5
19,46.2
13,15.7
124,422.2
40,119.4

• Jason Brownlee October 30, 2016 at 8:59 am #

Confirm that you do not have any empty rows on the end of the file.

2. Adrian Moldovan October 27, 2016 at 9:20 pm #

This is brilliant!
Thanks for talking the time to go through all the steps and explain literally… everything.

• Jason Brownlee October 28, 2016 at 9:09 am #

3. Nelson Silva October 28, 2016 at 2:11 am #

Hello Jason,
great tutorial!
It would be great if you also provided the code for the respective plots in python!
Especially the plot for the dataset 🙂

Thank you.

• Jason Brownlee October 28, 2016 at 9:16 am #

Great suggestion Nelson, thanks.

I was aiming to keep the use of libs to a minimum (e.g. no matplotlib or seaborn).

• Rahul Sharma June 13, 2017 at 5:46 am #

Hi Nelson, You can use pyplotlib library to create this kinf of scatter plot:

Pls use this code to implement scatter plot:

import pyplotlib.pyplot as py
py.scatter(x_axis_value,y_axis_value,color=’black’)
py.show()

I hope this helps !

4. venkat dabbara October 28, 2016 at 4:58 am #

predicted = algorithm(dataset, test_set)

where is algorithm defined???

• Jason Brownlee October 28, 2016 at 9:18 am #

Great question Venkat.

The “algorithm” argument in the evaluate_algorithm() function is a name of a function. We pass in the name of the function as “simple_linear_regression”. This means that when we execute algorithm() to make predictions in evaluate_algorithm(), we are in fact calling the simple_linear_regression() function.

I did this to separate algorithm evaluation from algorithm implementation, so that the same test harness can be used for many different algorithms.

5. En-wai October 28, 2016 at 9:14 pm #

under section 2. Calcuating covairiance i think the two meaning there is not quiet a clear. Pls check it.

“In fact, covariance is a generalization of correlation that is limited to two variables. Whereas covariance can be calculate between two or more variables.”???????

• Jason Brownlee October 29, 2016 at 7:42 am #

Thanks En-wai, I have updated the language.

I was trying to comment on how covariance is an abstraction of correlation to go from 2 groups of numbers to more than 2 groups of numbers.

6. Ram October 29, 2016 at 1:06 am #

Hi,

I got clear idea on linear regression. Thank You.

We do calculate linear regression with SciPi library as below.

regr = linear_model.LinearRegression()

regr.fit(X_train, y_train).

Please clarify whether all this calculation will happen behind the scenes when we call the above code.

• Jason Brownlee October 29, 2016 at 9:25 am #

Hi Ram,

There are more efficient approaches to implement these algorithms using linear algebra. I expect this these more efficient approaches are being used behind the scenes.

Implementing algorithms is great for learning how they work, but it is not a good idea to use these from scratch implementations in production.

7. Aliyu A. Aziz October 29, 2016 at 6:03 pm #

Hi Jason,

Many thanks for this easy to follow LR from scratch. I have noticed Line 9

file = open(filename, “rb”)

is opening the file in text mode and causing the “Error: iterator should return strings, not bytes (did you open the file in text mode?)”

Changing ‘rb’ to ‘rt’ or ‘r’
file = open(filename, “rt”)

fixes the error.

Best regards

• Jason Brownlee October 30, 2016 at 8:54 am #

Great, thanks Aliyu.

It does work on my platform, but I will make the example more portable.

8. saimadhu November 3, 2016 at 6:06 pm #

Hi,
Jason Brownlee

Thanks a lot for such an amazing post on simple linear regression. This post is the best tutorial to get the clear picture about simple linear regression analysis and I felt this post is the must read before learning the multi-regression analysis.

• Jason Brownlee November 4, 2016 at 9:06 am #

9. Johnny December 13, 2016 at 4:45 am #

Another great one and I love these foundation ones. Also, you get right into the steps/meat of it and you do not leave out cosmetics – just wrap those up neatly at the end. Thank you sir.

I would like to see/study this same type of process for datasets pertaining to the basic types of business. Specifically, how to produce good dataset and properly frame up problem areas, for business. Do you recommend any books?

• Jason Brownlee December 13, 2016 at 8:09 am #

Thanks Johnny.

Sorry, I don’t know of good books like that. It is an empirical pursuit – more of a craft. The best education is practice.

10. Aslam March 12, 2017 at 4:39 am #

I am a beginner and found this very useful.

Thank you sir !

11. Girish March 25, 2017 at 2:58 am #

How go we plot the graph using code

• Jason Brownlee March 25, 2017 at 7:39 am #

You can use matplotlib:

12. Nemanja April 2, 2017 at 5:33 am #

Hy, how can we plot a line of regression on our graph? And what we can do to reduce a rmse?Thanks

• Jason Brownlee April 2, 2017 at 6:33 am #

You can evaluate the RMSE each epoch/iteration, save the RMSE values in an array and plot the array using matplotlib.

13. Sean April 3, 2017 at 4:36 am #

what is the relationship between numpy.cov() , numpy.var() methods and your covariance() , variance() calculations ? I get very different results between the two.

Thanks

14. Abhishek April 28, 2017 at 3:25 pm #

Its a great article thankyou for helping us…

• Jason Brownlee April 29, 2017 at 7:21 am #

Thanks Abhishek, I’m glad that you found it useful.

15. Nuwan C May 4, 2017 at 1:30 pm #

Hi Jason,
Thank you for another great tutorial.
What does the Zero Based algorithm do and why it use in her?

Thank you

16. John David Kromkowski May 5, 2017 at 2:28 am #

Nice work

Maybe tiny typo:

covariance = sum((x(i) – mean(x)) * (y – mean(y)))

should be

covariance = sum((x(i) – mean(x)) * (y(i) – mean(y)))

You have it correct in the actual code

17. Etienne May 20, 2017 at 6:35 pm #

Good day Jason

My model is y = b0 + (b1 * x) – (b2 / (b3+x)), which gives an asymptotic approach in a flocculation process. While I get a good data fit using the scipy curve_fit routine, I do not know how to get the leverage, the diagonal elements of the hat matrix H. Whereas in your model, the X system matrix would be formulated as:

^y = H.y

and H is X(XT.X)**-1.XT, where XT is the transpose of X

In your model X.^b would be:

[ 1 x0 ] [b0]
[ 1 x1 ] [b1]
[ 1 x2 ] .
[ 1 x3 ]
[ .. .. ]

But what would it be in my case?

Another problem is how to solve for H, so I can get the diagonal elements hii.

Any help would be greatly appreciated.

18. suguna May 24, 2017 at 4:45 am #

I removed columns header from csv file(Insurance CSV)

then Iam getting this following error:

ValueError: could not convert string to float: female

• Rahul Sharma June 13, 2017 at 5:47 am #

suguna , you need to remove all the empty cells in your csv, if any are present. That is what is causing this error

19. Rahul Sharma June 14, 2017 at 2:21 am #

Hi Jason,

As per the derivation : https://en.wikipedia.org/wiki/Standard_deviation

Variance = Avg (xi – xMean)^2

But here in algorithm you have used it as : sum([(x-mean)**2 for x in values])

which is not average but only some of squared difference. Is this some kind of modification?

• Rahul Sharma June 15, 2017 at 10:12 pm #

Hi Jason. Can you please clarify this doubt.

20. Digvijay Rana June 15, 2017 at 5:10 am #

Thankyou very much Sir,
I had been looking for someplace to start implenting algos myself. This is best tutorial i have read by far. Waiting fo other algorithm’s simple implementations.

21. Vaibhav June 17, 2017 at 11:08 pm #

Thanks a lot sir ! . Its a best description so far .

22. Kris July 6, 2017 at 11:20 pm #

I’m confused about your definition of covariance. Generally it’s finally divided by (n – 1) where n is the number of samples, where as there is no such operation carried out through out the code. Can you please clarify ?

23. Soumik Rakshit July 13, 2017 at 1:45 am #

• Jason Brownlee July 13, 2017 at 9:57 am #

Here is the raw file:

You will need to convert the “,” to “.” and replace the space between columns with “,”.

24. uma maheswari July 16, 2017 at 9:09 pm #

hi jason

can you tell how do we implement the linear regression on image dataset

25. Pierce Ng July 31, 2017 at 1:39 am #

Hi Jason,

Great stuff! Thanks for the exposition.

I implemented a no-shuffling version of train_test_split which always takes the first 38 entries as training data and the last 25 entries as test data. The program gives RMSE of 45.23.

Your RMSE of 38.339 is from the randomization in train_test_split with seed(1). If I try with seed(2) then the RMSE is 37.734.

What’s the next step with different values of RMSE?

• Jason Brownlee July 31, 2017 at 8:16 am #

This is the variance of the method.

Ideally, we would evaluate the algorithm multiple times and report the mean and standard deviation of the model.

Does that help?

26. Eoin Kenny August 11, 2017 at 3:14 am #

That is NOT the formula for variance… you’re supposed to divide by n or n-1, what is going on?

• Jason Brownlee August 11, 2017 at 6:43 am #

Might be population vs sample variance.