# Chapter 8 提升方法

``````rm(list = ls())
library(purrr)
slice_point_fun <- function(x_little) {
stopifnot(is.vector(x_little))
x_little <- x_little[order(x_little)]
len_x <- length(x_little)
slice_point <- vector("numeric",
length = len_x - 1)
for (i in seq(len_x - 1)) {
slice_point[i] <- (x_little[i] +
x_little[i + 1]) / 2
}
return(slice_point)
}
# slice_point <- slice_point_fun(x[, 1])

Error <- function(classifier_slice_point, sample_weight, x) {
len_x <- length(sample_weight)
yi_1 <- vector("integer", length = len_x)
# Gx=1, when x < classifier_slice_point
idx_1 <- which(x < classifier_slice_point)
yi_1[idx_1] <- 1
yi_1[-idx_1] <- -1
error_less <- sample_weight[which(yi_1 != y)] %>%
sum(.)

# Gx=1, when x > classifier_slice_point
yi_2 <- vector("integer", length = len_x)
idx_2 <- which(x > classifier_slice_point)
yi_2[idx_2] <- 1
yi_2[-idx_2] <- -1
error_more <- sample_weight[which(yi_2 != y)] %>%
sum(.)

error <- min(c(error_less, error_more))
error_idx <- which.min(c(error_less, error_more))
error_list <- list(error = error,
classifier = ifelse(error_idx == 1,
"less than",
"more than"))
return(error_list)
}

# sample_weight <- rep(1 / row_data, row_data)
# Error(2.0, sample_weight, x[, 3])    ####################################################
# x[order(x[, 3]),3]

G <- function(classifier_slice_point,
sample_weight,
sample_point_x, x) {

classifier <- Error(classifier_slice_point,
sample_weight, x)[["classifier"]]
if (classifier == "less than") {
G_value <- ifelse(
sample_point_x < classifier_slice_point,
1, -1)
} else {
G_value <- ifelse(
sample_point_x > classifier_slice_point,
1, -1)
}
return(G_value)
}

#G(2.5, sample_weight, x[1, 1], x[, 1])

Z <- function(slice_point_error_min,
sample_weight, x,
alpha, y) {

len_x <- length(x)
Z_value <- 0
for (i in seq(len_x)) {
Gi <- G(slice_point_error_min,
sample_weight, x[i], x)
zi <- sample_weight[i] * exp(-alpha * y[i] * Gi)
Z_value <-  Z_value + zi
}
Z_value <- round(Z_value, 4)
return(Z_value)
}

#Z(2.5, sample_weight, x[, 1], 0.4236, y)

final <- function(x, y) {
row_data <- nrow(x)
col_data <- ncol(x) + 1
sample_weight <- rep(1 / row_data, row_data)
miss_classify <- Inf
f_value_vec <- vector("numeric", length = row_data)
k <- 1
while (miss_classify > 0) {
print(paste0("k= " ,k))
print(paste0("sample_weight= " ,sample_weight))
var_ <- list()
for (kk in seq(col_data - 1)) {
#kk <- 1
slice_point <- slice_point_fun(x[, kk])
len_slice_point <- length(slice_point) #
sample_weight_list <- vector("list",
length = len_slice_point)
for (i in seq(len_slice_point)) {
sample_weight_list[[i]] <- sample_weight
}
x_kk <- vector("list", length = len_slice_point)
for (j in seq(len_slice_point)) {
x_kk[[j]] <- x[, kk]
}
all_error_i <- pmap(list(slice_point,
sample_weight_list,
x_kk), Error)
vec_all_error_i <- modify_depth(all_error_i, 1, "error") %>%
as_vector(.)
error_min_i <- vec_all_error_i %>%
min(.) #
print(paste0("error_min_i = ", error_min_i))
error_min_idx_i <- which(vec_all_error_i == error_min_i)[1] #
slice_point_error_min_i <- slice_point[error_min_idx_i] #

var_[[kk]] <- list(error_min_i = error_min_i,
error_min_idx_i = error_min_idx_i,
slice_point_error_min_i = slice_point_error_min_i,
slice_point = slice_point)
}
vec_all_error <- modify_depth(var_, 1, "error_min_i") %>%
as_vector(.)
error_min <- vec_all_error %>%
min(.) #
error_min_var_idx <- which(vec_all_error == error_min)[1] #  是哪个变量

slice_point_error_min <-
var_[[error_min_var_idx]][["slice_point_error_min_i"]]

alpha <- 1 / 2 * log((1 - error_min) /
(error_min + 0.00001), base = exp(1))
alpha <- round(alpha, 4)
print(paste0("alpha=", alpha))
miss_classify <- 0
k_var_idx <- error_min_var_idx
Z_value <- Z(slice_point_error_min,
sample_weight, x[, k_var_idx],
alpha, y)
for (i in seq(row_data)) {
#print(paste0("i = ", i))
G_value <- G(slice_point_error_min,
sample_weight,
x[i, k_var_idx], x[, k_var_idx])
f_value <- alpha * G_value + f_value_vec[i]
miss <- ifelse(sign(f_value) != y[i], 1, 0)
miss_classify <- miss_classify + miss
f_value_vec[i] <- f_value

exp_value <- exp(-alpha * y[i] * G_value)
sample_weight[i] <- round(sample_weight[i] /
Z_value * exp_value , 5)
}
print(paste0("miss_classify= ", miss_classify))
k <- k + 1
}
}

