第 28 章 模型输出结果的规整

28.1 案例

还是用第 13 章的gapminder案例

library(tidyverse)
library(gapminder)
gapminder
## # A tibble: 1,704 x 6
##    country    continent  year lifeExp     pop gdpPercap
##    <fct>      <fct>     <int>   <dbl>   <int>     <dbl>
##  1 Afghanist~ Asia       1952    28.8  8.43e6      779.
##  2 Afghanist~ Asia       1957    30.3  9.24e6      821.
##  3 Afghanist~ Asia       1962    32.0  1.03e7      853.
##  4 Afghanist~ Asia       1967    34.0  1.15e7      836.
##  5 Afghanist~ Asia       1972    36.1  1.31e7      740.
##  6 Afghanist~ Asia       1977    38.4  1.49e7      786.
##  7 Afghanist~ Asia       1982    39.9  1.29e7      978.
##  8 Afghanist~ Asia       1987    40.8  1.39e7      852.
##  9 Afghanist~ Asia       1992    41.7  1.63e7      649.
## 10 Afghanist~ Asia       1997    41.8  2.22e7      635.
## # ... with 1,694 more rows

28.1.1 可视化探索

画个简单的图

gapminder %>%
  ggplot(aes(x = log(gdpPercap), y = lifeExp)) +
  geom_point(alpha = 0.2)

我们想用不同的模型拟合log(gdpPercap)lifeExp的关联

library(colorspace)

model_colors <- colorspace::qualitative_hcl(4, palette = "dark 2")
# model_colors <- c("darkorange", "purple", "cyan4")

ggplot(
  data = gapminder,
  mapping = aes(x = log(gdpPercap), y = lifeExp)
) +
  geom_point(alpha = 0.2) +
  geom_smooth(
    method = "lm",
    aes(color = "OLS", fill = "OLS") # one
  ) +
  geom_smooth(
    method = "lm", formula = y ~ splines::bs(x, df = 3),
    aes(color = "Cubic Spline", fill = "Cubic Spline") # two
  ) +
  geom_smooth(
    method = "loess",
    aes(color = "LOESS", fill = "LOESS") # three
  ) +
  scale_color_manual(name = "Models", values = model_colors) +
  scale_fill_manual(name = "Models", values = model_colors) +
  theme(legend.position = "top")

28.1.2 简单模型

还是回到我们今天的主题。我们建立一个简单的线性模型

out <- lm(
  formula = lifeExp ~ gdpPercap + pop + continent,
  data = gapminder
)
out
## 
## Call:
## lm(formula = lifeExp ~ gdpPercap + pop + continent, data = gapminder)
## 
## Coefficients:
##       (Intercept)          gdpPercap  
##          4.78e+01           4.50e-04  
##               pop  continentAmericas  
##          6.57e-09           1.35e+01  
##     continentAsia    continentEurope  
##          8.19e+00           1.75e+01  
##  continentOceania  
##          1.81e+01
str(out)
summary(out)
## 
## Call:
## lm(formula = lifeExp ~ gdpPercap + pop + continent, data = gapminder)
## 
## Residuals:
##    Min     1Q Median     3Q    Max 
## -49.16  -4.49   0.30   5.11  25.17 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)
## (Intercept)       4.78e+01   3.40e-01  140.82   <2e-16
## gdpPercap         4.50e-04   2.35e-05   19.16   <2e-16
## pop               6.57e-09   1.98e-09    3.33    9e-04
## continentAmericas 1.35e+01   6.00e-01   22.46   <2e-16
## continentAsia     8.19e+00   5.71e-01   14.34   <2e-16
## continentEurope   1.75e+01   6.25e-01   27.97   <2e-16
## continentOceania  1.81e+01   1.78e+00   10.15   <2e-16
##                      
## (Intercept)       ***
## gdpPercap         ***
## pop               ***
## continentAmericas ***
## continentAsia     ***
## continentEurope   ***
## continentOceania  ***
## ---
## Signif. codes:  
## 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 8.37 on 1697 degrees of freedom
## Multiple R-squared:  0.582,  Adjusted R-squared:  0.581 
## F-statistic:  394 on 6 and 1697 DF,  p-value: <2e-16
模型的输出结果是一个复杂的list,图 28.1给出了out的结构
线性模型结果的示意图

