Chapter 8 提升方法

chap8_1

chap8_1

chap8_2

chap8_2

rm(list = ls())
library(purrr)
slice_point_fun <- function(x_little) {
  stopifnot(is.vector(x_little))
  x_little <- x_little[order(x_little)]
  len_x <- length(x_little)
  slice_point <- vector("numeric", 
                        length = len_x - 1)
  for (i in seq(len_x - 1)) {
    slice_point[i] <- (x_little[i] + 
                         x_little[i + 1]) / 2
  }
  return(slice_point)
}
# slice_point <- slice_point_fun(x[, 1])

Error <- function(classifier_slice_point, sample_weight, x) {
  len_x <- length(sample_weight)
  yi_1 <- vector("integer", length = len_x)
  # Gx=1, when x < classifier_slice_point
  idx_1 <- which(x < classifier_slice_point)
  yi_1[idx_1] <- 1
  yi_1[-idx_1] <- -1
  error_less <- sample_weight[which(yi_1 != y)] %>% 
    sum(.)
  
  # Gx=1, when x > classifier_slice_point
  yi_2 <- vector("integer", length = len_x)
  idx_2 <- which(x > classifier_slice_point)
  yi_2[idx_2] <- 1
  yi_2[-idx_2] <- -1
  error_more <- sample_weight[which(yi_2 != y)] %>% 
    sum(.)
  
  error <- min(c(error_less, error_more))
  error_idx <- which.min(c(error_less, error_more))
  error_list <- list(error = error, 
                     classifier = ifelse(error_idx == 1, 
                                         "less than", 
                                         "more than"))
  return(error_list)
}

# sample_weight <- rep(1 / row_data, row_data)
# Error(2.0, sample_weight, x[, 3])    ####################################################
# x[order(x[, 3]),3]


G <- function(classifier_slice_point, 
              sample_weight, 
              sample_point_x, x) {

  classifier <- Error(classifier_slice_point, 
                      sample_weight, x)[["classifier"]]
  if (classifier == "less than") {
    G_value <- ifelse(
      sample_point_x < classifier_slice_point, 
      1, -1)
  } else {
    G_value <- ifelse(
      sample_point_x > classifier_slice_point, 
      1, -1)
  }
  return(G_value)
}

#G(2.5, sample_weight, x[1, 1], x[, 1])

Z <- function(slice_point_error_min, 
              sample_weight, x, 
              alpha, y) {
  
  len_x <- length(x)
  Z_value <- 0
  for (i in seq(len_x)) {
    Gi <- G(slice_point_error_min, 
            sample_weight, x[i], x)
    zi <- sample_weight[i] * exp(-alpha * y[i] * Gi)
    Z_value <-  Z_value + zi
  }
  Z_value <- round(Z_value, 4)
  return(Z_value)
} 

#Z(2.5, sample_weight, x[, 1], 0.4236, y)