########
x <- c(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) %>% matrix()
y <- c(1, 1, 1, -1, -1, -1, 1, 1, 1, -1)
final(x = x, y = y)``````
``````## [1] "k= 1"
##  [1] "sample_weight= 0.1" "sample_weight= 0.1" "sample_weight= 0.1"
##  [4] "sample_weight= 0.1" "sample_weight= 0.1" "sample_weight= 0.1"
##  [7] "sample_weight= 0.1" "sample_weight= 0.1" "sample_weight= 0.1"
## [10] "sample_weight= 0.1"
## [1] "error_min_i = 0.3"
## [1] "alpha=0.4236"
## [1] "miss_classify= 3"
## [1] "k= 2"
##  [1] "sample_weight= 0.07143" "sample_weight= 0.07143"
##  [3] "sample_weight= 0.07143" "sample_weight= 0.07143"
##  [5] "sample_weight= 0.07143" "sample_weight= 0.07143"
##  [7] "sample_weight= 0.16666" "sample_weight= 0.16666"
##  [9] "sample_weight= 0.16666" "sample_weight= 0.07143"
## [1] "error_min_i = 0.21429"
## [1] "alpha=0.6496"
## [1] "miss_classify= 3"
## [1] "k= 3"
##  [1] "sample_weight= 0.04545" "sample_weight= 0.04545"
##  [3] "sample_weight= 0.04545" "sample_weight= 0.16665"
##  [5] "sample_weight= 0.16665" "sample_weight= 0.16665"
##  [7] "sample_weight= 0.10605" "sample_weight= 0.10605"
##  [9] "sample_weight= 0.10605" "sample_weight= 0.04545"
## [1] "error_min_i = 0.1818"
## [1] "alpha=0.7521"
## [1] "miss_classify= 0"``````
``````# iris_1 <- iris[1:100, ]
# iris_1[, 5] <- NULL
# iris_1[["labels"]] <- c(rep(1, 50), rep(-1, 50))
# x_iris <- iris_1[, 1:4]
# y_iris <- iris_1[, 5]
# row_data <- nrow(x)
# col_data <- ncol(x) + 1
# final(x = x_iris, y = y_iris)``````
``````x <- c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
y <- c(5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05)
len_x <- length(x)
slice_point <- vector("numeric", length = len_x - 1)
for (i in seq(len_x - 1)) {
slice_point[i] <- (x[i] + x[i + 1]) / 2
}
slice_point``````
``## [1] 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5``
``````#sample_weight <- rep(1 / len_x, len_x)

