8.2 Regression Tree

A simple regression tree is built in a manner similar to a simple classification tree, and like the simple classification tree, it is rarely invoked on its own; the bagged, random forest, and gradient boosting methods build on this logic. I’ll learn by example again. Using the ISLR::Carseats data set, and predict Sales using from the 10 feature variables.

cs_dat <- ISLR::Carseats
skimr::skim(cs_dat)
Table 8.2: Data summary
Name cs_dat
Number of rows 400
Number of columns 11
_______________________
Column type frequency:
factor 3
numeric 8
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
ShelveLoc 0 1 FALSE 3 Med: 219, Bad: 96, Goo: 85
Urban 0 1 FALSE 2 Yes: 282, No: 118
US 0 1 FALSE 2 Yes: 258, No: 142

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Sales 0 1 7.5 2.8 0 5.4 7.5 9.3 16 ▁▆▇▃▁
CompPrice 0 1 125.0 15.3 77 115.0 125.0 135.0 175 ▁▅▇▃▁
Income 0 1 68.7 28.0 21 42.8 69.0 91.0 120 ▇▆▇▆▅
Advertising 0 1 6.6 6.6 0 0.0 5.0 12.0 29 ▇▃▃▁▁
Population 0 1 264.8 147.4 10 139.0 272.0 398.5 509 ▇▇▇▇▇
Price 0 1 115.8 23.7 24 100.0 117.0 131.0 191 ▁▂▇▆▁
Age 0 1 53.3 16.2 25 39.8 54.5 66.0 80 ▇▆▇▇▇
Education 0 1 13.9 2.6 10 12.0 14.0 16.0 18 ▇▇▃▇▇

Split careseats_dat (n = 400) into cs_train (80%, n = 321) and cs_test (20%, n = 79).

set.seed(12345)
partition <- createDataPartition(y = cs_dat$Sales, p = 0.8, list = FALSE)
cs_train <- cs_dat[partition, ]
cs_test <- cs_dat[-partition, ]

The first step is to build a full tree, then perform k-fold cross-validation to help select the optimal cost complexity (cp). The only difference here is the rpart() parameter method = "anova" to produce a regression tree.

set.seed(1234)
cs_mdl_cart_full <- rpart(Sales ~ ., cs_train, method = "anova")
print(cs_mdl_cart_full)
## n= 321 
## 
## node), split, n, deviance, yval
##       * denotes terminal node
## 
##  1) root 321 2600  7.5  
##    2) ShelveLoc=Bad,Medium 251 1500  6.8  
##      4) Price>=1.1e+02 168  720  6.0  
##        8) ShelveLoc=Bad 50  170  4.7  
##         16) Population< 2e+02 20   48  3.6 *
##         17) Population>=2e+02 30   81  5.4 *
##        9) ShelveLoc=Medium 118  430  6.5  
##         18) Advertising< 12 88  290  6.1  
##           36) CompPrice< 1.4e+02 69  190  5.8  
##             72) Price>=1.3e+02 16   51  4.5 *
##             73) Price< 1.3e+02 53  110  6.2 *
##           37) CompPrice>=1.4e+02 19   58  7.4 *
##         19) Advertising>=12 30   83  7.8 *
##      5) Price< 1.1e+02 83  440  8.4  
##       10) Age>=64 32  150  6.9  
##         20) Price>=85 25   67  6.2  
##           40) ShelveLoc=Bad 9   18  4.8 *
##           41) ShelveLoc=Medium 16   21  6.9 *
##         21) Price< 85 7   20  9.6 *
##       11) Age< 64 51  180  9.3  
##         22) Income< 58 12   28  7.7 *
##         23) Income>=58 39  120  9.7  
##           46) Age>=50 14   21  8.5 *
##           47) Age< 50 25   60 10.0 *
##    3) ShelveLoc=Good 70  420 10.0  
##      6) Price>=1.1e+02 49  240  9.4  
##       12) Advertising< 14 41  160  8.9  
##         24) Age>=61 17   53  7.8 *
##         25) Age< 61 24   69  9.8 *
##       13) Advertising>=14 8   13 12.0 *
##      7) Price< 1.1e+02 21   61 12.0 *

