# 第 66 章 贝叶斯广义线性模型

library(tidyverse)
library(tidybayes)
library(rstan)
library(wesanderson)

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

theme_set(
bayesplot::theme_default() +
ggthemes::theme_tufte() +
theme(plot.background = element_rect(fill = wes_palette("Moonrise2")[3],
color = wes_palette("Moonrise2")[3]))
)

## 66.1 广义线性模型

1. 响应变量的概率分布(Normal, Binomial, Poisson, Categorical, Multinomial, Poisson, Beta)

2. 预测变量的线性组合

$\eta = X \vec{\beta}$

1. 连接函数($$g(.)$$)， 将期望值映射到预测变量的线性组合 $g(\mu) = \eta$

$\mu = g^{-1}(\eta)$

## 66.2 研究生院录取时有性别歧视？

UCBadmit <- readr::read_rds(here::here('demo_data', "UCBadmit.rds")) %>%
mutate(applicant_gender = factor(applicant_gender, levels = c("male", "female")))
UCBadmit
##    dept applicant_gender admit rejection applications      ratio
## 1     A             male   512       313          825 0.62060606
## 2     A           female    89        19          108 0.82407407
## 3     B             male   353       207          560 0.63035714
## 4     B           female    17         8           25 0.68000000
## 5     C             male   120       205          325 0.36923077
## 6     C           female   202       391          593 0.34064081
## 7     D             male   138       279          417 0.33093525
## 8     D           female   131       244          375 0.34933333
## 9     E             male    53       138          191 0.27748691
## 10    E           female    94       299          393 0.23918575
## 11    F             male    22       351          373 0.05898123
## 12    F           female    24       317          341 0.07038123

\begin{align*} \text{admit}_i & \sim \operatorname{Binomial}(n_i, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{gender}[i]} \\ \alpha_j & \sim \operatorname{Normal}(0, 1.5), \end{align*}

\begin{align*} \text{logit}(p_{i}) &= \log\Big(\frac{p_{i}}{1 - p_{i}}\Big) = \alpha_{\text{gender}[i]}\\ \text{equivalent to,} \quad p_{i} &= \frac{1}{1 + \exp[- \alpha_{\text{gender}[i]}]} \\ & = \frac{\exp(\alpha_{\text{gender}[i]})}{1 + \exp (\alpha_{\text{gender}[i]})} \\ & = \text{inv_logit}(\alpha_{\text{gender}[i]}) \end{align*}

R语言glm函数能够拟合一系列的广义线性模型

model_logit <- glm(
cbind(admit, rejection) ~ 0 + applicant_gender,
)

summary(model_logit)
##
## Call:
## glm(formula = cbind(admit, rejection) ~ 0 + applicant_gender,
##
## Deviance Residuals:
##      Min        1Q    Median        3Q       Max
## -16.7915   -4.7613   -0.4365    5.1025   11.2022
##
## Coefficients:
##                        Estimate Std. Error z value Pr(>|z|)
## applicant_gendermale   -0.22013    0.03879  -5.675 1.38e-08 ***
## applicant_genderfemale -0.83049    0.05077 -16.357  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
##     Null deviance: 1107.08  on 12  degrees of freedom
## Residual deviance:  783.61  on 10  degrees of freedom
## AIC: 856.55
##
## Number of Fisher Scoring iterations: 4

### 66.2.1 Stan 代码

stan_program_A <- '
data {
int n;
int applications[n];
int applicant_gender[n];
}
parameters {
real a[2];
}
transformed parameters {
vector[n] p;
for (i in 1:n) {
p[i] = inv_logit(a[applicant_gender[i]]);
}
}
model {
a ~ normal(0, 1.5);
for (i in 1:n) {
}
}
'

compose_data()