best_slicepoint <- function(vector_slice_point, x, y) {
len_x <- length(x)
len_vsp <- length(vector_slice_point)
ms <- vector("numeric", length = len_vsp)
for (i in seq(len_vsp)) {
idx_1 <- which(x <= vector_slice_point[i])
c1 <- y[idx_1] %>% mean(.)
c2 <- y[-idx_1] %>% mean(.)
ms[i] <- sum((y[idx_1] - c1)^2) + sum((y[-idx_1] - c2)^2)
}
idx_minvalue <- which.min(ms)
minvalue <- ms[idx_minvalue] %>% round(2)
best_slice_point <- vector_slice_point[idx_minvalue]
R1 <- subset(x, x <= best_slice_point)
R2 <- subset(x, x > best_slice_point)
idx_R1 <- which(x %in% R1)
c1_R1 <- y[idx_R1] %>% mean(.) %>% round(2)
c2_R1 <- y[-idx_R1] %>% mean(.) %>% round(2)

# x 可以不用先排好序，也可以直接输入
residuals <- vector("numeric", length = len_x)
for (i in seq(len_x)) {
if (i %in% idx_R1) {
residuals[i] <- y[i] - c1_R1
} else {
residuals[i] <- y[i] - c2_R1
}
}
return_list <- list(best_slice_point = best_slice_point,
minvalue = minvalue,
R = list(R1 = R1, R2 = R2),
c = c(c1_R1, c2_R1),
residuals = residuals)
return(return_list)
}
tree <- best_slicepoint(slice_point,x,y)    ####################################################

# x 可以不用先排好序，也可以直接输入
# x1 <- c(8, 5, 2, 3, 4, 6, 7, 9, 10, 1)
# y1 <- c(8.70, 6.80, 5.70, 5.91, 6.40,7.05, 8.90, 9.00, 9.05, 5.56)
# best_slicepoint(slice_point, x1, y1)

# best_slicepoint(slice_point, x, tree\$residuals)