final <- function(x, y) {
  row_data <- nrow(x)
  col_data <- ncol(x) + 1
  sample_weight <- rep(1 / row_data, row_data)
  miss_classify <- Inf
  f_value_vec <- vector("numeric", length = row_data)
  k <- 1
  while (miss_classify > 0) {
    print(paste0("k= " ,k))
    print(paste0("sample_weight= " ,sample_weight))
    var_ <- list()
    for (kk in seq(col_data - 1)) {
      #kk <- 1
      slice_point <- slice_point_fun(x[, kk])
      len_slice_point <- length(slice_point) # 
      sample_weight_list <- vector("list", 
                                   length = len_slice_point)
      for (i in seq(len_slice_point)) {
        sample_weight_list[[i]] <- sample_weight
      }
      x_kk <- vector("list", length = len_slice_point)
      for (j in seq(len_slice_point)) {
        x_kk[[j]] <- x[, kk]
      }
      all_error_i <- pmap(list(slice_point, 
                               sample_weight_list, 
                               x_kk), Error)
      vec_all_error_i <- modify_depth(all_error_i, 1, "error") %>% 
        as_vector(.)
      error_min_i <- vec_all_error_i %>% 
        min(.) #
      print(paste0("error_min_i = ", error_min_i))
      error_min_idx_i <- which(vec_all_error_i == error_min_i)[1] #
      slice_point_error_min_i <- slice_point[error_min_idx_i] #
      
      var_[[kk]] <- list(error_min_i = error_min_i,
                         error_min_idx_i = error_min_idx_i,
                         slice_point_error_min_i = slice_point_error_min_i,
                         slice_point = slice_point)
    }
    vec_all_error <- modify_depth(var_, 1, "error_min_i") %>% 
      as_vector(.)
    error_min <- vec_all_error %>% 
      min(.) #
    error_min_var_idx <- which(vec_all_error == error_min)[1] #  是哪个变量
    
    slice_point_error_min <- 
      var_[[error_min_var_idx]][["slice_point_error_min_i"]]
    
    alpha <- 1 / 2 * log((1 - error_min) / 
                           (error_min + 0.00001), base = exp(1))
    alpha <- round(alpha, 4)
    print(paste0("alpha=", alpha))
    miss_classify <- 0
    k_var_idx <- error_min_var_idx
    Z_value <- Z(slice_point_error_min, 
                 sample_weight, x[, k_var_idx], 
                 alpha, y)  
    for (i in seq(row_data)) {
      #print(paste0("i = ", i))
      G_value <- G(slice_point_error_min, 
                   sample_weight, 
                   x[i, k_var_idx], x[, k_var_idx])
      f_value <- alpha * G_value + f_value_vec[i]
      miss <- ifelse(sign(f_value) != y[i], 1, 0)
      miss_classify <- miss_classify + miss 
      f_value_vec[i] <- f_value
  
      exp_value <- exp(-alpha * y[i] * G_value)
      sample_weight[i] <- round(sample_weight[i] /
                                  Z_value * exp_value , 5)
    }
    print(paste0("miss_classify= ", miss_classify))
    k <- k + 1
  }
}

########
x <- c(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) %>% matrix()
y <- c(1, 1, 1, -1, -1, -1, 1, 1, 1, -1)
final(x = x, y = y)
## [1] "k= 1"
##  [1] "sample_weight= 0.1" "sample_weight= 0.1" "sample_weight= 0.1"
##  [4] "sample_weight= 0.1" "sample_weight= 0.1" "sample_weight= 0.1"
##  [7] "sample_weight= 0.1" "sample_weight= 0.1" "sample_weight= 0.1"
## [10] "sample_weight= 0.1"
## [1] "error_min_i = 0.3"
## [1] "alpha=0.4236"
## [1] "miss_classify= 3"
## [1] "k= 2"
##  [1] "sample_weight= 0.07143" "sample_weight= 0.07143"
##  [3] "sample_weight= 0.07143" "sample_weight= 0.07143"
##  [5] "sample_weight= 0.07143" "sample_weight= 0.07143"
##  [7] "sample_weight= 0.16666" "sample_weight= 0.16666"
##  [9] "sample_weight= 0.16666" "sample_weight= 0.07143"
## [1] "error_min_i = 0.21429"
## [1] "alpha=0.6496"
## [1] "miss_classify= 3"
## [1] "k= 3"
##  [1] "sample_weight= 0.04545" "sample_weight= 0.04545"
##  [3] "sample_weight= 0.04545" "sample_weight= 0.16665"
##  [5] "sample_weight= 0.16665" "sample_weight= 0.16665"
##  [7] "sample_weight= 0.10605" "sample_weight= 0.10605"
##  [9] "sample_weight= 0.10605" "sample_weight= 0.04545"
## [1] "error_min_i = 0.1818"
## [1] "alpha=0.7521"
## [1] "miss_classify= 0"
# iris_1 <- iris[1:100, ]
# iris_1[, 5] <- NULL
# iris_1[["labels"]] <- c(rep(1, 50), rep(-1, 50))
# head(iris_1)
# x_iris <- iris_1[, 1:4]
# y_iris <- iris_1[, 5]
# row_data <- nrow(x)
# col_data <- ncol(x) + 1
# final(x = x_iris, y = y_iris)
chap8_3

chap8_3

x <- c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
y <- c(5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05)
len_x <- length(x)
slice_point <- vector("numeric", length = len_x - 1)
for (i in seq(len_x - 1)) {
  slice_point[i] <- (x[i] + x[i + 1]) / 2
}
slice_point
## [1] 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
#sample_weight <- rep(1 / len_x, len_x)

