Chapter 7 Regularization
These notes are from this tutorial on DataCamp, the Machine Learning Toolbox DataCamp class, and Interpretable Machine Learning (Molnar 2020).
Regularization is a set of methods that manage the bias-variance trade-off problem in linear regression.
The linear regression model is \(Y = X \beta + \epsilon\), where \(\epsilon \sim N(0, \sigma^2)\). OLS estimates the coefficients by minimizing the loss function
\[L = \sum_{i = 1}^n \left(y_i - x_i^{'} \hat\beta \right)^2.\]
The resulting estimate for the coefficients is
\[\hat{\beta} = \left(X'X\right)^{-1}\left(X'Y\right).\]
There are two important characteristics of any estimator: its bias and its variance. For OLS, these are
\[Bias(\hat{\beta}) = E(\hat{\beta}) - \beta = 0\] and
\[Var(\hat{\beta}) = \sigma^2(X'X)^{-1}\]
where the unknown population variance \(\sigma^2\) is estimated from the residuals
\[\hat\sigma^2 = \frac{\epsilon' \epsilon}{n - k}.\]
The OLS estimator is unbiased, but can have a large variance when the predictor variables are highly correlated with each other, or when there are many predictors (notice how \(\hat{\sigma}^2\) increases as \(k \rightarrow n\)). Stepwise selection balances the trade-off by eliminating variables, but this throws away information. Regularization keeps all the predictors, but reduces coefficient magnitudes to reduce variance at the expense of some bias.
In the sections below, I’ll use the mtcars
data set to predict mpg
from the other variables using the caret::glmnet()
function. glmnet()
uses penalized maximum likelihood to fit generalized linear models such as ridge, lasso, and elastic net. I’ll compare the model performances by creating a training and validation set, and a common trainControl
object to make sure the models use the same observations in the cross-validation folds.
library(tidyverse)
library(caret)
data("mtcars")
set.seed(123)
partition <- createDataPartition(mtcars$mpg, p = 0.8, list = FALSE)
training <- mtcars[partition, ]
testing <- mtcars[-partition, ]
train_control <- trainControl(
method = "repeatedcv",
number = 5,
repeats = 5,
savePredictions = "final" # saves predictions from optimal tuning parameters
)
References
Molnar, Christoph. 2020. Interpretable Machine Learning. https://christophm.github.io/interpretable-ml-book/.