x <- c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
y <- c(5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05)
square_loss_error <- Inf
eps <- 0.17
k <- 1
final_TREE <- list()
while (square_loss_error > eps) {
print(paste0("k= ",k))
tree <- best_slicepoint(slice_point, x, y)
final_TREE[[k]] <- tree
square_loss_error <- tree\$minvalue
TREE <- tree\$best_slice_point
TREE_value <- tree\$c
print(paste0("square_loss_error = ", square_loss_error))
print(paste0("TREE_slice_point = ", TREE))
print(paste0("TREE_value = ", TREE_value))
x <- x
y <- tree\$residuals
k <- k + 1
}``````
``````## [1] "k= 1"
## [1] "square_loss_error = 1.93"
## [1] "TREE_slice_point = 6.5"
## [1] "TREE_value = 6.24" "TREE_value = 8.91"
## [1] "k= 2"
## [1] "square_loss_error = 0.79"
## [1] "TREE_slice_point = 3.5"
## [1] "TREE_value = -0.52" "TREE_value = 0.22"
## [1] "k= 3"
## [1] "square_loss_error = 0.47"
## [1] "TREE_slice_point = 6.5"
## [1] "TREE_value = 0.15"  "TREE_value = -0.22"
## [1] "k= 4"
## [1] "square_loss_error = 0.3"
## [1] "TREE_slice_point = 4.5"
## [1] "TREE_value = -0.16" "TREE_value = 0.11"
## [1] "k= 5"
## [1] "square_loss_error = 0.23"
## [1] "TREE_slice_point = 6.5"
## [1] "TREE_value = 0.07"  "TREE_value = -0.11"
## [1] "k= 6"
## [1] "square_loss_error = 0.17"
## [1] "TREE_slice_point = 2.5"
## [1] "TREE_value = -0.15" "TREE_value = 0.04"``````
``final_TREE``
``````## [[1]]
## [[1]]\$best_slice_point
## [1] 6.5
##
## [[1]]\$minvalue
## [1] 1.93
##
## [[1]]\$R
## [[1]]\$R\$R1
## [1] 1 2 3 4 5 6
##
## [[1]]\$R\$R2
## [1]  7  8  9 10
##
##
## [[1]]\$c
## [1] 6.24 8.91
##
## [[1]]\$residuals
##  [1] -0.68 -0.54 -0.33  0.16  0.56  0.81 -0.01 -0.21  0.09  0.14
##
##
## [[2]]
## [[2]]\$best_slice_point
## [1] 3.5
##
## [[2]]\$minvalue
## [1] 0.79
##
## [[2]]\$R
## [[2]]\$R\$R1
## [1] 1 2 3
##
## [[2]]\$R\$R2
## [1]  4  5  6  7  8  9 10
##
##
## [[2]]\$c
## [1] -0.52  0.22
##
## [[2]]\$residuals
##  [1] -0.16 -0.02  0.19 -0.06  0.34  0.59 -0.23 -0.43 -0.13 -0.08
##
##
## [[3]]
## [[3]]\$best_slice_point
## [1] 6.5
##
## [[3]]\$minvalue
## [1] 0.47
##
## [[3]]\$R
## [[3]]\$R\$R1
## [1] 1 2 3 4 5 6
##
## [[3]]\$R\$R2
## [1]  7  8  9 10
##
##
## [[3]]\$c
## [1]  0.15 -0.22
##
## [[3]]\$residuals
##  [1] -0.31 -0.17  0.04 -0.21  0.19  0.44 -0.01 -0.21  0.09  0.14
##
##
## [[4]]
## [[4]]\$best_slice_point
## [1] 4.5
##
## [[4]]\$minvalue
## [1] 0.3
##
## [[4]]\$R
## [[4]]\$R\$R1
## [1] 1 2 3 4
##
## [[4]]\$R\$R2
## [1]  5  6  7  8  9 10
##
##
## [[4]]\$c
## [1] -0.16  0.11
##
## [[4]]\$residuals
##  [1] -0.15 -0.01  0.20 -0.05  0.08  0.33 -0.12 -0.32 -0.02  0.03
##
##
## [[5]]
## [[5]]\$best_slice_point
## [1] 6.5
##
## [[5]]\$minvalue
## [1] 0.23
##
## [[5]]\$R
## [[5]]\$R\$R1
## [1] 1 2 3 4 5 6
##
## [[5]]\$R\$R2
## [1]  7  8  9 10
##
##
## [[5]]\$c
## [1]  0.07 -0.11
##
## [[5]]\$residuals
##  [1] -0.22 -0.08  0.13 -0.12  0.01  0.26 -0.01 -0.21  0.09  0.14
##
##
## [[6]]
## [[6]]\$best_slice_point
## [1] 2.5
##
## [[6]]\$minvalue
## [1] 0.17
##
## [[6]]\$R
## [[6]]\$R\$R1
## [1] 1 2
##
## [[6]]\$R\$R2
## [1]  3  4  5  6  7  8  9 10
##
##
## [[6]]\$c
## [1] -0.15  0.04
##
## [[6]]\$residuals
##  [1] -0.07  0.07  0.09 -0.16 -0.03  0.22 -0.05 -0.25  0.05  0.10``````
``````final_tree_func <- function(x) {
y <- vector("numeric", length = length(final_TREE))
for (i in seq(length(final_TREE))) {
if (x < final_TREE[[i]][["best_slice_point"]]) {
y[i] <- final_TREE[[i]][["c"]][1]
} else {
y[i] <- final_TREE[[i]][["c"]][2]
}
}
sum_y <- sum(y)
return(sum_y)
}

