第 72 章 贝叶斯分类模型
library(tidyverse)
library(tidybayes)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
72.1 数据
这里,我们模拟了500个人他的家庭收入和职业选择 (career = 1, 2, 3
)
## # A tibble: 500 × 2
## family_income career
## <dbl> <int>
## 1 0.696 3
## 2 0.296 3
## 3 0.104 3
## 4 0.633 3
## 5 0.270 2
## 6 0.536 3
## 7 0.372 2
## 8 0.229 2
## 9 0.436 3
## 10 0.687 3
## # ℹ 490 more rows
以 career = 3
为基线(baseline),我们要估计下面公式中的四个参数,
回想下logit回归的数学表达式
$$ \[\begin{align*} log\left(\frac{P(\text{career}=1)}{P(\text{career}=3)}\right) &= \alpha_{1} + \beta_{1} \text{income} \\ log\left(\frac{P(\text{career}=2)}{P(\text{career}=3)}\right) &= \alpha_{2} + \beta_{2} \text{income} \\ \end{align*}\] $$
多项Logistic回归模型,R语言可以使用 nnet::multinom()
函数
df %>%
dplyr::mutate(career = fct_rev(as_factor(career))) %>%
nnet::multinom(career ~ family_income, data = .)
## # weights: 9 (4 variable)
## initial value 549.306144
## iter 10 value 338.904780
## final value 337.833351
## converged
## Call:
## nnet::multinom(formula = career ~ family_income, data = .)
##
## Coefficients:
## (Intercept) family_income
## 2 -0.3915280 -1.844089
## 1 -0.9065662 -4.162446
##
## Residual Deviance: 675.6667
## AIC: 683.6667
72.2 stan for multi-logit Regression
72.2.1 stan 1
stan_program <- "
data{
int N; // number of observations
int K; // number of outcome values
int career[N]; // outcome
real family_income[N];
}
parameters{
vector[K-1] a; // intercepts
vector[K-1] b; // coefficients on family income
}
model{
vector[K] p;
vector[K] s;
a ~ normal(0, 5);
b ~ normal(0, 5);
for ( i in 1:N ) {
for ( j in 1:(K-1) ) s[j] = a[j] + b[j]*family_income[i];
s[K] = 0;
p = softmax( s );
career[i] ~ categorical( p );
}
}
"
stan_data <- list(
N = nrow(df),
K = 3,
career = df$career,
family_income = df$family_income
)
m1 <- stan(model_code = stan_program, data = stan_data)
m1
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## a[1] -0.94 0.01 0.32 -1.62 -1.15 -0.93 -0.72 -0.31 1802 1
## a[2] -0.40 0.00 0.22 -0.83 -0.54 -0.40 -0.25 0.03 1922 1
## b[1] -4.16 0.02 0.89 -6.02 -4.72 -4.11 -3.55 -2.52 1633 1
## b[2] -1.84 0.01 0.42 -2.68 -2.13 -1.83 -1.55 -1.03 1888 1
## lp__ -340.27 0.03 1.42 -343.94 -341.00 -339.94 -339.22 -338.49 1710 1
##
## Samples were drawn using NUTS(diag_e) at Mon Oct 28 09:57:36 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
72.2.2 stan 2
stan_program <- "
data {
int<lower = 2> K;
int<lower = 0> N;
int<lower = 1> D;
int<lower = 1, upper = K> y[N];
matrix[N, D] x;
}
transformed data {
vector[D] zeros = rep_vector(0, D);
}
parameters {
matrix[D, K - 1] beta_raw;
}
transformed parameters {
matrix[D, K] beta;
beta = append_col(beta_raw, zeros);
}
model {
matrix[N, K] x_beta = x * beta;
to_vector(beta_raw) ~ normal(0, 5);
for (n in 1:N)
y[n] ~ categorical_logit(to_vector(x_beta[n]));
}
"
stan_data <- list(
N = nrow(df),
K = 3,
D = 2,
y = df$career,
x = model.matrix( ~1 + family_income, data = df)
)
m2 <- stan(model_code = stan_program, data = stan_data)
m2
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5%
## beta_raw[1,1] -0.94 0.01 0.31 -1.58 -1.15 -0.94 -0.73 -0.34
## beta_raw[1,2] -0.40 0.00 0.21 -0.82 -0.54 -0.39 -0.25 0.02
## beta_raw[2,1] -4.14 0.02 0.87 -5.95 -4.70 -4.12 -3.55 -2.53
## beta_raw[2,2] -1.85 0.01 0.42 -2.70 -2.14 -1.85 -1.56 -1.05
## beta[1,1] -0.94 0.01 0.31 -1.58 -1.15 -0.94 -0.73 -0.34
## beta[1,2] -0.40 0.00 0.21 -0.82 -0.54 -0.39 -0.25 0.02
## beta[1,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00 0.00
## beta[2,1] -4.14 0.02 0.87 -5.95 -4.70 -4.12 -3.55 -2.53
## beta[2,2] -1.85 0.01 0.42 -2.70 -2.14 -1.85 -1.56 -1.05
## beta[2,3] 0.00 NaN 0.00 0.00 0.00 0.00 0.00 0.00
## lp__ -340.28 0.04 1.44 -343.89 -340.97 -339.96 -339.20 -338.49
## n_eff Rhat
## beta_raw[1,1] 2056 1
## beta_raw[1,2] 2001 1
## beta_raw[2,1] 2040 1
## beta_raw[2,2] 1918 1
## beta[1,1] 2056 1
## beta[1,2] 2001 1
## beta[1,3] NaN NaN
## beta[2,1] 2040 1
## beta[2,2] 1918 1
## beta[2,3] NaN NaN
## lp__ 1560 1
##
## Samples were drawn using NUTS(diag_e) at Mon Oct 28 09:58:15 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).