This tutorial runs through the basics of scikit-learn syntax for linear regression. Pay close attention to how the data is generated for this example. Notice how list comprehension is used to create a list for the dependent and independent variables. Furthermore, the dependent variable list is formed by adding random Gaussian noise to each value in the independent variable list. These lists are then converted to numpy arrays to train the linear regression model. By construction, there exists a linear relationship between the independent and dependent variables. Linear regression is then used to identify the slope and intercept, which should match the empirical data. Finally, predictions of the dependent variable are made using independent variable data not contained in the training set.
from sklearn import linear_model
import matplotlib.pyplot as plt
import numpy as np
import random
#----------------------------------------------------------------------------------------#
# Step 1: training data
X = [i for i in range(10)]
Y = [random.gauss(x,0.75) for x in X]
X = np.asarray(X)
Y = np.asarray(Y)
X = X[:,np.newaxis]
Y = Y[:,np.newaxis]
plt.scatter(X,Y)
#----------------------------------------------------------------------------------------#
# Step 2: define and train a model
model = linear_model.LinearRegression()
model.fit(X, Y)
print(model.coef_, model.intercept_)
#----------------------------------------------------------------------------------------#
# Step 3: prediction
x_new_min = 0.0
x_new_max = 10.0
X_NEW = np.linspace(x_new_min, x_new_max, 100)
X_NEW = X_NEW[:,np.newaxis]
Y_NEW = model.predict(X_NEW)
plt.plot(X_NEW, Y_NEW, color='coral', linewidth=3)
plt.grid()
plt.xlim(x_new_min,x_new_max)
plt.ylim(0,10)
plt.title("Simple Linear Regression using scikit-learn and python 3",fontsize=10)
plt.xlabel('x')
plt.ylabel('y')
plt.savefig("simple_linear_regression.png", bbox_inches='tight')
plt.show()
Source: Benjamin H.G. Marchant, https://www.moonbooks.org/Articles/How-to-implement-a-simple-linear-regression-using-scikit-learn-and-python-3-/ This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 License.