final_tree_func(3.8)``````
``## [1] 6.56``
``final_tree_func(2.5)``
``## [1] 5.82``
``final_tree_func(100)``
``## [1] 8.95``
``````predict_fun <- function(x) {
stopifnot(is.vector(x))
leng_x <- length(x)
predict_value <- vector("numeric", length = leng_x)
for (i in seq(leng_x)) {
predict_value[i] <- final_tree_func(x[i])
}
return(predict_value)
}
predict_fun(seq(1, 20, by = 0.2))``````
``````##  [1] 5.63 5.63 5.63 5.63 5.63 5.63 5.63 5.63 5.82 5.82 5.82 5.82 5.82 6.56
## [15] 6.56 6.56 6.56 6.56 6.83 6.83 6.83 6.83 6.83 6.83 6.83 6.83 6.83 6.83
## [29] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [43] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [57] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [71] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [85] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95``````
``````## 扩展
require(TH.data)``````
``## Loading required package: TH.data``
``## Loading required package: survival``
``## Loading required package: MASS``
``````##
## Attaching package: 'MASS'``````
``````## The following object is masked from 'package:dplyr':
##
##     select``````
``````##
## Attaching package: 'TH.data'``````
``````## The following object is masked from 'package:MASS':
##
##     geyser``````
``dplyr::glimpse(bodyfat)``
``````## Observations: 71
## Variables: 10
## \$ age          <dbl> 57, 65, 59, 58, 60, 61, 56, 60, 58, 62, 63, 62, 6...
## \$ DEXfat       <dbl> 41.68, 43.29, 35.41, 22.79, 36.42, 24.13, 29.83, ...
## \$ waistcirc    <dbl> 100.0, 99.5, 96.0, 72.0, 89.5, 83.5, 81.0, 89.0, ...
## \$ hipcirc      <dbl> 112.0, 116.5, 108.5, 96.5, 100.5, 97.0, 103.0, 10...
## \$ elbowbreadth <dbl> 7.1, 6.5, 6.2, 6.1, 7.1, 6.5, 6.9, 6.2, 6.4, 7.0,...
## \$ kneebreadth  <dbl> 9.4, 8.9, 8.9, 9.2, 10.0, 8.8, 8.9, 8.5, 8.8, 8.8...
## \$ anthro3a     <dbl> 4.42, 4.63, 4.12, 4.03, 4.24, 3.55, 4.14, 4.04, 3...
## \$ anthro3b     <dbl> 4.95, 5.01, 4.74, 4.48, 4.68, 4.06, 4.52, 4.70, 4...
## \$ anthro3c     <dbl> 4.50, 4.48, 4.60, 3.91, 4.15, 3.64, 4.31, 4.47, 3...
## \$ anthro4      <dbl> 6.13, 6.37, 5.82, 5.66, 5.91, 5.14, 5.69, 5.70, 5...``````
``````data("bodyfat", package = "TH.data")

### final model proposed by Garcia et al. (2005)
fmod <- lm(DEXfat ~ hipcirc + anthro3a + kneebreadth, data = bodyfat)
predict(fmod, bodyfat)``````
``````##        47        48        49        50        51        52        53
## 39.315480 42.537378 33.901251 27.531664 32.970386 22.750005 31.266049
##        54        55        56        57        58        59        60
## 30.637342 25.957475 21.683960 24.729106 26.281913 24.595095 19.542765
##        61        62        63        64        65        66        67
## 31.839068 38.161299 28.242285 41.339475 37.800842 22.799220 33.086417
##        68        69        70        71        72        73        74
## 23.194527 27.160642 25.726705 19.917902 34.546497 34.448555 22.393620
##        75        76        77        78        79        80        81
## 47.808223 27.296823 45.745559 47.594972 29.738424 39.405537 22.018267
##        82        83        84        85        86        87        88
## 27.112875 29.457939 41.091646 28.080651 36.046999 52.240951 41.482610
##        89        90        91        92        93        94        95
## 38.731896 34.377219 50.574094 49.318790 49.075053 52.846606 35.132022
##        96        97        98        99       100       101       102
## 27.154165 22.987527 39.534244 37.250424 42.809951 19.012029 32.969458
##       103       104       105       106       107       108       109
## 20.414520 43.778485 35.387305 29.518457 17.839849 16.502462 15.579532
##       110       111       112       113       114       115       116
## 19.154684 17.308153  9.477999  9.477999 15.465475 23.268506 23.288446
##       117
## 17.866250``````
``(bodyfat\$DEXfat - fmod\$fitted.values)^2 %>% sum()``
``## [1] 838.7505``
``````data_train <- bodyfat[, c("hipcirc","anthro3a", "kneebreadth", "DEXfat")]
dplyr::glimpse(data_train)``````
``````## Observations: 71
## Variables: 4
## \$ hipcirc     <dbl> 112.0, 116.5, 108.5, 96.5, 100.5, 97.0, 103.0, 105...
## \$ anthro3a    <dbl> 4.42, 4.63, 4.12, 4.03, 4.24, 3.55, 4.14, 4.04, 3....
## \$ kneebreadth <dbl> 9.4, 8.9, 8.9, 9.2, 10.0, 8.8, 8.9, 8.5, 8.8, 8.8,...
## \$ DEXfat      <dbl> 41.68, 43.29, 35.41, 22.79, 36.42, 24.13, 29.83, 3...``````
``````#####################################################################