best_slicepoint <- function(vector_slice_point, x, y) {
  len_x <- length(x)
  len_vsp <- length(vector_slice_point)
  ms <- vector("numeric", length = len_vsp)
  for (i in seq(len_vsp)) {
    idx_1 <- which(x <= vector_slice_point[i])
    c1 <- y[idx_1] %>% mean(.)
    c2 <- y[-idx_1] %>% mean(.)
    ms[i] <- sum((y[idx_1] - c1)^2) + sum((y[-idx_1] - c2)^2) 
  }
  idx_minvalue <- which.min(ms)
  minvalue <- ms[idx_minvalue] %>% round(2)
  best_slice_point <- vector_slice_point[idx_minvalue]
  R1 <- subset(x, x <= best_slice_point)
  R2 <- subset(x, x > best_slice_point)
  idx_R1 <- which(x %in% R1)
  c1_R1 <- y[idx_R1] %>% mean(.) %>% round(2)
  c2_R1 <- y[-idx_R1] %>% mean(.) %>% round(2)
  
  # x 可以不用先排好序,也可以直接输入
  residuals <- vector("numeric", length = len_x)
  for (i in seq(len_x)) {
    if (i %in% idx_R1) {
      residuals[i] <- y[i] - c1_R1
    } else {
      residuals[i] <- y[i] - c2_R1
    }
  }
  return_list <- list(best_slice_point = best_slice_point, 
                      minvalue = minvalue,
                      R = list(R1 = R1, R2 = R2),
                      c = c(c1_R1, c2_R1),
                      residuals = residuals)
  return(return_list)
}
tree <- best_slicepoint(slice_point,x,y)    ####################################################

# x 可以不用先排好序,也可以直接输入
# x1 <- c(8, 5, 2, 3, 4, 6, 7, 9, 10, 1)
# y1 <- c(8.70, 6.80, 5.70, 5.91, 6.40,7.05, 8.90, 9.00, 9.05, 5.56)
# best_slicepoint(slice_point, x1, y1)

# best_slicepoint(slice_point, x, tree$residuals)

