8.2 Regression Tree
A simple regression tree is built in a manner similar to a simple classification tree, and like the simple classification tree, it is rarely invoked on its own; the bagged, random forest, and gradient boosting methods build on this logic. I’ll learn by example again. Using the ISLR::Carseats
data set, and predict Sales
using from the 10 feature variables.
Name | cs_dat |
Number of rows | 400 |
Number of columns | 11 |
_______________________ | |
Column type frequency: | |
factor | 3 |
numeric | 8 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
ShelveLoc | 0 | 1 | FALSE | 3 | Med: 219, Bad: 96, Goo: 85 |
Urban | 0 | 1 | FALSE | 2 | Yes: 282, No: 118 |
US | 0 | 1 | FALSE | 2 | Yes: 258, No: 142 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Sales | 0 | 1 | 7.5 | 2.8 | 0 | 5.4 | 7.5 | 9.3 | 16 | ▁▆▇▃▁ |
CompPrice | 0 | 1 | 125.0 | 15.3 | 77 | 115.0 | 125.0 | 135.0 | 175 | ▁▅▇▃▁ |
Income | 0 | 1 | 68.7 | 28.0 | 21 | 42.8 | 69.0 | 91.0 | 120 | ▇▆▇▆▅ |
Advertising | 0 | 1 | 6.6 | 6.6 | 0 | 0.0 | 5.0 | 12.0 | 29 | ▇▃▃▁▁ |
Population | 0 | 1 | 264.8 | 147.4 | 10 | 139.0 | 272.0 | 398.5 | 509 | ▇▇▇▇▇ |
Price | 0 | 1 | 115.8 | 23.7 | 24 | 100.0 | 117.0 | 131.0 | 191 | ▁▂▇▆▁ |
Age | 0 | 1 | 53.3 | 16.2 | 25 | 39.8 | 54.5 | 66.0 | 80 | ▇▆▇▇▇ |
Education | 0 | 1 | 13.9 | 2.6 | 10 | 12.0 | 14.0 | 16.0 | 18 | ▇▇▃▇▇ |
Split careseats_dat
(n = 400) into cs_train
(80%, n = 321) and cs_test
(20%, n = 79).
set.seed(12345)
partition <- createDataPartition(y = cs_dat$Sales, p = 0.8, list = FALSE)
cs_train <- cs_dat[partition, ]
cs_test <- cs_dat[-partition, ]
The first step is to build a full tree, then perform k-fold cross-validation to help select the optimal cost complexity (cp). The only difference here is the rpart()
parameter method = "anova"
to produce a regression tree.
set.seed(1234)
cs_mdl_cart_full <- rpart(Sales ~ ., cs_train, method = "anova")
print(cs_mdl_cart_full)
## n= 321
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 321 2600 7.5
## 2) ShelveLoc=Bad,Medium 251 1500 6.8
## 4) Price>=1.1e+02 168 720 6.0
## 8) ShelveLoc=Bad 50 170 4.7
## 16) Population< 2e+02 20 48 3.6 *
## 17) Population>=2e+02 30 81 5.4 *
## 9) ShelveLoc=Medium 118 430 6.5
## 18) Advertising< 12 88 290 6.1
## 36) CompPrice< 1.4e+02 69 190 5.8
## 72) Price>=1.3e+02 16 51 4.5 *
## 73) Price< 1.3e+02 53 110 6.2 *
## 37) CompPrice>=1.4e+02 19 58 7.4 *
## 19) Advertising>=12 30 83 7.8 *
## 5) Price< 1.1e+02 83 440 8.4
## 10) Age>=64 32 150 6.9
## 20) Price>=85 25 67 6.2
## 40) ShelveLoc=Bad 9 18 4.8 *
## 41) ShelveLoc=Medium 16 21 6.9 *
## 21) Price< 85 7 20 9.6 *
## 11) Age< 64 51 180 9.3
## 22) Income< 58 12 28 7.7 *
## 23) Income>=58 39 120 9.7
## 46) Age>=50 14 21 8.5 *
## 47) Age< 50 25 60 10.0 *
## 3) ShelveLoc=Good 70 420 10.0
## 6) Price>=1.1e+02 49 240 9.4
## 12) Advertising< 14 41 160 8.9
## 24) Age>=61 17 53 7.8 *
## 25) Age< 61 24 69 9.8 *
## 13) Advertising>=14 8 13 12.0 *
## 7) Price< 1.1e+02 21 61 12.0 *
The predicted Sales
at the root is the mean Sales
for the training data set, 7.535950 (values are $000s). The deviance at the root is the SSE, 2567.768. The first split is at ShelveLoc
= [Bad, Medium] vs Good. Here is the unpruned tree diagram.
The boxes show the node predicted value (mean) and the proportion of observations that are in the node (or child nodes).
rpart()
grew the full tree, and used cross-validation to test the performance of the possible complexity hyperparameters. printcp()
displays the candidate cp values. You can use this table to decide how to prune the tree.
##
## Regression tree:
## rpart(formula = Sales ~ ., data = cs_train, method = "anova")
##
## Variables actually used in tree construction:
## [1] Advertising Age CompPrice Income Population Price
## [7] ShelveLoc
##
## Root node error: 2568/321 = 8
##
## n= 321
##
## CP nsplit rel error xerror xstd
## 1 0 0 1 1 0
## 2 0 1 1 1 0
## 3 0 2 1 1 0
## 4 0 3 1 1 0
## 5 0 4 1 1 0
## 6 0 5 0 1 0
## 7 0 6 0 1 0
## 8 0 7 0 1 0
## 9 0 8 0 1 0
## 10 0 9 0 1 0
## 11 0 10 0 1 0
## 12 0 11 0 1 0
## 13 0 12 0 1 0
## 14 0 13 0 1 0
## 15 0 14 0 1 0
## 16 0 15 0 1 0
There were 16 possible cp values in this model. The model with the smallest complexity parameter allows the most splits (nsplit
). The highest complexity parameter corresponds to a tree with just a root node. rel error
is the SSE relative to the root node. The root node SSE is 2567.76800, so its rel error
is 2567.76800/2567.76800 = 1.0. That means the absolute error of the full tree (at CP = 0.01) is 0.30963 * 2567.76800 = 795.058. You can verify that by calculating the SSE of the model predicted values:
data.frame(pred = predict(cs_mdl_cart_full, newdata = cs_train)) %>%
mutate(obs = cs_train$Sales,
sq_err = (obs - pred)^2) %>%
summarize(sse = sum(sq_err))
## sse
## 1 795
Finishing the CP table tour, xerror
is the cross-validated SSE and xstd
is its standard error. If you want the lowest possible error, then prune to the tree with the smallest relative SSE (xerror
). If you want to balance predictive power with simplicity, prune to the smallest tree within 1 SE of the one with the smallest relative SSE. The CP table is not super-helpful for finding that tree. I’ll add a column to find it.
cs_mdl_cart_full$cptable %>%
data.frame() %>%
mutate(min_xerror_idx = which.min(cs_mdl_cart_full$cptable[, "xerror"]),
rownum = row_number(),
xerror_cap = cs_mdl_cart_full$cptable[min_xerror_idx, "xerror"] +
cs_mdl_cart_full$cptable[min_xerror_idx, "xstd"],
eval = case_when(rownum == min_xerror_idx ~ "min xerror",
xerror < xerror_cap ~ "under cap",
TRUE ~ "")) %>%
select(-rownum, -min_xerror_idx)
## CP nsplit rel.error xerror xstd xerror_cap eval
## 1 0.263 0 1.00 1.01 0.077 0.59
## 2 0.121 1 0.74 0.75 0.059 0.59
## 3 0.046 2 0.62 0.65 0.051 0.59
## 4 0.045 3 0.57 0.67 0.052 0.59
## 5 0.042 4 0.52 0.66 0.051 0.59
## 6 0.026 5 0.48 0.62 0.049 0.59
## 7 0.026 6 0.46 0.62 0.048 0.59
## 8 0.024 7 0.43 0.62 0.048 0.59
## 9 0.015 8 0.41 0.58 0.042 0.59 under cap
## 10 0.015 9 0.39 0.56 0.041 0.59 under cap
## 11 0.015 10 0.38 0.56 0.041 0.59 under cap
## 12 0.014 11 0.36 0.56 0.041 0.59 under cap
## 13 0.014 12 0.35 0.56 0.038 0.59 min xerror
## 14 0.014 13 0.33 0.56 0.038 0.59 under cap
## 15 0.011 14 0.32 0.57 0.039 0.59 under cap
## 16 0.010 15 0.31 0.57 0.038 0.59 under cap
Okay, so the simplest tree is the one with CP = 0.02599265 (5 splits). Fortunately, plotcp()
presents a nice graphical representation of the relationship between xerror
and cp
.
The dashed line is set at the minimum xerror
+ xstd
. The top axis shows the number of splits in the tree. I’m not sure why the CP values are not the same as in the table (they are close, but not the same). The smallest relative error is at CP = 0.01000000 (15 splits), but the maximum CP below the dashed line (one standard deviation above the minimum error) is at CP = 0.02599265 (5 splits). Use the prune()
function to prune the tree by specifying the associated cost-complexity cp
.
cs_mdl_cart <- prune(
cs_mdl_cart_full,
cp = cs_mdl_cart_full$cptable[cs_mdl_cart_full$cptable[, 2] == 5, "CP"]
)
rpart.plot(cs_mdl_cart, yesno = TRUE)
The most “important” indicator of Sales
is ShelveLoc
. Here are the importance values from the model.
cs_mdl_cart$variable.importance %>%
data.frame() %>%
rownames_to_column(var = "Feature") %>%
rename(Overall = '.') %>%
ggplot(aes(x = fct_reorder(Feature, Overall), y = Overall)) +
geom_pointrange(aes(ymin = 0, ymax = Overall), color = "cadetblue", size = .3) +
theme_minimal() +
coord_flip() +
labs(x = "", y = "", title = "Variable Importance with Simple Regression")
The most important indicator of Sales
is ShelveLoc
, then Price
, then Age
, all of which appear in the final model. CompPrice
was also important.
The last step is to make predictions on the validation data set. The root mean squared error (\(RMSE = \sqrt{(1/2) \sum{(actual - pred)^2}})\) and mean absolute error (\(MAE = (1/n) \sum{|actual - pred|}\)) are the two most common measures of predictive accuracy. The key difference is that RMSE punishes large errors more harshly. For a regression tree, set argument type = "vector"
(or do not specify at all).
cs_preds_cart <- predict(cs_mdl_cart, cs_test, type = "vector")
cs_rmse_cart <- RMSE(
pred = cs_preds_cart,
obs = cs_test$Sales
)
cs_rmse_cart
## [1] 2.4
The pruning process leads to an average prediction error of 2.363 in the test data set. Not too bad considering the standard deviation of Sales
is 2.8. Here is a predicted vs actual plot.
data.frame(Predicted = cs_preds_cart, Actual = cs_test$Sales) %>%
ggplot(aes(x = Actual, y = Predicted)) +
geom_point(alpha = 0.6, color = "cadetblue") +
geom_smooth() +
geom_abline(intercept = 0, slope = 1, linetype = 2) +
labs(title = "Carseats CART, Predicted vs Actual")
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'
The 6 possible predicted values do a decent job of binning the observations.
8.2.1 Training with Caret
I can also fit the model with caret::train()
, specifying method = "rpart"
. I’ll build the model using 10-fold cross-validation to optimize the hyperparameter CP.
cs_trControl = trainControl(
method = "cv",
number = 10,
savePredictions = "final" # save predictions for the optimal tuning parameter
)
I’ll let the model look for the best CP tuning parameter with tuneLength
to get close, then fine-tune with tuneGrid
.
set.seed(1234)
cs_mdl_cart2 = train(
Sales ~ .,
data = cs_train,
method = "rpart",
tuneLength = 5,
metric = "RMSE",
trControl = cs_trControl
)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
## There were missing values in resampled performance measures.
## CART
##
## 321 samples
## 10 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 289, 289, 289, 289, 289, 289, ...
## Resampling results across tuning parameters:
##
## cp RMSE Rsquared MAE
## 0.042 2.2 0.41 1.8
## 0.045 2.2 0.38 1.8
## 0.046 2.3 0.37 1.8
## 0.121 2.4 0.29 1.9
## 0.263 2.7 0.19 2.2
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was cp = 0.042.
The first cp
(0.04167149) produced the smallest RMSE. I can drill into the best value of cp
using a tuning grid. I’ll try that now.
set.seed(1234)
cs_mdl_cart2 = train(
Sales ~ .,
data = cs_train,
method = "rpart",
tuneGrid = expand.grid(cp = seq(from = 0, to = 0.1, by = 0.01)),
metric = "RMSE",
trControl = cs_trControl
)
print(cs_mdl_cart2)
## CART
##
## 321 samples
## 10 predictor
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 289, 289, 289, 289, 289, 289, ...
## Resampling results across tuning parameters:
##
## cp RMSE Rsquared MAE
## 0.00 2.1 0.50 1.7
## 0.01 2.1 0.46 1.7
## 0.02 2.1 0.47 1.7
## 0.03 2.1 0.45 1.7
## 0.04 2.1 0.44 1.7
## 0.05 2.3 0.36 1.8
## 0.06 2.3 0.37 1.8
## 0.07 2.3 0.36 1.8
## 0.08 2.3 0.36 1.8
## 0.09 2.3 0.36 1.8
## 0.10 2.3 0.36 1.8
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was cp = 0.
It looks like the best performing tree is the unpruned one.
Let’s see the final model.
What were the most important variables?
Evaluate the model by making predictions with the test data set.
cs_preds_cart2 <- predict(cs_mdl_cart2, cs_test, type = "raw")
data.frame(Actual = cs_test$Sales, Predicted = cs_preds_cart2) %>%
ggplot(aes(x = Actual, y = Predicted)) +
geom_point(alpha = 0.6, color = "cadetblue") +
geom_smooth(method = "loess", formula = "y ~ x") +
geom_abline(intercept = 0, slope = 1, linetype = 2) +
labs(title = "Carseats CART, Predicted vs Actual (caret)")
The model over-estimates at the low end and underestimates at the high end. Calculate the test data set RMSE.
## [1] 2.3
Caret performed better in this model. Here is a summary the RMSE values of the two models.
cs_scoreboard <- rbind(
data.frame(Model = "Single Tree", RMSE = cs_rmse_cart),
data.frame(Model = "Single Tree (caret)", RMSE = cs_rmse_cart2)
) %>% arrange(RMSE)
scoreboard(cs_scoreboard)
Model | RMSE |
Single Tree (caret) | 2.2983 |
Single Tree | 2.3632 |