The predicted Sales at the root is the mean Sales for the training data set, 7.535950 (values are $000s). The deviance at the root is the SSE, 2567.768. The first split is at ShelveLoc = [Bad, Medium] vs Good. Here is the unpruned tree diagram.

rpart.plot(cs_mdl_cart_full, yesno = TRUE)

The boxes show the node predicted value (mean) and the proportion of observations that are in the node (or child nodes).

rpart() grew the full tree, and used cross-validation to test the performance of the possible complexity hyperparameters. printcp() displays the candidate cp values. You can use this table to decide how to prune the tree.

printcp(cs_mdl_cart_full)
## 
## Regression tree:
## rpart(formula = Sales ~ ., data = cs_train, method = "anova")
## 
## Variables actually used in tree construction:
## [1] Advertising Age         CompPrice   Income      Population  Price      
## [7] ShelveLoc  
## 
## Root node error: 2568/321 = 8
## 
## n= 321 
## 
##    CP nsplit rel error xerror xstd
## 1   0      0         1      1    0
## 2   0      1         1      1    0
## 3   0      2         1      1    0
## 4   0      3         1      1    0
## 5   0      4         1      1    0
## 6   0      5         0      1    0
## 7   0      6         0      1    0
## 8   0      7         0      1    0
## 9   0      8         0      1    0
## 10  0      9         0      1    0
## 11  0     10         0      1    0
## 12  0     11         0      1    0
## 13  0     12         0      1    0
## 14  0     13         0      1    0
## 15  0     14         0      1    0
## 16  0     15         0      1    0

There were 16 possible cp values in this model. The model with the smallest complexity parameter allows the most splits (nsplit). The highest complexity parameter corresponds to a tree with just a root node. rel error is the SSE relative to the root node. The root node SSE is 2567.76800, so its rel error is 2567.76800/2567.76800 = 1.0. That means the absolute error of the full tree (at CP = 0.01) is 0.30963 * 2567.76800 = 795.058. You can verify that by calculating the SSE of the model predicted values:

data.frame(pred = predict(cs_mdl_cart_full, newdata = cs_train)) %>%
   mutate(obs = cs_train$Sales,
          sq_err = (obs - pred)^2) %>%
   summarize(sse = sum(sq_err))
##   sse
## 1 795

Finishing the CP table tour, xerror is the cross-validated SSE and xstd is its standard error. If you want the lowest possible error, then prune to the tree with the smallest relative SSE (xerror). If you want to balance predictive power with simplicity, prune to the smallest tree within 1 SE of the one with the smallest relative SSE. The CP table is not super-helpful for finding that tree. I’ll add a column to find it.

cs_mdl_cart_full$cptable %>%
   data.frame() %>%
   mutate(min_xerror_idx = which.min(cs_mdl_cart_full$cptable[, "xerror"]),
          rownum = row_number(),
          xerror_cap = cs_mdl_cart_full$cptable[min_xerror_idx, "xerror"] + 
             cs_mdl_cart_full$cptable[min_xerror_idx, "xstd"],
          eval = case_when(rownum == min_xerror_idx ~ "min xerror",
                           xerror < xerror_cap ~ "under cap",
                           TRUE ~ "")) %>%
   select(-rownum, -min_xerror_idx) 