x <- c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
y <- c(5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05)
square_loss_error <- Inf
eps <- 0.17
k <- 1
final_TREE <- list()
while (square_loss_error > eps) {
  print(paste0("k= ",k))
  tree <- best_slicepoint(slice_point, x, y)
  final_TREE[[k]] <- tree
  square_loss_error <- tree$minvalue
  TREE <- tree$best_slice_point
  TREE_value <- tree$c
  print(paste0("square_loss_error = ", square_loss_error))
  print(paste0("TREE_slice_point = ", TREE))
  print(paste0("TREE_value = ", TREE_value))
  x <- x
  y <- tree$residuals
  k <- k + 1
}
## [1] "k= 1"
## [1] "square_loss_error = 1.93"
## [1] "TREE_slice_point = 6.5"
## [1] "TREE_value = 6.24" "TREE_value = 8.91"
## [1] "k= 2"
## [1] "square_loss_error = 0.79"
## [1] "TREE_slice_point = 3.5"
## [1] "TREE_value = -0.52" "TREE_value = 0.22" 
## [1] "k= 3"
## [1] "square_loss_error = 0.47"
## [1] "TREE_slice_point = 6.5"
## [1] "TREE_value = 0.15"  "TREE_value = -0.22"
## [1] "k= 4"
## [1] "square_loss_error = 0.3"
## [1] "TREE_slice_point = 4.5"
## [1] "TREE_value = -0.16" "TREE_value = 0.11" 
## [1] "k= 5"
## [1] "square_loss_error = 0.23"
## [1] "TREE_slice_point = 6.5"
## [1] "TREE_value = 0.07"  "TREE_value = -0.11"
## [1] "k= 6"
## [1] "square_loss_error = 0.17"
## [1] "TREE_slice_point = 2.5"
## [1] "TREE_value = -0.15" "TREE_value = 0.04"
final_TREE
## [[1]]
## [[1]]$best_slice_point
## [1] 6.5
## 
## [[1]]$minvalue
## [1] 1.93
## 
## [[1]]$R
## [[1]]$R$R1
## [1] 1 2 3 4 5 6
## 
## [[1]]$R$R2
## [1]  7  8  9 10
## 
## 
## [[1]]$c
## [1] 6.24 8.91
## 
## [[1]]$residuals
##  [1] -0.68 -0.54 -0.33  0.16  0.56  0.81 -0.01 -0.21  0.09  0.14
## 
## 
## [[2]]
## [[2]]$best_slice_point
## [1] 3.5
## 
## [[2]]$minvalue
## [1] 0.79
## 
## [[2]]$R
## [[2]]$R$R1
## [1] 1 2 3
## 
## [[2]]$R$R2
## [1]  4  5  6  7  8  9 10
## 
## 
## [[2]]$c
## [1] -0.52  0.22
## 
## [[2]]$residuals
##  [1] -0.16 -0.02  0.19 -0.06  0.34  0.59 -0.23 -0.43 -0.13 -0.08
## 
## 
## [[3]]
## [[3]]$best_slice_point
## [1] 6.5
## 
## [[3]]$minvalue
## [1] 0.47
## 
## [[3]]$R
## [[3]]$R$R1
## [1] 1 2 3 4 5 6
## 
## [[3]]$R$R2
## [1]  7  8  9 10
## 
## 
## [[3]]$c
## [1]  0.15 -0.22
## 
## [[3]]$residuals
##  [1] -0.31 -0.17  0.04 -0.21  0.19  0.44 -0.01 -0.21  0.09  0.14
## 
## 
## [[4]]
## [[4]]$best_slice_point
## [1] 4.5
## 
## [[4]]$minvalue
## [1] 0.3
## 
## [[4]]$R
## [[4]]$R$R1
## [1] 1 2 3 4
## 
## [[4]]$R$R2
## [1]  5  6  7  8  9 10
## 
## 
## [[4]]$c
## [1] -0.16  0.11
## 
## [[4]]$residuals
##  [1] -0.15 -0.01  0.20 -0.05  0.08  0.33 -0.12 -0.32 -0.02  0.03
## 
## 
## [[5]]
## [[5]]$best_slice_point
## [1] 6.5
## 
## [[5]]$minvalue
## [1] 0.23
## 
## [[5]]$R
## [[5]]$R$R1
## [1] 1 2 3 4 5 6
## 
## [[5]]$R$R2
## [1]  7  8  9 10
## 
## 
## [[5]]$c
## [1]  0.07 -0.11
## 
## [[5]]$residuals
##  [1] -0.22 -0.08  0.13 -0.12  0.01  0.26 -0.01 -0.21  0.09  0.14
## 
## 
## [[6]]
## [[6]]$best_slice_point
## [1] 2.5
## 
## [[6]]$minvalue
## [1] 0.17
## 
## [[6]]$R
## [[6]]$R$R1
## [1] 1 2
## 
## [[6]]$R$R2
## [1]  3  4  5  6  7  8  9 10
## 
## 
## [[6]]$c
## [1] -0.15  0.04
## 
## [[6]]$residuals
##  [1] -0.07  0.07  0.09 -0.16 -0.03  0.22 -0.05 -0.25  0.05  0.10
final_tree_func <- function(x) {
  y <- vector("numeric", length = length(final_TREE))
  for (i in seq(length(final_TREE))) {
    if (x < final_TREE[[i]][["best_slice_point"]]) {
      y[i] <- final_TREE[[i]][["c"]][1]
    } else {
      y[i] <- final_TREE[[i]][["c"]][2]
    }
  }
  sum_y <- sum(y)
  return(sum_y)
}

