第 61 章 机器学习

Rstudio工厂的 Max Kuhn 大神正主持机器学习的开发,日臻成熟了,感觉很强大啊。

61.1 数据

penguins <- read_csv("./demo_data/penguins.csv") %>%
  janitor::clean_names() %>% 
  drop_na()

penguins %>%
  head()
penguins %>%
  ggplot(aes(x = bill_length_mm, y = bill_depth_mm, 
             color = species, shape = species)
         ) +
  geom_point()

61.2 机器学习

split <- penguins %>% 
  mutate(species = as_factor(species)) %>% 
  mutate(species = fct_lump(species, 1)) %>% 
  initial_split()

split

training_data <- training(split)
training_data
testing_data <- testing(split)
testing_data

61.3 model01

model_logistic <- parsnip::logistic_reg() %>% 
  set_engine("glm") %>% 
  set_mode("classification") %>% 
  fit(species ~ bill_length_mm + bill_depth_mm, data = training_data)


bind_cols(
  predict(model_logistic, new_data = testing_data, type = "class"),
  predict(model_logistic, new_data = testing_data, type = "prob"),
  testing_data
)


predict(model_logistic, new_data = testing_data) %>% 
  bind_cols(testing_data) %>% 
  count(.pred_class, species)

61.4 model02

model_neighbor <- parsnip::nearest_neighbor(neighbors = 10) %>% 
  set_engine("kknn") %>% 
  set_mode("classification") %>% 
  fit(species ~ bill_length_mm, data = training_data)

predict(model_neighbor, new_data = testing_data) %>% 
  bind_cols(testing_data) %>% 
  count(.pred_class, species)

61.5 model03

model_multinom <- parsnip::multinom_reg() %>% 
  set_engine("nnet") %>% 
  set_mode("classification") %>% 
  fit(species ~ bill_length_mm, data = training_data)

predict(model_multinom, new_data = testing_data) %>% 
  bind_cols(testing_data) %>% 
  count(.pred_class, species)

61.6 model04

model_decision <- parsnip::decision_tree() %>% 
  set_engine("rpart") %>% 
  set_mode("classification") %>% 
  fit(species ~ bill_length_mm, data = training_data)

predict(model_decision, new_data = testing_data) %>% 
  bind_cols(testing_data) %>% 
  count(.pred_class, species)

61.7 workflow

61.7.1 使用 recipes

library(tidyverse)
library(tidymodels)
library(workflows)

penguins <- readr::read_csv("./demo_data/penguins.csv") %>%
  janitor::clean_names() 


split <- penguins %>% 
  tidyr::drop_na() %>% 
  rsample::initial_split(prop = 3/4)

training_data <- rsample::training(split)
testing_data  <- rsample::testing(split)

参考tidy modeling in R, 被预测变量在分割前,应该先处理,比如标准化。 但这里的案例,我为了偷懒,被预测变量bill_length_mm,暂时保留不变。 预测变量做标准处理。

penguins_lm <- 
  parsnip::linear_reg() %>% 
  #parsnip::set_engine("lm") 
  parsnip::set_engine("stan") 

penguins_recipe  <- 
  recipes::recipe(bill_length_mm ~ bill_depth_mm + sex, data = training_data) %>% 
  recipes::step_normalize(all_numeric(), -all_outcomes()) %>% 
  recipes::step_dummy(all_nominal())


broom::tidy(penguins_recipe)
## # A tibble: 2 × 6
##   number operation type      trained skip  id          
##    <int> <chr>     <chr>     <lgl>   <lgl> <chr>       
## 1      1 step      normalize FALSE   FALSE normalize_G…
## 2      2 step      dummy     FALSE   FALSE dummy_Gnn60
penguins_recipe %>% 
  recipes::prep(data = training_data) %>%  #or prep(retain = TRUE)
  recipes::juice()


penguins_recipe %>%   
  recipes::prep(data = training_data) %>% 
  recipes::bake(new_data = testing_data)   # recipe used in new_data



train_data <- 
  penguins_recipe %>%   
  recipes::prep(data = training_data) %>% 
  recipes::bake(new_data = NULL) 


test_data <- 
  penguins_recipe %>%   
  recipes::prep(data = training_data) %>% 
  recipes::bake(new_data = testing_data)   

61.7.2 workflows的思路更清晰

workflows的思路让模型结构更清晰。 这样prep(), bake(), and juice() 就可以省略了,只需要recipe和model,他们往往是成对出现的

wflow <- 
  workflows::workflow() %>% 
  workflows::add_recipe(penguins_recipe) %>% 
  workflows::add_model(penguins_lm) 


wflow_fit <- 
  wflow %>% 
  parsnip::fit(data = training_data)
wflow_fit %>% 
  workflows::pull_workflow_fit() %>% 
  broom.mixed::tidy()
## # A tibble: 3 × 3
##   term          estimate std.error
##   <chr>            <dbl>     <dbl>
## 1 (Intercept)      41.1      0.471
## 2 bill_depth_mm    -2.23     0.346
## 3 sex_male          5.70     0.669
wflow_fit %>% 
  workflows::pull_workflow_prepped_recipe() 
## Recipe
## 
## Inputs:
## 
##       role #variables
##    outcome          1
##  predictor          2
## 
## Training data contained 249 data points and no missing data.
## 
## Operations:
## 
## Centering and scaling for bill_depth_mm [trained]
## Dummy variables from sex [trained]

先提取模型,用在 predict() 是可以的,但这样太麻烦了

wflow_fit %>% 
  workflows::pull_workflow_fit() %>% 
  stats::predict(new_data = test_data) # note: test_data not testing_data

因为,predict() 会自动的将recipes(对training_data的操作),应用到testing_data 这个不错,参考这里

penguins_pred <- 
  predict(
    wflow_fit, 
    new_data = testing_data %>% dplyr::select(-bill_length_mm), # note: testing_data not test_data
    type = "numeric"
  ) %>% 
  dplyr::bind_cols(testing_data %>% dplyr::select(bill_length_mm))

penguins_pred
## # A tibble: 84 × 2
##   .pred bill_length_mm
##   <dbl>          <dbl>
## 1  45.1           39.1
## 2  39.9           37.8
## 3  45.2           40.6
## 4  41.7           39.5
## 5  40.4           39.5
## 6  44.9           40.9
## # … with 78 more rows
penguins_pred %>% 
  ggplot(aes(x = bill_length_mm, y = .pred)) + 
  geom_abline(linetype = 2) + 
  geom_point(alpha = 0.5) + 
  labs(y = "Predicted ", x = "bill_length_mm") 

augment()具有predict()一样的功能和特性,还更简练的多

wflow_fit %>%
  augment(new_data = testing_data) %>%       # note: testing_data not test_data
  ggplot(aes(x = bill_length_mm, y = .pred)) + 
  geom_abline(linetype = 2) + 
  geom_point(alpha = 0.5) + 
  labs(y = "Predicted ", x = "bill_length_mm") 

61.7.3 模型评估

参考https://www.tmwr.org/performance.html#regression-metrics

penguins_pred %>%
  yardstick::rmse(truth = bill_length_mm, estimate = .pred) 
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard        4.02

自定义一个指标评价函数my_multi_metric,就是放一起,感觉不够tidyverse

my_multi_metric <- yardstick::metric_set(rmse, rsq, mae, ccc)

penguins_pred %>%
  my_multi_metric(truth = bill_length_mm, estimate = .pred) 
## # A tibble: 4 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard       4.02 
## 2 rsq     standard       0.327
## 3 mae     standard       3.35 
## 4 ccc     standard       0.506