Chapter 12 Multinomial Regression Model
Sometimes we wish to classify a response variable that has \(K\) classes, where \(K > 2\).
In this case, we can use multinomial logistic regression, which extends the two-class logistic regression model.
First, we select one class to serve as the baseline (also called the reference class). Without loss of generality, suppose we use the \(K\)th class as the baseline.
(Note: In practice, the default baseline chosen by software may be different.)
Then the multinomial logistic regression model specifies that, for \(k = 1, \ldots, K-1\), \[\begin{equation*} \mathbb{P}(Y = k \mid X = x) \frac{e^{\beta_{k0} + \beta_{k1} x_1 + \cdots + \beta_{kp} x_p}} {1 + \sum_{l=1}^{K-1} e^{\beta_{l0} + \beta_{l1} x_1 + \cdots + \beta_{lp} x_p}}, \end{equation*}\] and for the baseline class, \[\begin{equation*} \mathbb{P}(Y = K \mid X = x) = \frac{1} {1 + \sum_{l=1}^{K-1} e^{\beta_{l0} + \beta_{l1} x_1 + \cdots + \beta_{lp} x_p}}. \end{equation*}\]
- Then, the log odds is \[\begin{equation*} \log \frac{\mathbb{P}(Y=k|X=x)}{\mathbb{P}(Y=K|X=x)} = \beta_{k0} + \beta_{k1}x_1 + \ldots + \beta_{kp}x_p. \end{equation*}\]
Fitting the Model in R
We can use the function multinom() from the package nnet to fit a multinomial logistic regression model.
For illustration, we use the wine quality dataset from https://archive.ics.uci.edu/dataset/186/wine+quality.
The variable
qualityis treated as a categorical variable, so this is a classification problem.
Note: The predictors on the right-hand side of the formula should be roughly scaled to a similar range (for example, mean 0 and variance 1). Otherwise, the algorithm may converge slowly or fail to converge.
library(nnet) # for multinom()
library(tidyverse)
str(wine_df)
## 'data.frame': 4898 obs. of 12 variables:
## $ fixed.acidity : num 7 6.3 8.1 7.2 7.2 8.1 6.2 7 6.3 8.1 ...
## $ volatile.acidity : num 0.27 0.3 0.28 0.23 0.23 0.28 0.32 0.27 0.3 0.22 ...
## $ citric.acid : num 0.36 0.34 0.4 0.32 0.32 0.4 0.16 0.36 0.34 0.43 ...
## $ residual.sugar : num 20.7 1.6 6.9 8.5 8.5 6.9 7 20.7 1.6 1.5 ...
## $ chlorides : num 0.045 0.049 0.05 0.058 0.058 0.05 0.045 0.045 0.049 0.044 ...
## $ free.sulfur.dioxide : num 45 14 30 47 47 30 30 45 14 28 ...
## $ total.sulfur.dioxide: num 170 132 97 186 186 97 136 170 132 129 ...
## $ density : num 1.001 0.994 0.995 0.996 0.996 ...
## $ pH : num 3 3.3 3.26 3.19 3.19 3.26 3.18 3 3.3 3.22 ...
## $ sulphates : num 0.45 0.49 0.44 0.4 0.4 0.44 0.47 0.45 0.49 0.45 ...
## $ alcohol : num 8.8 9.5 10.1 9.9 9.9 10.1 9.6 8.8 9.5 11 ...
## $ quality : int 6 6 6 6 6 6 6 6 6 6 ...
table(wine_df$quality)
##
## 3 4 5 6 7 8 9
## 20 163 1457 2198 880 175 5
# Only a small number of observations for quality 3 and 9
wine_df <- filter(wine_df, quality > 3, quality < 9)
# Remove observations with quality 3 and 9
# Scale predictors for numerical stability
wine_df[, -12] <- scale(wine_df[, -12])
# The 12th column is the response variable "quality"
wine_df$quality <- factor(wine_df$quality)
# Indices corresponding to training data
set.seed(1)
random_index <- sample(1:nrow(wine_df), 4000, replace = FALSE)
fit <- multinom(quality ~ ., data = wine_df[random_index, ])
## # weights: 65 (48 variable)
## initial value 6437.751650
## iter 10 value 4787.842148
## iter 20 value 4602.011928
## iter 30 value 4430.212773
## iter 40 value 4298.888895
## iter 50 value 4255.839267
## final value 4251.314670
## converged
round(summary(fit)$coefficients, 3)
## (Intercept) fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 5 2.712 -0.029 -0.354 0.080 1.142 0.003
## 6 3.398 -0.011 -0.937 0.079 1.756 0.015
## 7 2.036 0.409 -1.140 -0.012 2.989 -0.248
## 8 0.220 0.516 -1.050 0.014 3.550 0.021
## free.sulfur.dioxide total.sulfur.dioxide density pH sulphates alcohol
## 5 0.478 0.234 -1.385 0.144 0.093 -0.610
## 6 0.604 0.166 -1.945 0.264 0.229 0.183
## 7 0.741 0.147 -3.652 0.690 0.452 0.135
## 8 0.955 0.161 -4.325 0.815 0.375 0.165For example, \[\begin{equation*} \log \frac{\mathbb{P}(Y = 8|X=x)}{\mathbb{P}(Y = 4 | X = x)} = 0.220 + 0.516 \text{fixed acidity} +\ldots + 0.165\text{alcohol}. \end{equation*}\]
- A one-unit increase in
alcoholis associated with an increase of 0.165 in the log odds of quality being 8 versus quality being 4, holding all other variables fixed.
p-values
The multinom() function does not automatically report \(p\)-values. We can compute Wald-type \(p\)-values as follows:
z <- summary(fit)$coefficients / summary(fit)$standard.errors
p <- (1 - pnorm(abs(z), 0, 1)) * 2
round(p, digits = 3)
## (Intercept) fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 5 0.000 0.852 0 0.405 0.001 0.975
## 6 0.000 0.943 0 0.422 0.000 0.867
## 7 0.000 0.017 0 0.913 0.000 0.056
## 8 0.257 0.020 0 0.926 0.000 0.902
## free.sulfur.dioxide total.sulfur.dioxide density pH sulphates alcohol
## 5 0.001 0.074 0.006 0.337 0.414 0.026
## 6 0.000 0.211 0.000 0.082 0.043 0.513
## 7 0.000 0.318 0.000 0.000 0.000 0.665
## 8 0.000 0.417 0.000 0.000 0.008 0.693Prediction and Confusion Matrix
prediction <- predict(fit, wine_df[-random_index, ])
test_value <- wine_df[-random_index, ]$quality
confusion_mat <- table(prediction, test_value)
confusion_mat
## test_value
## prediction 4 5 6 7 8
## 4 2 0 2 0 0
## 5 14 142 79 8 0
## 6 10 112 278 126 23
## 7 0 0 30 35 12
## 8 0 0 0 0 0
# Accuracy
mean(prediction == test_value)
## [1] 0.5234822