final_tree_func(3.8)
## [1] 6.56
final_tree_func(2.5)
## [1] 5.82
final_tree_func(100)
## [1] 8.95
predict_fun <- function(x) {
  stopifnot(is.vector(x))
  leng_x <- length(x)
  predict_value <- vector("numeric", length = leng_x)
  for (i in seq(leng_x)) {
    predict_value[i] <- final_tree_func(x[i])
  }
  return(predict_value)
}
predict_fun(seq(1, 20, by = 0.2))
##  [1] 5.63 5.63 5.63 5.63 5.63 5.63 5.63 5.63 5.82 5.82 5.82 5.82 5.82 6.56
## [15] 6.56 6.56 6.56 6.56 6.83 6.83 6.83 6.83 6.83 6.83 6.83 6.83 6.83 6.83
## [29] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [43] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [57] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [71] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## [85] 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95 8.95
## 扩展
require(TH.data)
## Loading required package: TH.data
## Loading required package: survival
## Loading required package: MASS
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select
## 
## Attaching package: 'TH.data'
## The following object is masked from 'package:MASS':
## 
##     geyser
dplyr::glimpse(bodyfat)
## Observations: 71
## Variables: 10
## $ age          <dbl> 57, 65, 59, 58, 60, 61, 56, 60, 58, 62, 63, 62, 6...
## $ DEXfat       <dbl> 41.68, 43.29, 35.41, 22.79, 36.42, 24.13, 29.83, ...
## $ waistcirc    <dbl> 100.0, 99.5, 96.0, 72.0, 89.5, 83.5, 81.0, 89.0, ...
## $ hipcirc      <dbl> 112.0, 116.5, 108.5, 96.5, 100.5, 97.0, 103.0, 10...
## $ elbowbreadth <dbl> 7.1, 6.5, 6.2, 6.1, 7.1, 6.5, 6.9, 6.2, 6.4, 7.0,...
## $ kneebreadth  <dbl> 9.4, 8.9, 8.9, 9.2, 10.0, 8.8, 8.9, 8.5, 8.8, 8.8...
## $ anthro3a     <dbl> 4.42, 4.63, 4.12, 4.03, 4.24, 3.55, 4.14, 4.04, 3...
## $ anthro3b     <dbl> 4.95, 5.01, 4.74, 4.48, 4.68, 4.06, 4.52, 4.70, 4...
## $ anthro3c     <dbl> 4.50, 4.48, 4.60, 3.91, 4.15, 3.64, 4.31, 4.47, 3...
## $ anthro4      <dbl> 6.13, 6.37, 5.82, 5.66, 5.91, 5.14, 5.69, 5.70, 5...
data("bodyfat", package = "TH.data")

### final model proposed by Garcia et al. (2005)
fmod <- lm(DEXfat ~ hipcirc + anthro3a + kneebreadth, data = bodyfat)
predict(fmod, bodyfat)
##        47        48        49        50        51        52        53 
## 39.315480 42.537378 33.901251 27.531664 32.970386 22.750005 31.266049 
##        54        55        56        57        58        59        60 
## 30.637342 25.957475 21.683960 24.729106 26.281913 24.595095 19.542765 
##        61        62        63        64        65        66        67 
## 31.839068 38.161299 28.242285 41.339475 37.800842 22.799220 33.086417 
##        68        69        70        71        72        73        74 
## 23.194527 27.160642 25.726705 19.917902 34.546497 34.448555 22.393620 
##        75        76        77        78        79        80        81 
## 47.808223 27.296823 45.745559 47.594972 29.738424 39.405537 22.018267 
##        82        83        84        85        86        87        88 
## 27.112875 29.457939 41.091646 28.080651 36.046999 52.240951 41.482610 
##        89        90        91        92        93        94        95 
## 38.731896 34.377219 50.574094 49.318790 49.075053 52.846606 35.132022 
##        96        97        98        99       100       101       102 
## 27.154165 22.987527 39.534244 37.250424 42.809951 19.012029 32.969458 
##       103       104       105       106       107       108       109 
## 20.414520 43.778485 35.387305 29.518457 17.839849 16.502462 15.579532 
##       110       111       112       113       114       115       116 
## 19.154684 17.308153  9.477999  9.477999 15.465475 23.268506 23.288446 
##       117 
## 17.866250
(bodyfat$DEXfat - fmod$fitted.values)^2 %>% sum()
## [1] 838.7505
data_train <- bodyfat[, c("hipcirc","anthro3a", "kneebreadth", "DEXfat")]
dplyr::glimpse(data_train)
## Observations: 71
## Variables: 4
## $ hipcirc     <dbl> 112.0, 116.5, 108.5, 96.5, 100.5, 97.0, 103.0, 105...
## $ anthro3a    <dbl> 4.42, 4.63, 4.12, 4.03, 4.24, 3.55, 4.14, 4.04, 3....
## $ kneebreadth <dbl> 9.4, 8.9, 8.9, 9.2, 10.0, 8.8, 8.9, 8.5, 8.8, 8.8,...
## $ DEXfat      <dbl> 41.68, 43.29, 35.41, 22.79, 36.42, 24.13, 29.83, 3...
#####################################################################

