Implementing Simple Linear Regression with scikit-learn

This tutorial runs through the basics of scikit-learn syntax for linear regression. Pay close attention to how the data is generated for this example. Notice how list comprehension is used to create a list for the dependent and independent variables. Furthermore, the dependent variable list is formed by adding random Gaussian noise to each value in the independent variable list. These lists are then converted to numpy arrays to train the linear regression model. By construction, there exists a linear relationship between the independent and dependent variables. Linear regression is then used to identify the slope and intercept, which should match the empirical data. Finally, predictions of the dependent variable are made using independent variable data not contained in the training set.

To perform a simple linear regression with python 3, a solution is to use the module called scikit-learn. Example of implementation:

graph

How to implement a simple linear regression using scikit-learn and python 3 ?

from sklearn import linear_model

import matplotlib.pyplot as plt
import numpy as np
import random

#----------------------------------------------------------------------------------------#
# Step 1: training data

X = [i for i in range(10)]
Y = [random.gauss(x,0.75) for x in X]

X = np.asarray(X)
Y = np.asarray(Y)

X = X[:,np.newaxis]
Y = Y[:,np.newaxis]

plt.scatter(X,Y)

#----------------------------------------------------------------------------------------#
# Step 2: define and train a model

model = linear_model.LinearRegression()
model.fit(X, Y)

print(model.coef_, model.intercept_)

#----------------------------------------------------------------------------------------#
# Step 3: prediction

x_new_min = 0.0
x_new_max = 10.0

X_NEW = np.linspace(x_new_min, x_new_max, 100)
X_NEW = X_NEW[:,np.newaxis]

Y_NEW = model.predict(X_NEW)

plt.plot(X_NEW, Y_NEW, color='coral', linewidth=3)

plt.grid()
plt.xlim(x_new_min,x_new_max)
plt.ylim(0,10)

plt.title("Simple Linear Regression using scikit-learn and python 3",fontsize=10)
plt.xlabel('x')
plt.ylabel('y')

plt.savefig("simple_linear_regression.png", bbox_inches='tight')
plt.show()



Source: Benjamin H.G. Marchant, https://www.moonbooks.org/Articles/How-to-implement-a-simple-linear-regression-using-scikit-learn-and-python-3-/
Creative Commons License This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 License.

Last modified: Wednesday, September 28, 2022, 12:04 PM