fit01 <- stan(model_code = stan_program_A, data = stan_data)
fit01
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
##           mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff
## a[1]     -0.22    0.00 0.04    -0.30    -0.25    -0.22    -0.19    -0.15  3654
## a[2]     -0.83    0.00 0.05    -0.93    -0.86    -0.83    -0.80    -0.73  2837
## p[1]      0.45    0.00 0.01     0.43     0.44     0.45     0.45     0.46  3655
## p[2]      0.30    0.00 0.01     0.28     0.30     0.30     0.31     0.32  2844
## p[3]      0.45    0.00 0.01     0.43     0.44     0.45     0.45     0.46  3655
## p[4]      0.30    0.00 0.01     0.28     0.30     0.30     0.31     0.32  2844
## p[5]      0.45    0.00 0.01     0.43     0.44     0.45     0.45     0.46  3655
## p[6]      0.30    0.00 0.01     0.28     0.30     0.30     0.31     0.32  2844
## p[7]      0.45    0.00 0.01     0.43     0.44     0.45     0.45     0.46  3655
## p[8]      0.30    0.00 0.01     0.28     0.30     0.30     0.31     0.32  2844
## p[9]      0.45    0.00 0.01     0.43     0.44     0.45     0.45     0.46  3655
## p[10]     0.30    0.00 0.01     0.28     0.30     0.30     0.31     0.32  2844
## p[11]     0.45    0.00 0.01     0.43     0.44     0.45     0.45     0.46  3655
## p[12]     0.30    0.00 0.01     0.28     0.30     0.30     0.31     0.32  2844
## lp__  -2976.60    0.02 0.98 -2979.26 -2976.97 -2976.31 -2975.89 -2975.63  1944
##       Rhat
## a[1]     1
## a[2]     1
## p[1]     1
## p[2]     1
## p[3]     1
## p[4]     1
## p[5]     1
## p[6]     1
## p[7]     1
## p[8]     1
## p[9]     1
## p[10]    1
## p[11]    1
## p[12]    1
## lp__     1
##
## Samples were drawn using NUTS(diag_e) at Tue Jul 18 20:11:11 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
inv_logit <- function(x) {
exp(x) / (1 + exp(x))
}

fit01 %>%
pivot_wider(
names_from = i,
values_from = a,
names_prefix = "a_"
) %>%
mutate(
diff_a = a_1 - a_2,
diff_p = inv_logit(a_1) - inv_logit(a_2)
) %>%
pivot_longer(contains("diff")) %>%
group_by(name) %>%
tidybayes::mean_qi(value, .width = .89)
## # A tibble: 2 × 7
##   name   value .lower .upper .width .point .interval
##   <chr>  <dbl>  <dbl>  <dbl>  <dbl> <chr>  <chr>
## 1 diff_a 0.609  0.509  0.712   0.89 mean   qi
## 2 diff_p 0.141  0.119  0.165   0.89 mean   qi

• 从 log-odd 度量看，男性录取率高于女性录取率
• 从概率的角度看，男性的录取概率比女性高 12% 到 16%

fit01 %>%
tidybayes::gather_draws(p[i]) %>%
tidybayes::mean_qi(.width = .89) %>%
ungroup() %>%
rename(Estimate = .value) %>%

ggplot(aes(x = applicant_gender, y = ratio)) +
geom_point(aes(y = Estimate),
color = wes_palette("Moonrise2")[1],
shape = 1, size = 3
) +
geom_point(color = wes_palette("Moonrise2")[2]) +
geom_line(aes(group = dept),
color = wes_palette("Moonrise2")[2]) +
scale_y_continuous(limits = 0:1) +

facet_grid(. ~ dept) +
labs(x = NULL, y = 'Probability of admission')

• 原始数据中只有学院C和E，女性录取率略低于男性，但模型结果却表明，女性预期的录取概率比男性低14%。

• 同时，我们看到男性和女性申请的院系不一样，以下是各学院男女申请人数的比例。女性更倾向与选择A、B之外的学院，比如F学院，而这些学院申请人数比较多，因而男女录取率都很低，甚至不到10%. 也就说，大多女性选择录取率低的学院，从而拉低了女性整体的录取率。