##       CP nsplit rel.error xerror  xstd xerror_cap       eval
## 1  0.263      0      1.00   1.01 0.077       0.59           
## 2  0.121      1      0.74   0.75 0.059       0.59           
## 3  0.046      2      0.62   0.65 0.051       0.59           
## 4  0.045      3      0.57   0.67 0.052       0.59           
## 5  0.042      4      0.52   0.66 0.051       0.59           
## 6  0.026      5      0.48   0.62 0.049       0.59           
## 7  0.026      6      0.46   0.62 0.048       0.59           
## 8  0.024      7      0.43   0.62 0.048       0.59           
## 9  0.015      8      0.41   0.58 0.042       0.59  under cap
## 10 0.015      9      0.39   0.56 0.041       0.59  under cap
## 11 0.015     10      0.38   0.56 0.041       0.59  under cap
## 12 0.014     11      0.36   0.56 0.041       0.59  under cap
## 13 0.014     12      0.35   0.56 0.038       0.59 min xerror
## 14 0.014     13      0.33   0.56 0.038       0.59  under cap
## 15 0.011     14      0.32   0.57 0.039       0.59  under cap
## 16 0.010     15      0.31   0.57 0.038       0.59  under cap

Okay, so the simplest tree is the one with CP = 0.02599265 (5 splits). Fortunately, plotcp() presents a nice graphical representation of the relationship between xerror and cp.

plotcp(cs_mdl_cart_full, upper = "splits")

The dashed line is set at the minimum xerror + xstd. The top axis shows the number of splits in the tree. I’m not sure why the CP values are not the same as in the table (they are close, but not the same). The smallest relative error is at CP = 0.01000000 (15 splits), but the maximum CP below the dashed line (one standard deviation above the minimum error) is at CP = 0.02599265 (5 splits). Use the prune() function to prune the tree by specifying the associated cost-complexity cp.

cs_mdl_cart <- prune(
   cs_mdl_cart_full,
   cp = cs_mdl_cart_full$cptable[cs_mdl_cart_full$cptable[, 2] == 5, "CP"]
)
rpart.plot(cs_mdl_cart, yesno = TRUE)

The most “important” indicator of Sales is ShelveLoc. Here are the importance values from the model.

cs_mdl_cart$variable.importance %>% 
   data.frame() %>%
   rownames_to_column(var = "Feature") %>%
   rename(Overall = '.') %>%
   ggplot(aes(x = fct_reorder(Feature, Overall), y = Overall)) +
   geom_pointrange(aes(ymin = 0, ymax = Overall), color = "cadetblue", size = .3) +
   theme_minimal() +
   coord_flip() +
   labs(x = "", y = "", title = "Variable Importance with Simple Regression")

The most important indicator of Sales is ShelveLoc, then Price, then Age, all of which appear in the final model. CompPrice was also important.

The last step is to make predictions on the validation data set. The root mean squared error (\(RMSE = \sqrt{(1/2) \sum{(actual - pred)^2}})\) and mean absolute error (\(MAE = (1/n) \sum{|actual - pred|}\)) are the two most common measures of predictive accuracy. The key difference is that RMSE punishes large errors more harshly. For a regression tree, set argument type = "vector" (or do not specify at all).

cs_preds_cart <- predict(cs_mdl_cart, cs_test, type = "vector")

cs_rmse_cart <- RMSE(
   pred = cs_preds_cart,
   obs = cs_test$Sales
)
cs_rmse_cart
## [1] 2.4

The pruning process leads to an average prediction error of 2.363 in the test data set. Not too bad considering the standard deviation of Sales is 2.8. Here is a predicted vs actual plot.

data.frame(Predicted = cs_preds_cart, Actual = cs_test$Sales) %>%
   ggplot(aes(x = Actual, y = Predicted)) +
   geom_point(alpha = 0.6, color = "cadetblue") +
   geom_smooth() +
   geom_abline(intercept = 0, slope = 1, linetype = 2) +
   labs(title = "Carseats CART, Predicted vs Actual")
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'

The 6 possible predicted values do a decent job of binning the observations.

8.2.1 Training with Caret

I can also fit the model with caret::train(), specifying method = "rpart". I’ll build the model using 10-fold cross-validation to optimize the hyperparameter CP.