row_data <- nrow(data_train)
col_data <- ncol(data_train)

x <- data_train[, -col_data]
y <- data_train[, col_data]

(slice_point <- slice_point_fun(x[, 1]))
##  [1]  88.15  89.15  90.50  91.50  92.10  92.20  92.60  93.00  93.10  93.60
## [11]  94.00  94.25  94.75  95.00  95.00  95.50  96.25  96.75  97.00  97.00
## [21]  98.00  99.00  99.00  99.15  99.40  99.75 100.00 100.25 100.75 101.10
## [31] 101.35 101.75 102.00 102.15 102.65 103.00 103.30 103.95 104.65 105.50
## [41] 106.75 107.50 107.60 107.85 108.25 108.75 109.00 109.00 109.00 109.00
## [51] 109.25 109.90 111.15 112.00 112.50 114.50 116.15 116.40 116.75 117.75
## [61] 120.25 122.25 122.50 123.25 124.50 125.00 125.50 126.50 127.75 130.25
tree <- best_slicepoint(slice_point, x[, 1], y)    ####################################################

square_loss_error <- Inf
eps <- 500 # 可以接受的平方损失误差底线
k <- 1
final_TREE <- list()
while (square_loss_error > eps) {
  print(paste0("k= ",k))
  VAR <- list()
  for (kk in seq(col_data - 1)) {
    slice_point <- slice_point_fun(x[, kk])
    Tree <- best_slicepoint(slice_point, x[, kk], y)
    VAR[[kk]] <- Tree
  }
  
  all_minvalue <- modify_depth(VAR, 1, "minvalue") %>% as_vector()
  idx_minvalue <- all_minvalue %>% which.min()  #是哪个变量
  tree <- VAR[[idx_minvalue]]
  
  final_TREE[[k]] <- tree
  final_TREE[[k]][["idx_var_minvalue"]] <- idx_minvalue
  
  square_loss_error <- tree$minvalue
  TREE <- tree$best_slice_point
  TREE_value <- tree$c
  
  print(paste0("square_loss_error = ", square_loss_error))
  print(paste0("TREE_slice_point = ", TREE))
  print(paste0("TREE_value = ", TREE_value))
  x <- x
  y <- tree$residuals
  k <- k + 1
}
## [1] "k= 1"
## [1] "square_loss_error = 3197.3"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 24.19" "TREE_value = 42.19"
## [1] "k= 2"
## [1] "square_loss_error = 2394.98"
## [1] "TREE_slice_point = 3.52"
## [1] "TREE_value = -6.23" "TREE_value = 1.81" 
## [1] "k= 3"
## [1] "square_loss_error = 1774.28"
## [1] "TREE_slice_point = 4.635"
## [1] "TREE_value = -0.5"  "TREE_value = 17.37"
## [1] "k= 4"
## [1] "square_loss_error = 1478.86"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 1.55"  "TREE_value = -2.68"
## [1] "k= 5"
## [1] "square_loss_error = 1145.53"
## [1] "TREE_slice_point = 99.4"
## [1] "TREE_value = -2.94" "TREE_value = 1.6"  
## [1] "k= 6"
## [1] "square_loss_error = 1040.36"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 0.92" "TREE_value = -1.6"
## [1] "k= 7"
## [1] "square_loss_error = 911.36"
## [1] "TREE_slice_point = 4.14"
## [1] "TREE_value = -0.81" "TREE_value = 2.23" 
## [1] "k= 8"
## [1] "square_loss_error = 860.6"
## [1] "TREE_slice_point = 2.4"
## [1] "TREE_value = -4.97" "TREE_value = 0.14" 
## [1] "k= 9"
## [1] "square_loss_error = 801.02"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 0.7"  "TREE_value = -1.2"
## [1] "k= 10"
## [1] "square_loss_error = 695.07"
## [1] "TREE_slice_point = 114.5"
## [1] "TREE_value = -0.64" "TREE_value = 2.36" 
## [1] "k= 11"
## [1] "square_loss_error = 645.62"
## [1] "TREE_slice_point = 108.25"
## [1] "TREE_value = 0.64" "TREE_value = -1.1"
## [1] "k= 12"
## [1] "square_loss_error = 608.61"
## [1] "TREE_slice_point = 91.5"
## [1] "TREE_value = -2.96" "TREE_value = 0.18" 
## [1] "k= 13"
## [1] "square_loss_error = 586.56"
## [1] "TREE_slice_point = 127.75"
## [1] "TREE_value = -0.1" "TREE_value = 3.27"
## [1] "k= 14"
## [1] "square_loss_error = 564.35"
## [1] "TREE_slice_point = 3.52"
## [1] "TREE_value = 1.04" "TREE_value = -0.3"
## [1] "k= 15"
## [1] "square_loss_error = 538.65"
## [1] "TREE_slice_point = 9.75"
## [1] "TREE_value = -0.39" "TREE_value = 0.93" 
## [1] "k= 16"
## [1] "square_loss_error = 522.87"
## [1] "TREE_slice_point = 3.11"
## [1] "TREE_value = -1.93" "TREE_value = 0.11" 
## [1] "k= 17"
## [1] "square_loss_error = 504.79"
## [1] "TREE_slice_point = 3.21"
## [1] "TREE_value = 1.66"  "TREE_value = -0.15"
## [1] "k= 18"
## [1] "square_loss_error = 493.09"
## [1] "TREE_slice_point = 3.11"
## [1] "TREE_value = -1.66" "TREE_value = 0.1"
# final_TREE