row_data <- nrow(data_train)
col_data <- ncol(data_train)

x <- data_train[, -col_data]
y <- data_train[, col_data]

(slice_point <- slice_point_fun(x[, 1]))``````
``````##  [1]  88.15  89.15  90.50  91.50  92.10  92.20  92.60  93.00  93.10  93.60
## [11]  94.00  94.25  94.75  95.00  95.00  95.50  96.25  96.75  97.00  97.00
## [21]  98.00  99.00  99.00  99.15  99.40  99.75 100.00 100.25 100.75 101.10
## [31] 101.35 101.75 102.00 102.15 102.65 103.00 103.30 103.95 104.65 105.50
## [41] 106.75 107.50 107.60 107.85 108.25 108.75 109.00 109.00 109.00 109.00
## [51] 109.25 109.90 111.15 112.00 112.50 114.50 116.15 116.40 116.75 117.75
## [61] 120.25 122.25 122.50 123.25 124.50 125.00 125.50 126.50 127.75 130.25``````
``````tree <- best_slicepoint(slice_point, x[, 1], y)    ####################################################

square_loss_error <- Inf
eps <- 500 # 可以接受的平方损失误差底线
k <- 1
final_TREE <- list()
while (square_loss_error > eps) {
print(paste0("k= ",k))
VAR <- list()
for (kk in seq(col_data - 1)) {
slice_point <- slice_point_fun(x[, kk])
Tree <- best_slicepoint(slice_point, x[, kk], y)
VAR[[kk]] <- Tree
}

all_minvalue <- modify_depth(VAR, 1, "minvalue") %>% as_vector()
idx_minvalue <- all_minvalue %>% which.min()  #是哪个变量
tree <- VAR[[idx_minvalue]]

final_TREE[[k]] <- tree
final_TREE[[k]][["idx_var_minvalue"]] <- idx_minvalue

square_loss_error <- tree\$minvalue
TREE <- tree\$best_slice_point
TREE_value <- tree\$c

print(paste0("square_loss_error = ", square_loss_error))
print(paste0("TREE_slice_point = ", TREE))
print(paste0("TREE_value = ", TREE_value))
x <- x
y <- tree\$residuals
k <- k + 1
}``````
``````## [1] "k= 1"
## [1] "square_loss_error = 3197.3"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 24.19" "TREE_value = 42.19"
## [1] "k= 2"
## [1] "square_loss_error = 2394.98"
## [1] "TREE_slice_point = 3.52"
## [1] "TREE_value = -6.23" "TREE_value = 1.81"
## [1] "k= 3"
## [1] "square_loss_error = 1774.28"
## [1] "TREE_slice_point = 4.635"
## [1] "TREE_value = -0.5"  "TREE_value = 17.37"
## [1] "k= 4"
## [1] "square_loss_error = 1478.86"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 1.55"  "TREE_value = -2.68"
## [1] "k= 5"
## [1] "square_loss_error = 1145.53"
## [1] "TREE_slice_point = 99.4"
## [1] "TREE_value = -2.94" "TREE_value = 1.6"
## [1] "k= 6"
## [1] "square_loss_error = 1040.36"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 0.92" "TREE_value = -1.6"
## [1] "k= 7"
## [1] "square_loss_error = 911.36"
## [1] "TREE_slice_point = 4.14"
## [1] "TREE_value = -0.81" "TREE_value = 2.23"
## [1] "k= 8"
## [1] "square_loss_error = 860.6"
## [1] "TREE_slice_point = 2.4"
## [1] "TREE_value = -4.97" "TREE_value = 0.14"
## [1] "k= 9"
## [1] "square_loss_error = 801.02"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 0.7"  "TREE_value = -1.2"
## [1] "k= 10"
## [1] "square_loss_error = 695.07"
## [1] "TREE_slice_point = 114.5"
## [1] "TREE_value = -0.64" "TREE_value = 2.36"
## [1] "k= 11"
## [1] "square_loss_error = 645.62"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 0.64" "TREE_value = -1.1"
## [1] "k= 12"
## [1] "square_loss_error = 608.61"
## [1] "TREE_slice_point = 91.5"
## [1] "TREE_value = -2.96" "TREE_value = 0.18"
## [1] "k= 13"
## [1] "square_loss_error = 586.56"
## [1] "TREE_slice_point = 127.75"
## [1] "TREE_value = -0.1" "TREE_value = 3.27"
## [1] "k= 14"
## [1] "square_loss_error = 564.35"
## [1] "TREE_slice_point = 3.52"
## [1] "TREE_value = 1.04" "TREE_value = -0.3"
## [1] "k= 15"
## [1] "square_loss_error = 538.65"
## [1] "TREE_slice_point = 9.75"
## [1] "TREE_value = -0.39" "TREE_value = 0.93"
## [1] "k= 16"
## [1] "square_loss_error = 522.87"
## [1] "TREE_slice_point = 3.11"
## [1] "TREE_value = -1.93" "TREE_value = 0.11"
## [1] "k= 17"
## [1] "square_loss_error = 504.79"
## [1] "TREE_slice_point = 3.21"
## [1] "TREE_value = 1.66"  "TREE_value = -0.15"
## [1] "k= 18"
## [1] "square_loss_error = 493.09"
## [1] "TREE_slice_point = 3.11"
## [1] "TREE_value = -1.66" "TREE_value = 0.1"``````
``````# final_TREE