图 28.1: 线性模型结果的示意图

我们发现out对象包含了很多元素,比如系数、残差、模型残差自由度等等,用读取列表的方法可以直接读取

out$coefficients
out$residuals
out$fitted.values

事实上,前面使用的suammary()函数只是选取和打印了out对象的一小部分信息,同时这些信息的结构不适合用dplyr操作和ggplot2画图。

28.2 broom

为规整模型结果,这里我们推荐用David Robinson 开发的broom宏包。

library(broom)

broom 宏包将常用的100多种模型的输出结果规整成数据框 tibble()的格式,在模型比较和可视化中就可以方便使用dplyr函数了。 broom 提供了三个主要的函数:

  • tidy() 提取模型输出结果的主要信息,比如 coefficientst-statistics
  • glance() 把模型视为一个整体,提取如 F-statisticmodel deviance 或者 r-squared等信息
  • augment() 模型输出的信息添加到建模用的数据集中,比如fitted valuesresiduals

28.2.1 tidy

tidy(out)
## # A tibble: 7 x 5
##   term         estimate   std.error statistic   p.value
##   <chr>           <dbl>       <dbl>     <dbl>     <dbl>
## 1 (Intercept)   4.78e+1     3.40e-1    141.   0.       
## 2 gdpPercap     4.50e-4     2.35e-5     19.2  3.24e- 74
## 3 pop           6.57e-9     1.98e-9      3.33 9.01e-  4
## 4 continentAm~  1.35e+1     6.00e-1     22.5  5.19e- 98
## 5 continentAs~  8.19e+0     5.71e-1     14.3  4.06e- 44
## 6 continentEu~  1.75e+1     6.25e-1     28.0  6.34e-142
## 7 continentOc~  1.81e+1     1.78e+0     10.1  1.59e- 23
out %>%
  tidy() %>%
  ggplot(mapping = aes(
    x = term,
    y = estimate
  )) +
  geom_point() +
  coord_flip()

可以很方便的获取系数的置信区间

out %>%
  tidy(conf.int = TRUE)
## # A tibble: 7 x 7
##   term  estimate std.error statistic   p.value conf.low
##   <chr>    <dbl>     <dbl>     <dbl>     <dbl>    <dbl>
## 1 (Int~  4.78e+1   3.40e-1    141.   0.         4.71e+1
## 2 gdpP~  4.50e-4   2.35e-5     19.2  3.24e- 74  4.03e-4
## 3 pop    6.57e-9   1.98e-9      3.33 9.01e-  4  2.70e-9
## 4 cont~  1.35e+1   6.00e-1     22.5  5.19e- 98  1.23e+1
## 5 cont~  8.19e+0   5.71e-1     14.3  4.06e- 44  7.07e+0
## 6 cont~  1.75e+1   6.25e-1     28.0  6.34e-142  1.62e+1
## 7 cont~  1.81e+1   1.78e+0     10.1  1.59e- 23  1.46e+1
## # ... with 1 more variable: conf.high <dbl>
out %>%
  tidy(conf.int = TRUE) %>%
  filter(!term %in% c("(Intercept)")) %>%
  ggplot(aes(
    x = reorder(term, estimate),
    y = estimate, ymin = conf.low, ymax = conf.high
  )) +
  geom_pointrange() +
  coord_flip() +
  labs(x = "", y = "OLS Estimate")

28.2.2 augment

augment()会返回一个数据框,这个数据框是在原始数据框的基础上,增加了模型的拟合值(.fitted), 拟合值的标准误(.se.fit), 残差(.resid)等列。

