第 72 章 贝叶斯分类模型

library(tidyverse)
library(tidybayes)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

72.1 数据

这里,我们模拟了500个人他的家庭收入和职业选择 (career = 1, 2, 3)

df <- readr::read_rds(here::here("demo_data", "career.rds")) 
df
## # A tibble: 500 × 2
##    family_income career
##            <dbl>  <int>
##  1         0.696      3
##  2         0.296      3
##  3         0.104      3
##  4         0.633      3
##  5         0.270      2
##  6         0.536      3
##  7         0.372      2
##  8         0.229      2
##  9         0.436      3
## 10         0.687      3
## # ℹ 490 more rows

career = 3为基线(baseline),我们要估计下面公式中的四个参数,

回想下logit回归的数学表达式

$$ \[\begin{align*} log\left(\frac{P(\text{career}=1)}{P(\text{career}=3)}\right) &= \alpha_{1} + \beta_{1} \text{income} \\ log\left(\frac{P(\text{career}=2)}{P(\text{career}=3)}\right) &= \alpha_{2} + \beta_{2} \text{income} \\ \end{align*}\] $$

多项Logistic回归模型,R语言可以使用 nnet::multinom() 函数

df %>% 
  dplyr::mutate(career = fct_rev(as_factor(career))) %>% 
  nnet::multinom(career ~ family_income, data = .)
## # weights:  9 (4 variable)
## initial  value 549.306144 
## iter  10 value 338.904780
## final  value 337.833351 
## converged
## Call:
## nnet::multinom(formula = career ~ family_income, data = .)
## 
## Coefficients:
##   (Intercept) family_income
## 2  -0.3915280     -1.844089
## 1  -0.9065662     -4.162446
## 
## Residual Deviance: 675.6667 
## AIC: 683.6667

72.2 stan for multi-logit Regression

72.2.1 stan 1

stan_program <- "
data{
    int N;              // number of observations
    int K;              // number of outcome values
    int career[N];      // outcome
    real family_income[N];
}
parameters{
    vector[K-1] a;      // intercepts
    vector[K-1] b;      // coefficients on family income
}
model{
    vector[K] p;
    vector[K] s;
    a ~ normal(0, 5);
    b ~ normal(0, 5);
    for ( i in 1:N ) {
        for ( j in 1:(K-1) ) s[j] = a[j] + b[j]*family_income[i];
        s[K] = 0;        
        p = softmax( s );
        career[i] ~ categorical( p );
    }
}
"


stan_data <- list(
    N             = nrow(df),
    K             = 3,         
    career        = df$career,
    family_income = df$family_income
  )


m1 <- stan(model_code = stan_program, data = stan_data)
m1
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##         mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
## a[1]   -0.94    0.01 0.32   -1.62   -1.15   -0.93   -0.72   -0.31  1802    1
## a[2]   -0.40    0.00 0.22   -0.83   -0.54   -0.40   -0.25    0.03  1922    1
## b[1]   -4.16    0.02 0.89   -6.02   -4.72   -4.11   -3.55   -2.52  1633    1
## b[2]   -1.84    0.01 0.42   -2.68   -2.13   -1.83   -1.55   -1.03  1888    1
## lp__ -340.27    0.03 1.42 -343.94 -341.00 -339.94 -339.22 -338.49  1710    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Oct 28 09:57:36 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

72.2.2 stan 2

stan_program <- "
data {
  int<lower = 2> K;
  int<lower = 0> N;
  int<lower = 1> D;
  int<lower = 1, upper = K> y[N];
  matrix[N, D] x;
}
transformed data {
  vector[D] zeros = rep_vector(0, D);
}
parameters {
  matrix[D, K - 1] beta_raw;
}
transformed parameters {
  matrix[D, K] beta;
  beta = append_col(beta_raw, zeros);
}
model {
  matrix[N, K] x_beta = x * beta;

  to_vector(beta_raw) ~ normal(0, 5);  

  for (n in 1:N)
    y[n] ~ categorical_logit(to_vector(x_beta[n]));
}
"

stan_data <- list(
    N = nrow(df),
    K = 3,         
    D = 2,         
    y = df$career,
    x = model.matrix( ~1 + family_income, data = df)
  )

m2 <- stan(model_code = stan_program, data = stan_data)
m2
## Inference for Stan model: anon_model.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##                  mean se_mean   sd    2.5%     25%     50%     75%   97.5%
## beta_raw[1,1]   -0.94    0.01 0.31   -1.58   -1.15   -0.94   -0.73   -0.34
## beta_raw[1,2]   -0.40    0.00 0.21   -0.82   -0.54   -0.39   -0.25    0.02
## beta_raw[2,1]   -4.14    0.02 0.87   -5.95   -4.70   -4.12   -3.55   -2.53
## beta_raw[2,2]   -1.85    0.01 0.42   -2.70   -2.14   -1.85   -1.56   -1.05
## beta[1,1]       -0.94    0.01 0.31   -1.58   -1.15   -0.94   -0.73   -0.34
## beta[1,2]       -0.40    0.00 0.21   -0.82   -0.54   -0.39   -0.25    0.02
## beta[1,3]        0.00     NaN 0.00    0.00    0.00    0.00    0.00    0.00
## beta[2,1]       -4.14    0.02 0.87   -5.95   -4.70   -4.12   -3.55   -2.53
## beta[2,2]       -1.85    0.01 0.42   -2.70   -2.14   -1.85   -1.56   -1.05
## beta[2,3]        0.00     NaN 0.00    0.00    0.00    0.00    0.00    0.00
## lp__          -340.28    0.04 1.44 -343.89 -340.97 -339.96 -339.20 -338.49
##               n_eff Rhat
## beta_raw[1,1]  2056    1
## beta_raw[1,2]  2001    1
## beta_raw[2,1]  2040    1
## beta_raw[2,2]  1918    1
## beta[1,1]      2056    1
## beta[1,2]      2001    1
## beta[1,3]       NaN  NaN
## beta[2,1]      2040    1
## beta[2,2]      1918    1
## beta[2,3]       NaN  NaN
## lp__           1560    1
## 
## Samples were drawn using NUTS(diag_e) at Mon Oct 28 09:58:15 2024.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).