7.7 Multinomial Logistic Regression
When dealing with categorical response variables with more than two possible outcomes, the multinomial logistic regression is a natural extension of the binary logistic model.
7.7.1 The Multinomial Distribution
Suppose we have a categorical response variable Yi that can take values in {1,2,…,J}. For each observation i, the probability that it falls into category j is given by:
The response follows a multinomial distribution:
This means that each observation belongs to exactly one of the J categories.
7.7.2 Modeling Probabilities Using Log-Odds
We cannot model the probabilities pij directly because they must sum to 1. Instead, we use a logit transformation, comparing each category j to a baseline category (typically the first category, j=1):
Using a linear function of covariates xi, we define:
Rearranging to express probabilities explicitly:
Since all probabilities must sum to 1:
Substituting for pij:
Solving for pi1:
Thus, the probability for category j is:
This formulation is known as the multinomial logit model.
7.7.3 Softmax Representation
An alternative formulation avoids choosing a baseline category and instead treats all J categories symmetrically using the softmax function:
This representation is often used in neural networks and general machine learning models.
7.7.4 Log-Odds Ratio Between Two Categories
The log-odds ratio between two categories k and k′ is:
This equation tells us that:
- If βkp>βk′p, then increasing xp increases the odds of choosing category k over k′.
- If βkp<βk′p, then increasing xp decreases the odds of choosing k over k′.
7.7.5 Estimation
To estimate the parameters βj, we use Maximum Likelihood estimation.
Given n independent observations (Yi,Xi), the likelihood function is:
Taking the log-likelihood:
Since there is no closed-form solution, numerical methods (see Non-linear Least Squares Estimation) are used for estimation.
7.7.6 Interpretation of Coefficients
- Each βjp represents the effect of xp on the log-odds of category j relative to the baseline.
- Positive coefficients mean increasing xp makes category j more likely relative to the baseline.
- Negative coefficients mean increasing xp makes category j less likely relative to the baseline.
7.7.7 Application: Multinomial Logistic Regression
1. Load Necessary Libraries and Data
library(faraway) # For the dataset
library(dplyr) # For data manipulation
library(ggplot2) # For visualization
library(nnet) # For multinomial logistic regression
# Load and inspect data
data(nes96, package="faraway")
head(nes96, 3)
#> popul TVnews selfLR ClinLR DoleLR PID age educ income vote
#> 1 0 7 extCon extLib Con strRep 36 HS $3Kminus Dole
#> 2 190 1 sliLib sliLib sliCon weakDem 20 Coll $3Kminus Clinton
#> 3 31 7 Lib Lib Con weakDem 24 BAdeg $3Kminus Clinton
The dataset nes96
contains survey responses, including political party identification (PID
), age (age
), and education level (educ
2. Define Political Strength Categories
We classify political strength into three categories:
Strong: Strong Democrat or Strong Republican
Weak: Weak Democrat or Weak Republican
Neutral: Independents and other affiliations
# Check distribution of political identity
#> strDem weakDem indDem indind indRep weakRep strRep
#> 200 180 108 37 94 150 175
# Define Political Strength variable
nes96 <- nes96 %>%
mutate(Political_Strength = case_when(
PID %in% c("strDem", "strRep") ~ "Strong",
PID %in% c("weakDem", "weakRep") ~ "Weak",
PID %in% c("indDem", "indind", "indRep") ~ "Neutral",
TRUE ~ NA_character_
# Summarize
nes96 %>% group_by(Political_Strength) %>% summarise(Count = n())
#> # A tibble: 3 × 2
#> Political_Strength Count
#> <chr> <int>
#> 1 Neutral 239
#> 2 Strong 375
#> 3 Weak 330
3. Visualizing Political Strength by Age
We visualize the proportion of each political strength category across age groups.
# Prepare data for visualization
Plot_DF <- nes96 %>%
mutate(Age_Grp = cut_number(age, 4)) %>%
group_by(Age_Grp, Political_Strength) %>%
summarise(count = n(), .groups = 'drop') %>%
group_by(Age_Grp) %>%
mutate(etotal = sum(count), proportion = count / etotal)
# Plot age vs political strength
Age_Plot <- ggplot(
x = Age_Grp,
y = proportion,
group = Political_Strength,
linetype = Political_Strength,
color = Political_Strength
) +
geom_line(size = 2) +
labs(title = "Political Strength by Age Group",
x = "Age Group",
y = "Proportion")
# Display plot
4. Fit a Multinomial Logistic Model
We model political strength as a function of age and education.
# Fit multinomial logistic regression
Multinomial_Model <-
multinom(Political_Strength ~ age + educ,
data = nes96,
trace = FALSE)
#> Call:
#> multinom(formula = Political_Strength ~ age + educ, data = nes96,
#> trace = FALSE)
#> Coefficients:
#> (Intercept) age educ.L educ.Q educ.C educ^4
#> Strong -0.08788729 0.010700364 -0.1098951 -0.2016197 -0.1757739 -0.02116307
#> Weak 0.51976285 -0.004868771 -0.1431104 -0.2405395 -0.2411795 0.18353634
#> educ^5 educ^6
#> Strong -0.1664377 -0.1359449
#> Weak -0.1489030 -0.2173144
#> Std. Errors:
#> (Intercept) age educ.L educ.Q educ.C educ^4
#> Strong 0.3017034 0.005280743 0.4586041 0.4318830 0.3628837 0.2964776
#> Weak 0.3097923 0.005537561 0.4920736 0.4616446 0.3881003 0.3169149
#> educ^5 educ^6
#> Strong 0.2515012 0.2166774
#> Weak 0.2643747 0.2199186
#> Residual Deviance: 2024.596
#> AIC: 2056.596
5. Stepwise Model Selection Based on AIC
We perform stepwise selection to find the best model.
Multinomial_Step <- step(Multinomial_Model, trace = 0)
#> trying - age
#> trying - educ
#> trying - age
#> Call:
#> multinom(formula = Political_Strength ~ age, data = nes96, trace = FALSE)
#> Coefficients:
#> (Intercept) age
#> Strong -0.01988977 0.009832916
#> Weak 0.59497046 -0.005954348
#> Residual Deviance: 2030.756
#> AIC: 2038.756
Compare the best model to the full model based on deviance:
q = deviance(Multinomial_Step) - deviance(Multinomial_Model),
df = Multinomial_Model$edf - Multinomial_Step$edf,
lower.tail = FALSE
#> [1] 0.9078172
A non-significant p-value suggests no major difference between the full and stepwise models.
6. Predictions & Visualization
Predicting Political Strength Probabilities by Age
# Create data for prediction
PlotData <- data.frame(age = seq(from = 19, to = 91))
# Get predicted probabilities
Preds <- PlotData %>%
type = "probs")))
# Plot predicted probabilities across age
x = Preds$age,
y = Preds$Neutral,
type = "l",
ylim = c(0.2, 0.6),
col = "black",
ylab = "Proportion",
xlab = "Age"
lines(x = Preds$age,
y = Preds$Weak,
col = "blue")
lines(x = Preds$age,
y = Preds$Strong,
col = "red")
legend = c("Neutral", "Weak", "Strong"),
col = c("black", "blue", "red"),
lty = 1
Predict for Specific Ages
# Predict class for a 34-year-old
predict(Multinomial_Step, data.frame(age = 34))
#> [1] Weak
#> Levels: Neutral Strong Weak
# Predict probabilities for 34 and 35-year-olds
predict(Multinomial_Step, data.frame(age = c(34, 35)), type = "probs")
#> Neutral Strong Weak
#> 1 0.2597275 0.3556910 0.3845815
#> 2 0.2594080 0.3587639 0.3818281
7.7.8 Application: Gamma Regression
When response variables are strictly positive, we use Gamma regression.
1. Load and Prepare Data
library(agridat) # Agricultural dataset
# Load and filter data
dat <- agridat::streibig.competition
gammaDat <- subset(dat, sseeds < 1) # Keep only barley
gammaDat <-
x = bseeds,
y = bdwt,
block = factor(block))
2. Visualization of Inverse Yield
ggplot(gammaDat, aes(x = x, y = 1 / y)) +
geom_point(aes(color = block, shape = block)) +
labs(title = "Inverse Yield vs Seeding Rate",
x = "Seeding Rate",
y = "Inverse Yield")
3. Fit Gamma Regression Model
Gamma regression models yield as a function of seeding rate using an inverse link: ηij=β0j+β1jxij+β2x2ij,Yij=η−1ij
m1 <- glm(y ~ block + block * x + block * I(x^2),
data = gammaDat, family = Gamma(link = "inverse"))
#> Call:
#> glm(formula = y ~ block + block * x + block * I(x^2), family = Gamma(link = "inverse"),
#> data = gammaDat)
#> Deviance Residuals:
#> Min 1Q Median 3Q Max
#> -1.21708 -0.44148 0.02479 0.17999 0.80745
#> Coefficients:
#> Estimate Std. Error t value Pr(>|t|)
#> (Intercept) 1.115e-01 2.870e-02 3.886 0.000854 ***
#> blockB2 -1.208e-02 3.880e-02 -0.311 0.758630
#> blockB3 -2.386e-02 3.683e-02 -0.648 0.524029
#> x -2.075e-03 1.099e-03 -1.888 0.072884 .
#> I(x^2) 1.372e-05 9.109e-06 1.506 0.146849
#> blockB2:x 5.198e-04 1.468e-03 0.354 0.726814
#> blockB3:x 7.475e-04 1.393e-03 0.537 0.597103
#> blockB2:I(x^2) -5.076e-06 1.184e-05 -0.429 0.672475
#> blockB3:I(x^2) -6.651e-06 1.123e-05 -0.592 0.560012
#> ---
#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
#> (Dispersion parameter for Gamma family taken to be 0.3232083)
#> Null deviance: 13.1677 on 29 degrees of freedom
#> Residual deviance: 7.8605 on 21 degrees of freedom
#> AIC: 225.32
#> Number of Fisher Scoring iterations: 5
4. Predictions and Visualization
# Generate new data for prediction
newdf <-
expand.grid(x = seq(0, 120, length = 50),
block = factor(c("B1", "B2", "B3")))
# Predict responses
newdf$pred <- predict(m1, newdata = newdf, type = "response")
# Plot predictions
ggplot(gammaDat, aes(x = x, y = y)) +
geom_point(aes(color = block, shape = block)) +
geom_line(data = newdf, aes(
x = x,
y = pred,
color = block,
linetype = block
)) +
labs(title = "Predicted Yield by Seeding Rate",
x = "Seeding Rate",
y = "Yield")