cs_trControl = trainControl(
   method = "cv",
   number = 10,
   savePredictions = "final"       # save predictions for the optimal tuning parameter
)

I’ll let the model look for the best CP tuning parameter with tuneLength to get close, then fine-tune with tuneGrid.

set.seed(1234)
cs_mdl_cart2 = train(
   Sales ~ ., 
   data = cs_train, 
   method = "rpart",
   tuneLength = 5,
   metric = "RMSE",
   trControl = cs_trControl
)
## Warning in nominalTrainWorkflow(x = x, y = y, wts = weights, info = trainInfo, :
## There were missing values in resampled performance measures.
print(cs_mdl_cart2)
## CART 
## 
## 321 samples
##  10 predictor
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 289, 289, 289, 289, 289, 289, ... 
## Resampling results across tuning parameters:
## 
##   cp     RMSE  Rsquared  MAE
##   0.042  2.2   0.41      1.8
##   0.045  2.2   0.38      1.8
##   0.046  2.3   0.37      1.8
##   0.121  2.4   0.29      1.9
##   0.263  2.7   0.19      2.2
## 
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was cp = 0.042.

The first cp (0.04167149) produced the smallest RMSE. I can drill into the best value of cp using a tuning grid. I’ll try that now.

set.seed(1234)
cs_mdl_cart2 = train(
   Sales ~ ., 
   data = cs_train, 
   method = "rpart",
   tuneGrid = expand.grid(cp = seq(from = 0, to = 0.1, by = 0.01)),
   metric = "RMSE",
   trControl = cs_trControl
)
print(cs_mdl_cart2)
## CART 
## 
## 321 samples
##  10 predictor
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold) 
## Summary of sample sizes: 289, 289, 289, 289, 289, 289, ... 
## Resampling results across tuning parameters:
## 
##   cp    RMSE  Rsquared  MAE
##   0.00  2.1   0.50      1.7
##   0.01  2.1   0.46      1.7
##   0.02  2.1   0.47      1.7
##   0.03  2.1   0.45      1.7
##   0.04  2.1   0.44      1.7
##   0.05  2.3   0.36      1.8
##   0.06  2.3   0.37      1.8
##   0.07  2.3   0.36      1.8
##   0.08  2.3   0.36      1.8
##   0.09  2.3   0.36      1.8
##   0.10  2.3   0.36      1.8
## 
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was cp = 0.

It looks like the best performing tree is the unpruned one.

plot(cs_mdl_cart2)

Let’s see the final model.

rpart.plot(cs_mdl_cart2$finalModel)

What were the most important variables?

plot(varImp(cs_mdl_cart2), main="Variable Importance with Simple Regression")

Evaluate the model by making predictions with the test data set.

cs_preds_cart2 <- predict(cs_mdl_cart2, cs_test, type = "raw")
data.frame(Actual = cs_test$Sales, Predicted = cs_preds_cart2) %>%
ggplot(aes(x = Actual, y = Predicted)) +
   geom_point(alpha = 0.6, color = "cadetblue") +
   geom_smooth(method = "loess", formula = "y ~ x") +
   geom_abline(intercept = 0, slope = 1, linetype = 2) +
   labs(title = "Carseats CART, Predicted vs Actual (caret)")

The model over-estimates at the low end and underestimates at the high end. Calculate the test data set RMSE.

(cs_rmse_cart2 <- RMSE(pred = cs_preds_cart2, obs = cs_test$Sales))
## [1] 2.3

Caret performed better in this model. Here is a summary the RMSE values of the two models.

cs_scoreboard <- rbind(
   data.frame(Model = "Single Tree", RMSE = cs_rmse_cart),
   data.frame(Model = "Single Tree (caret)", RMSE = cs_rmse_cart2)
) %>% arrange(RMSE)
scoreboard(cs_scoreboard)

Model

RMSE

Single Tree (caret)

2.2983

Single Tree

2.3632