# x是多维的, x[1, ]
final_tree_func <- function(x) {
x <- as.numeric(x)
sum <- 0
len_x <- length(x)
vars <- modify_depth(final_TREE, 1, "idx_var_minvalue") %>%
as_vector()
for (k in seq(len_x)) {
idx <- which(vars == k)
if (length(idx) > 0) {
y <- vector("numeric", length = length(idx))
for (i in seq(length(idx))) {
if (x[k] < final_TREE[[ idx[i] ]][["best_slice_point"]]) {
y[i] <- final_TREE[[ idx[i] ]][["c"]][1]
} else {
y[i] <- final_TREE[[ idx[i] ]][["c"]][2]
}
}
sum_y <- sum(y)
} else {
sum_y <- 0
}
sum <- sum_y + sum
}
return(sum)
}

final_tree_func(x[1, ]); bodyfat\$DEXfat[1]``````
``## [1] 39.7``
``## [1] 41.68``
``final_tree_func(x[5, ]); bodyfat\$DEXfat[5]``
``## [1] 33.41``
``## [1] 36.42``
``final_tree_func(x[60, ]); bodyfat\$DEXfat[60]``
``## [1] 29.05``
``## [1] 28.98``
``````predict_fun <- function(x) {
stopifnot(is.matrix(x))
leng_x <- nrow(x)
predict_value <- vector("numeric", length = leng_x)
for (i in seq(leng_x)) {
predict_value[i] <- final_tree_func(x[i, ])
}
return(predict_value)
}
predict_fun(as.matrix(x))``````
``````##  [1] 39.70 42.70 36.66 24.51 33.41 24.51 32.09 29.05 24.51 24.51 24.51
## [12] 27.55 24.51 21.37 29.05 39.70 29.05 44.02 39.70 21.37 32.09 24.51
## [23] 29.05 29.05 19.62 33.41 36.66 17.81 44.02 29.05 44.02 44.02 29.05
## [34] 39.70 24.16 29.05 30.37 40.98 25.83 41.02 61.89 40.98 41.02 30.37
## [45] 47.39 44.02 47.39 61.89 36.66 29.05 24.51 41.02 39.66 40.98 17.81
## [56] 29.05 17.81 44.02 36.66 29.05 17.81 17.81 14.67 17.81 15.82 15.82
## [67] 15.82 15.82 22.35 22.35 14.67``````