# x是多维的, x[1, ]
final_tree_func <- function(x) {
  x <- as.numeric(x)
  sum <- 0
  len_x <- length(x)
  vars <- modify_depth(final_TREE, 1, "idx_var_minvalue") %>% 
    as_vector()
  for (k in seq(len_x)) {
    idx <- which(vars == k)
    if (length(idx) > 0) {
      y <- vector("numeric", length = length(idx))
      for (i in seq(length(idx))) {
        if (x[k] < final_TREE[[ idx[i] ]][["best_slice_point"]]) {
          y[i] <- final_TREE[[ idx[i] ]][["c"]][1]
        } else {
          y[i] <- final_TREE[[ idx[i] ]][["c"]][2]
        }
      }
      sum_y <- sum(y)
    } else {
      sum_y <- 0
    }
    sum <- sum_y + sum
  }
  return(sum)
}

final_tree_func(x[1, ]); bodyfat$DEXfat[1]
## [1] 39.7
## [1] 41.68
final_tree_func(x[5, ]); bodyfat$DEXfat[5]
## [1] 33.41
## [1] 36.42
final_tree_func(x[60, ]); bodyfat$DEXfat[60]
## [1] 29.05
## [1] 28.98
predict_fun <- function(x) {
  stopifnot(is.matrix(x))
  leng_x <- nrow(x)
  predict_value <- vector("numeric", length = leng_x)
  for (i in seq(leng_x)) {
    predict_value[i] <- final_tree_func(x[i, ])
  }
  return(predict_value)
}
predict_fun(as.matrix(x))
##  [1] 39.70 42.70 36.66 24.51 33.41 24.51 32.09 29.05 24.51 24.51 24.51
## [12] 27.55 24.51 21.37 29.05 39.70 29.05 44.02 39.70 21.37 32.09 24.51
## [23] 29.05 29.05 19.62 33.41 36.66 17.81 44.02 29.05 44.02 44.02 29.05
## [34] 39.70 24.16 29.05 30.37 40.98 25.83 41.02 61.89 40.98 41.02 30.37
## [45] 47.39 44.02 47.39 61.89 36.66 29.05 24.51 41.02 39.66 40.98 17.81
## [56] 29.05 17.81 44.02 36.66 29.05 17.81 17.81 14.67 17.81 15.82 15.82
## [67] 15.82 15.82 22.35 22.35 14.67