UCBadmit %>%
group_by(dept) %>%
mutate(proportion = applications / sum(applications)) %>%
select(dept, applicant_gender, proportion) %>%
pivot_wider(
names_from = dept,
values_from = proportion
) %>%
mutate(
across(where(is.double), round, digits = 2)
)
## Warning: There was 1 warning in mutate().
## ℹ In argument: across(where(is.double), round, digits = 2).
## Caused by warning:
## ! The ... argument of across() is deprecated as of dplyr 1.1.0.
## Supply arguments directly to .fns through an anonymous function instead.
##
##   # Previously
##   across(a:b, mean, na.rm = TRUE)
##
##   # Now
##   across(a:b, \(x) mean(x, na.rm = TRUE))
## # A tibble: 2 × 7
##   applicant_gender     A     B     C     D     E     F
##   <fct>            <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 male              0.88  0.96  0.35  0.53  0.33  0.52
## 2 female            0.12  0.04  0.65  0.47  0.67  0.48
• 模型没有问题，而是我们的提问（对全体学院，男女平均录取率有什么差别？）是有问题的。

### 66.2.2 增加预测变量

\begin{align*} \text{admit}_i & \sim \operatorname{Binomial} (n_i, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{gender}[i]} + \delta_{\text{dept}[i]} \\ \alpha_j & \sim \operatorname{Normal} (0, 1.5) \\ \delta_k & \sim \operatorname{Normal} (0, 1.5), \end{align*}

stan_program_B <- '
data {
int n;
int n_dept;
int applications[n];
int applicant_gender[n];
int dept[n];
}
parameters {
real a[2];
real b[n_dept];
}
transformed parameters {
vector[n] p;
for (i in 1:n) {
p[i] = inv_logit(a[applicant_gender[i]] + b[dept[i]]);
}
}
model {
a ~ normal(0, 1.5);
b ~ normal(0, 1.5);
for (i in 1:n) {
}
}
'
compose_data()

fit02 <- stan(model_code = stan_program_B, data = stan_data)
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## https://mc-stan.org/misc/warnings.html#bulk-ess
fit02
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
##           mean se_mean   sd     2.5%      25%      50%      75%    97.5% n_eff
## a[1]     -0.54    0.03 0.52    -1.53    -0.89    -0.55    -0.21     0.54   383
## a[2]     -0.44    0.03 0.52    -1.43    -0.79    -0.45    -0.11     0.65   392
## b[1]      1.12    0.03 0.52     0.00     0.79     1.12     1.48     2.12   387
## b[2]      1.07    0.03 0.52    -0.02     0.75     1.08     1.43     2.08   395
## b[3]     -0.14    0.03 0.52    -1.23    -0.47    -0.13     0.21     0.85   391
## b[4]     -0.18    0.03 0.52    -1.30    -0.50    -0.17     0.17     0.84   390
## b[5]     -0.62    0.03 0.52    -1.72    -0.95    -0.61    -0.27     0.39   392
## b[6]     -2.18    0.03 0.54    -3.34    -2.51    -2.17    -1.81    -1.16   421
## p[1]      0.64    0.00 0.02     0.61     0.63     0.64     0.65     0.67  5217
## p[2]      0.66    0.00 0.02     0.62     0.65     0.66     0.68     0.71  4817
## p[3]      0.63    0.00 0.02     0.59     0.62     0.63     0.64     0.67  4756
## p[4]      0.65    0.00 0.03     0.60     0.64     0.65     0.67     0.70  4176
## p[5]      0.34    0.00 0.02     0.30     0.32     0.34     0.35     0.38  4791
## p[6]      0.36    0.00 0.02     0.33     0.35     0.36     0.37     0.39  5145
## p[7]      0.33    0.00 0.02     0.29     0.32     0.33     0.34     0.37  5140
## p[8]      0.35    0.00 0.02     0.31     0.34     0.35     0.36     0.39  5241
## p[9]      0.24    0.00 0.02     0.20     0.23     0.24     0.25     0.28  4115
## p[10]     0.26    0.00 0.02     0.22     0.25     0.26     0.27     0.30  4651
## p[11]     0.06    0.00 0.01     0.05     0.06     0.06     0.07     0.08  3006
## p[12]     0.07    0.00 0.01     0.05     0.06     0.07     0.08     0.09  3151
## lp__  -2599.47    0.06 1.98 -2604.08 -2600.58 -2599.15 -2597.99 -2596.58  1083
##       Rhat
## a[1]  1.03
## a[2]  1.03
## b[1]  1.03
## b[2]  1.03
## b[3]  1.03
## b[4]  1.03
## b[5]  1.03
## b[6]  1.03
## p[1]  1.00
## p[2]  1.00
## p[3]  1.00
## p[4]  1.00
## p[5]  1.00
## p[6]  1.00
## p[7]  1.00
## p[8]  1.00
## p[9]  1.00
## p[10] 1.00
## p[11] 1.00
## p[12] 1.00
## lp__  1.00
##
## Samples were drawn using NUTS(diag_e) at Tue Jul 18 20:12:07 2023.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
inv_logit <- function(x) {
exp(x) / (1 + exp(x))
}

