Modul 5: Beslutningstræer - EKSEMPEL
library(caret)
library(DescTools)
library(dplyr)
library(ggplot2)
library(gridExtra)
##
## Vedhæfter pakke: 'gridExtra'
## Det følgende objekt er maskeret fra 'package:dplyr':
##
## combine
library(magrittr)
library(openxlsx)
library(pdp)
library(ranger)
library(rsample)
library(rpart)
library(rpart.plot)
library(vip)
##
## Vedhæfter pakke: 'vip'
## Det følgende objekt er maskeret fra 'package:utils':
##
## vi
Klargøring af datasæt
# tilføjelse af variabel "Kommune"
$Kommune <- ifelse(bolig$Postnummer < 3000, "Gentofte", "Tønder")
bolig# skalering af Salgspris så den angives i 1.000 kr.
# (for at lette læsningen af beslutningstræerne nedenfor)
$Salgspris <- bolig$Salgspris/1000 bolig
# opdeling i test- og træningsdatasæt
set.seed(4321)
<- initial_split(bolig, 2/3)
split <- training(split)
train <- testing(split) test
Beslutningstræer
Valg af splitværdi
Hvordan vælges splitværdien i hver enkelt knude?
Eksempel med estimation af et beslutningstræ med 1 lag og UDEN “pruning”
# model
<- rpart(formula = Salgspris ~ Boligareal, data=train, method="anova",
dt0 control = c(maxdepth = 1, cp=0))
rpart.plot(dt0)
# prædiktion baseret på træningssæt
<- predict(dt0, train)
dt0.train.predict ::RMSE(dt0.train.predict, train$Salgspris)
DescTools## [1] 2886.19
Kontrolberegning af splitværdi i dt0 output:
# beregning af SSE for forskellige split-værdier
<- seq(100.5, 300.5, by=1)
test.split_value <- c()
SSE for(i in 1:length(test.split_value)){
<- train$Salgspris[train$Boligareal < test.split_value[i]]
test.knot2 <- train$Salgspris[train$Boligareal >= test.split_value[i]]
test.knot3 <- sum((test.knot2-mean(test.knot2))^2) + sum((test.knot3-mean(test.knot3))^2)
SSE[i]
}# figur af SSE for forskellige split-værdier
plot(test.split_value, SSE, type="l")
# split-værdi med mindst SSE
<- test.split_value[which.min(SSE)] split_value
Kontrolberegning af RMSE i dt0
# beslutningstræets split i knude 2 og 3
<- train$Salgspris[train$Boligareal < split_value]
knot2 <- train$Salgspris[train$Boligareal >= split_value]
knot3 sqrt(( sum((knot2-mean(knot2))^2) + sum((knot3-mean(knot3))^2) ) / nrow(train))
## [1] 2886.19
Output
Hvordan fremkommer beslutningstræets output?
Eksempel med estimation af et beslutningstræ med 3 lag og UDEN “pruning.”
# model
<- rpart(formula = Salgspris ~ Boligareal+Kommune, data=train, method="anova",
dt1 control = c(maxdepth = 3, cp=0))
rpart.plot(dt1)
# sammenlign fit for trænings- og testdatasæt
<- predict(dt1, train)
dt1.train.predict <- predict(dt1, test)
dt1.test.predict ::RMSE(dt1.train.predict, train$Salgspris)
DescTools## [1] 1573.105
::RMSE(dt1.test.predict, test$Salgspris)
DescTools## [1] 1570.103
Kontrolberegninger i dt1 output:
dt1## n= 6108
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 6108 65562630000 4002.6320
## 2) Kommune=Tønder 2503 1430496000 881.7720
## 4) Boligareal< 138.5 1214 320298900 670.5956
## 8) Boligareal< 109.5 529 60316450 585.9036 *
## 9) Boligareal>=109.5 685 253257800 736.0001 *
## 5) Boligareal>=138.5 1289 1005070000 1080.6610
## 10) Boligareal< 163.5 568 280850600 934.8211 *
## 11) Boligareal>=163.5 721 702620600 1195.5530 *
## 3) Kommune=Gentofte 3605 22827070000 6169.4880
## 6) Boligareal< 184.5 2210 6821505000 5055.0310
## 12) Boligareal< 126.5 756 1087817000 4059.8330 *
## 13) Boligareal>=126.5 1454 4595618000 5572.4790 *
## 7) Boligareal>=184.5 1395 8912244000 7935.0430
## 14) Boligareal< 213.5 597 3208348000 7071.9120 *
## 15) Boligareal>=213.5 798 4926398000 8580.7690 *
# knude 1
<- train$Salgspris
knot1 length(knot1)
## [1] 6108
sum((knot1 - mean(knot1))^2)
## [1] 65562631332
mean(knot1)
## [1] 4002.632
# knude 3
<- train$Salgspris[(train$Kommune=="Gentofte")]
knot3 sum((knot3 - mean(knot3))^2)
## [1] 22827070550
mean(knot3)
## [1] 6169.488
# knude 6
<- train$Salgspris[(train$Kommune=="Gentofte")&(train$Boligareal<198.5)]
knot6 sum((knot6 - mean(knot6))^2)
## [1] 9382568721
mean(knot6)
## [1] 5291.134
# knude 13
<- train$Salgspris[(train$Kommune=="Gentofte")&(train$Boligareal<198.5)&(train$Boligareal>=160.5)]
knot13 sum((knot13 - mean(knot13))^2)
## [1] 3948730255
mean(knot13)
## [1] 6386.468
unique(dt1.train.predict[(train$Kommune=="Gentofte")&(train$Boligareal<198.5)&(train$Boligareal>=160.5)])
## [1] 5572.479 7071.912
Beslutningstræets dybde
Hvilken betydning har træets dybde?
Eksempel med estimation af maksimalt beslutningstræ (dvs. UDEN “pruning”).
# model
<- rpart(formula = Salgspris ~ Boligareal+Kommune, data=train, method="anova",
dt2 control = c(cp=0, xval=10))
rpart.plot(dt2) ## sammenlign med rpart.plot(dt1)
## Warning: labs do not fit even at cex 0.15, there may be some overplotting
plotcp(dt2)
# sammenlign fit for trænings- og testdatasæt
<- predict(dt2, train)
dt2.train.predict <- predict(dt2, test)
dt2.test.predict ::RMSE(dt2.train.predict, train$Salgspris)
DescTools## [1] 1467.361
::RMSE(dt2.test.predict, test$Salgspris)
DescTools## [1] 1558.686
Pruning
Hvilken betydning har “pruning?”
Eksempel med estimation af beslutningstræ med 3 lag og MED “pruning.”
# model
<- rpart(formula = Salgspris ~ Boligareal+Kommune, data=train, method="anova",
dt3 control = c(maxdepth=3, xval=10))
rpart.plot(dt3) ## sammenlign med rpart.plot(dt1)
plotcp(dt3)
# sammenlign fit for trænings- og testdatasæt
<- predict(dt3, train)
dt3.train.predict <- predict(dt3, test)
dt3.test.predict ::RMSE(dt3.train.predict, train$Salgspris)
DescTools## [1] 1580.034
::RMSE(dt3.test.predict, test$Salgspris)
DescTools## [1] 1576.344
Betydning af forklarende variable
Hvor meget betyder hver forklarende variabel?
# model
<- rpart(formula = Salgspris ~ Boligareal+Opførselsår+Antal.værelser+Grundareal+Tidligere.solgt+Kommune,
dt4 data = train, method="anova", control = c(cp=0, xval=10))
# variable importance plot
vip(dt4)
# partial dependence plots ("pdp") for kontinuerte forklarende variable: 1 variabel
<- partial(dt4, pred.var = "Boligareal")
pdp.Boligareal
pdp.Boligareal## Boligareal yhat
## 1 38.00 2095.392
## 2 57.76 2095.392
## 3 77.52 2446.473
## 4 97.28 2715.734
## 5 117.04 2932.118
## 6 136.80 3456.118
## 7 156.56 3597.617
## 8 176.32 3988.372
## 9 196.08 4453.851
## 10 215.84 5247.973
## 11 235.60 4929.438
## 12 255.36 4962.621
## 13 275.12 5035.658
## 14 294.88 5866.319
## 15 314.64 5056.218
## 16 334.40 5261.173
## 17 354.16 5272.109
## 18 373.92 5272.109
## 19 393.68 5272.109
## 20 413.44 5272.109
## 21 433.20 5272.109
## 22 452.96 5272.109
## 23 472.72 5272.109
## 24 492.48 5272.109
## 25 512.24 5272.109
## 26 532.00 5272.109
## 27 551.76 5272.109
## 28 571.52 5272.109
## 29 591.28 5272.109
## 30 611.04 5272.109
## 31 630.80 5272.109
## 32 650.56 5272.109
## 33 670.32 5272.109
## 34 690.08 5272.109
## 35 709.84 5272.109
## 36 729.60 5272.109
## 37 749.36 5272.109
## 38 769.12 5272.109
## 39 788.88 5272.109
## 40 808.64 5272.109
## 41 828.40 5272.109
## 42 848.16 5272.109
## 43 867.92 5272.109
## 44 887.68 5272.109
## 45 907.44 5272.109
## 46 927.20 5272.109
## 47 946.96 5272.109
## 48 966.72 5272.109
## 49 986.48 5272.109
## 50 1006.24 5272.109
## 51 1026.00 5272.109
autoplot(pdp.Boligareal)
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]` instead.
plot(train$Boligareal, train$Salgspris)
<- partial(dt4, pred.var = "Opførselsår")
pdp.Opførselsår autoplot(pdp.Opførselsår)
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]` instead.
plot(train$Opførselsår, train$Salgspris)
<- partial(dt4, pred.var = "Antal.værelser")
pdp.Antal.værelser autoplot(pdp.Antal.værelser)
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]` instead.
plot(train$Antal.værelser, train$Salgspris)
<- partial(dt4, pred.var = "Grundareal")
pdp.Grundareal autoplot(pdp.Grundareal)
## Warning: Use of `object[[1L]]` is discouraged. Use `.data[[1L]]` instead.
## Warning: Use of `object[["yhat"]]` is discouraged. Use `.data[["yhat"]]` instead.
plot(train$Grundareal, train$Salgspris)
# partial dependence plots ("pdp") for kontinuerte forklarende variable: 2 variable
<- partial(dt4, pred.var = c("Boligareal", "Opførselsår"))
pdp.Boligareal.Opførselsår plotPartial(pdp.Boligareal.Opførselsår, levelplot=FALSE, drape=TRUE)
Bagging
# model
<- train(
dt5 ~ Boligareal+Opførselsår+Antal.værelser+Grundareal+Tidligere.solgt+Kommune,
Salgspris data = train,
method = "treebag",
trControl = trainControl(method="cv", number=10),
nbagg = 50,
control = rpart.control(minsplit=2, cp=0)
)vip(dt5) ## sammenlign med vip(dt4)
Random forests
Estimation
# model
<- train(
dt6 ~ Boligareal+Opførselsår+Antal.værelser+Grundareal+Tidligere.solgt+Kommune,
Salgspris data = train,
method = "ranger"
)
Hypertuning
# vælg hvilke hyperparameterværdier, der skal undersøges
<- expand.grid(
hyper_grid mtry = c(1:5),
splitrule = c("variance","extratrees","maxstat"),
min.node.size = c(5, 10)
)
# estimation af random forest model for hver kombination af hyperparametre
<- train(
dt_tune ~ Boligareal+Opførselsår+Antal.værelser+Grundareal+Tidligere.solgt+Kommune,
Salgspris data = train,
method = "ranger",
tuneGrid = hyper_grid
)
# sammenligning af de estimerede modeller
<- arrange(dt_tune$results, RMSE)
results
results## mtry splitrule min.node.size RMSE Rsquared MAE RMSESD RsquaredSD MAESD
## 1 2 variance 10 1455.110 0.8038821 963.5533 33.48078 0.009276595 20.18752
## 2 2 variance 5 1458.715 0.8028612 962.4676 33.88389 0.009252204 21.58820
## 3 3 extratrees 10 1465.646 0.8009742 968.9371 34.96640 0.009651970 21.21562
## 4 3 variance 10 1467.982 0.8004989 962.4889 35.22994 0.009382435 22.59192
## 5 4 extratrees 10 1469.354 0.7999862 968.6904 36.08406 0.009910627 22.17343
## 6 3 maxstat 5 1471.402 0.7996145 980.3506 34.05371 0.009847222 20.12009
## 7 3 extratrees 5 1471.957 0.7992925 967.2131 34.52667 0.009334161 21.87398
## 8 5 extratrees 10 1472.431 0.7991919 969.7641 36.17279 0.009918929 22.30883
## 9 3 maxstat 10 1473.020 0.7992658 983.3653 35.12748 0.010241090 20.37377
## 10 4 variance 10 1475.801 0.7985254 966.0639 35.14168 0.009319587 23.03018
## 11 4 maxstat 10 1477.690 0.7977188 978.8162 36.55335 0.010318933 21.87229
## 12 4 maxstat 5 1479.209 0.7972980 977.9780 36.30596 0.010168541 22.57949
## 13 3 variance 5 1479.335 0.7976210 965.5289 34.37255 0.009059103 23.11744
## 14 4 extratrees 5 1483.452 0.7963470 971.5727 35.71164 0.009639408 23.02863
## 15 5 variance 10 1483.741 0.7964888 970.6303 35.37106 0.009321622 23.46073
## 16 2 maxstat 5 1484.736 0.7979994 1011.4640 32.71729 0.010148586 19.10762
## 17 5 extratrees 5 1487.343 0.7953623 973.9502 35.28342 0.009439750 22.68994
## 18 2 maxstat 10 1488.472 0.7971701 1014.9189 32.80092 0.010302429 19.11739
## 19 4 variance 5 1488.709 0.7952634 970.2621 35.42926 0.009303104 23.87076
## 20 2 extratrees 5 1489.168 0.7955557 1001.8671 31.49658 0.009018808 19.54391
## 21 5 maxstat 10 1491.949 0.7938191 983.4338 39.19683 0.010908063 23.98772
## 22 5 maxstat 5 1493.961 0.7933278 982.7309 38.09420 0.010581035 23.55947
## 23 2 extratrees 10 1496.582 0.7936034 1008.0833 31.22641 0.009146593 18.81949
## 24 5 variance 5 1497.279 0.7930669 975.1887 35.82601 0.009435892 24.68483
## 25 1 variance 5 1635.621 0.7841928 1219.6067 26.27294 0.009905537 18.05028
## 26 1 variance 10 1639.854 0.7835121 1223.7972 28.84381 0.010138700 20.35316
## 27 1 maxstat 10 1712.228 0.7735645 1301.9841 28.14747 0.010358127 27.39370
## 28 1 maxstat 5 1714.385 0.7739082 1306.6919 27.88852 0.010464434 26.21179
## 29 1 extratrees 5 1829.719 0.7436443 1387.6357 24.42494 0.009635193 16.76873
## 30 1 extratrees 10 1834.225 0.7440079 1393.8018 28.55343 0.008275010 22.55080