第 72 章 抽样数据的规整与可视化

在贝叶斯抽样样本量比较大,我们需要规整和可视化,就需要借助一些函数。这里简单介绍tidybayes宏包和它的姊妹宏包 ggdist,更多的技术参数见官方手册。

72.1 企鹅案例

问题简化,我们只挑选Gentoo类企鹅

library(palmerpenguins)

gentoo <- penguins %>% 
  drop_na() %>% 
  filter(species == "Gentoo")

gentoo
## # A tibble: 119 × 8
##   species island bill_length_mm bill_depth_mm
##   <fct>   <fct>           <dbl>         <dbl>
## 1 Gentoo  Biscoe           46.1          13.2
## 2 Gentoo  Biscoe           50            16.3
## 3 Gentoo  Biscoe           48.7          14.1
## 4 Gentoo  Biscoe           50            15.2
## 5 Gentoo  Biscoe           47.6          14.5
## 6 Gentoo  Biscoe           46.5          13.5
## # … with 113 more rows, and 4 more variables:
## #   flipper_length_mm <int>, body_mass_g <int>,
## #   sex <fct>, year <int>

先看下两个变量的关系

gentoo %>% 
  ggplot(aes(x = bill_length_mm, bill_depth_mm)) +
  geom_point()

72.2 Stan模型

假设我们建立最简单的线性模型,其中预测因子bill_length_mm,被解释变量是 bill_depth_mm

\[ \begin{align} y_n &\sim \operatorname{normal}(\mu_n, \,\, \sigma)\\ \mu_n &= \alpha + \beta x_n \end{align} \]

stan_program <- "
data {
  int<lower=0> N;
  vector[N] y;
  vector[N] x;
  int<lower=0> M;
  vector[M] new_x;  
}
parameters {
  real alpha;
  real beta;
  real<lower=0> sigma;
}
model {
  y ~ normal(alpha + beta * x, sigma);
  
  alpha  ~ normal(0, 10);
  beta   ~ normal(0, 10);
  sigma  ~ exponential(1);
}
generated quantities {
  vector[M] y_fit;
  vector[M] y_rep;
  for (n in 1:M) {
    y_fit[n] = alpha + beta * new_x[n];
    y_rep[n] = normal_rng(alpha + beta * new_x[n], sigma);
  }
}
"

library(modelr)
newdata <- gentoo %>% 
  data_grid(
    bill_length_mm = seq_range(bill_length_mm, 100)
)

# or
# newdata <- data.frame(
#     bill_length_mm = seq(min(gentoo$bill_length_mm), max(gentoo$bill_length_mm), length.out = 100)
#    ) 


stan_data <- list(
   N = nrow(gentoo),
   x = gentoo$bill_length_mm, 
   y = gentoo$bill_depth_mm,
   M = nrow(newdata),
   new_x = newdata$bill_length_mm
  )

fit <- stan(model_code = stan_program, data = stan_data)

72.3 抽样

draws <- fit %>% 
  tidybayes::gather_draws(alpha, beta, sigma)
draws
## # A tibble: 12,000 × 5
## # Groups:   .variable [3]
##   .chain .iteration .draw .variable .value
##    <int>      <int> <int> <chr>      <dbl>
## 1      1          1     1 alpha       4.40
## 2      1          2     2 alpha       4.02
## 3      1          3     3 alpha       3.79
## 4      1          4     4 alpha       4.00
## 5      1          5     5 alpha       6.16
## 6      1          6     6 alpha       5.71
## # … with 11,994 more rows

72.4 统计汇总

draws %>% 
  ggdist::mean_qi(.width = c(0.65, 0.89) )
## # A tibble: 6 × 7
##   .variable .value .lower .upper .width .point
##   <chr>      <dbl>  <dbl>  <dbl>  <dbl> <chr> 
## 1 alpha      5.09   4.13   6.03    0.65 mean  
## 2 beta       0.208  0.188  0.228   0.65 mean  
## 3 sigma      0.754  0.707  0.798   0.65 mean  
## 4 alpha      5.09   3.39   6.83    0.89 mean  
## 5 beta       0.208  0.172  0.244   0.89 mean  
## 6 sigma      0.754  0.681  0.840   0.89 mean  
## # … with 1 more variable: .interval <chr>

72.5 可视化

  • geom_slabinterval() / stat_slabinterval() family
draws %>% 
  ggplot(aes(x = .value, y = .variable)) + 
  ggdist::stat_interval()
draws %>% 
  ggplot(aes(x = .value, y = .variable)) + 
  ggdist::stat_slab()
draws %>% 
  ggplot(aes(x = .value, y = .variable)) + 
  ggdist::stat_slabinterval()
draws %>% 
  filter(.variable %in% c("beta", "sigma")) %>% 
  ggplot(aes(x = .value, y = .variable)) + 
  ggdist::stat_slabinterval()  +
  facet_grid(~ .variable, labeller = "label_both", scales = "free") 
  • geom_dotsinterval() / stat_dotsinterval() family
draws %>% 
  filter(.variable %in% c("beta", "sigma")) %>% 
  ggplot(aes(x = .value, y = .variable)) + 
  stat_dotsinterval(
    quantiles = 200,
    justification = -0.1,
    slab_color = "black",
    slab_fill = "orange",
    interval_color = "red"
  ) 
  • geom_lineribbon() / stat_lineribbon() family
fit %>% 
  tidybayes::gather_draws(y_fit[i]) %>% 
  ggdist::median_qi(.width = c(0.89)) %>%
  bind_cols(newdata) %>% 
  
  ggplot() + 
  geom_point(
    data = gentoo,
    aes(bill_length_mm, bill_depth_mm)
  ) +
  geom_lineribbon(
    aes(x = bill_length_mm, y = .value, ymin = .lower, ymax = .upper),
    alpha = 0.3, 
    fill = "gray50"
  ) +
  theme_classic() +
  scale_fill_brewer(direction = -1)
  • 组合
penguins %>%
  ggplot(aes(y = species, x = bill_length_mm, fill = species)) +
  stat_slab(aes(thickness = after_stat(pdf*n)), scale = 0.7) +
  stat_dotsinterval(side = "bottom", scale = 0.7, slab_size = NA) +
  scale_fill_brewer(palette = "Set2") +
  ggtitle("Rain cloud plot")