augment(out)
## # A tibble: 1,704 x 10
##    lifeExp gdpPercap    pop continent .fitted .resid
##      <dbl>     <dbl>  <int> <fct>       <dbl>  <dbl>
##  1    28.8      779. 8.43e6 Asia         56.4  -27.6
##  2    30.3      821. 9.24e6 Asia         56.4  -26.1
##  3    32.0      853. 1.03e7 Asia         56.5  -24.5
##  4    34.0      836. 1.15e7 Asia         56.5  -22.4
##  5    36.1      740. 1.31e7 Asia         56.4  -20.3
##  6    38.4      786. 1.49e7 Asia         56.5  -18.0
##  7    39.9      978. 1.29e7 Asia         56.5  -16.7
##  8    40.8      852. 1.39e7 Asia         56.5  -15.7
##  9    41.7      649. 1.63e7 Asia         56.4  -14.7
## 10    41.8      635. 2.22e7 Asia         56.4  -14.7
## # ... with 1,694 more rows, and 4 more variables:
## #   .std.resid <dbl>, .hat <dbl>, .sigma <dbl>,
## #   .cooksd <dbl>
out %>%
  augment() %>%
  ggplot(mapping = aes(x = lifeExp, y = .fitted)) +
  geom_point()

28.2.3 glance

glance() 函数也会返回数据框,但这个数据框只有一行,内容实际上是summary()输出结果的最底下一行。

glance(out)
## # A tibble: 1 x 12
##   r.squared adj.r.squared sigma statistic   p.value
##       <dbl>         <dbl> <dbl>     <dbl>     <dbl>
## 1     0.582         0.581  8.37      394. 3.94e-317
## # ... with 7 more variables: df <dbl>, logLik <dbl>,
## #   AIC <dbl>, BIC <dbl>, deviance <dbl>,
## #   df.residual <int>, nobs <int>

28.3 应用

broom的三个主要函数在分组统计建模时,格外方便。

penguins <-
  palmerpenguins::penguins %>%
  drop_na()
penguins %>%
  group_nest(species) %>%
  mutate(model = purrr::map(data, ~ lm(bill_depth_mm ~ bill_length_mm, data = .))) %>%
  mutate(glance = purrr::map(model, ~ broom::glance(.))) %>%
  tidyr::unnest(glance)
## # A tibble: 3 x 15
##   species      data model r.squared adj.r.squared sigma
##   <fct>   <list<tb> <lis>     <dbl>         <dbl> <dbl>
## 1 Adelie  [146 x 7] <lm>      0.149         0.143 1.13 
## 2 Chinst~  [68 x 7] <lm>      0.427         0.418 0.866
## 3 Gentoo  [119 x 7] <lm>      0.428         0.423 0.749
## # ... with 9 more variables: statistic <dbl>,
## #   p.value <dbl>, df <dbl>, logLik <dbl>, AIC <dbl>,
## #   BIC <dbl>, deviance <dbl>, df.residual <int>,
## #   nobs <int>
fit_ols <- function(df) {
  lm(body_mass_g ~ bill_depth_mm + bill_length_mm, data = df)
}


out_tidy <- penguins %>%
  group_nest(species) %>%
  mutate(model = purrr::map(data, fit_ols)) %>%
  mutate(tidy = purrr::map(model, ~ broom::tidy(.))) %>%
  tidyr::unnest(tidy) %>%
  dplyr::filter(!term %in% "(Intercept)")

out_tidy
## # A tibble: 6 x 8
##   species      data model term  estimate std.error
##   <fct>   <list<tb> <lis> <chr>    <dbl>     <dbl>
## 1 Adelie  [146 x 7] <lm>  bill~    164.       25.1
## 2 Adelie  [146 x 7] <lm>  bill~     64.8      11.5
## 3 Chinst~  [68 x 7] <lm>  bill~    159.       43.3
## 4 Chinst~  [68 x 7] <lm>  bill~     23.8      14.7
## 5 Gentoo  [119 x 7] <lm>  bill~    255.       40.0
## 6 Gentoo  [119 x 7] <lm>  bill~     54.7      12.7
## # ... with 2 more variables: statistic <dbl>,
## #   p.value <dbl>
out_tidy %>%
  ggplot(aes(
    x = species, y = estimate,
    ymin = estimate - 2 * std.error,
    ymax = estimate + 2 * std.error,
    color = term
  )) +
  geom_pointrange(position = position_dodge(width = 0.25)) +
  theme(legend.position = "top") +
  labs(x = NULL, y = "Estimate", color = "系数")