fit02 %>%
pivot_wider(
names_from = i,
values_from = a,
names_prefix = "a_"
) %>%
mutate(
diff_a = a_1 - a_2,
diff_p = inv_logit(a_1) - inv_logit(a_2)
) %>%
pivot_longer(contains("diff")) %>%
group_by(name) %>%
tidybayes::mean_qi(value, .width = .89)
## # A tibble: 2 × 7
##   name     value  .lower  .upper .width .point .interval
##   <chr>    <dbl>   <dbl>   <dbl>  <dbl> <chr>  <chr>
## 1 diff_a -0.0992 -0.228  0.0296    0.89 mean   qi
## 2 diff_p -0.0222 -0.0522 0.00638   0.89 mean   qi

• 从 log-odd 度量看，男性录取率低于女性录取率
• 从概率的角度看，男性的录取概率比女性低 2%

(增加了一个变量，剧情反转了。辛普森佯谬)

fit02 %>%
tidybayes::gather_draws(p[i]) %>%
tidybayes::mean_qi(.width = .89) %>%
ungroup() %>%
rename(Estimate = .value) %>%

ggplot(aes(x = applicant_gender, y = ratio)) +
geom_point(aes(y = Estimate),
color = wes_palette("Moonrise2")[1],
shape = 1, size = 3
) +
geom_point(color = wes_palette("Moonrise2")[2]) +
geom_line(aes(group = dept),
color = wes_palette("Moonrise2")[2]) +
scale_y_continuous(limits = 0:1) +

facet_grid(. ~ dept) +
labs(x = NULL, y = 'Probability of admission')

## 66.3 作业

• 讲性别从模型中去除，然后看看模型结果

• 如果我们假定申请人性别影响院系选择录取率，把院系当作中介，建立中介模型

library(ggdag)
## Warning: package 'ggdag' was built under R version 4.2.3
##
## Attaching package: 'ggdag'
## The following object is masked from 'package:stats':
##
##     filter
dag_coords <-
tibble(name = c("G", "D", "A"),
x    = c(1, 2, 3),
y    = c(1, 2, 1))

dagify(D ~ G,
A ~ D + G,
coords = dag_coords) %>%

ggplot(aes(x = x, y = y, xend = xend, yend = yend)) +
geom_dag_text(color = wes_palette("Moonrise2")[4], family = "serif") +
geom_dag_edges(edge_color = wes_palette("Moonrise2")[4]) +
scale_x_continuous(NULL, breaks = NULL) +
scale_y_continuous(NULL, breaks = NULL)