Chapter 12 Multinomial Regression Model
Sometimes we wish to classify a response variable that has \(K\) classes, where \(K > 2\).
We can consider using a multinomial logistic regression, which is an extension to the two-class logistic regression.
We first select a single class to serve as the baseline. Without loss of generality, we use the \(K\)th class as the baseline (but the default choice of the baseline in functions for fitting multinomial logistic regression can be different). Then, the multinomial logistic regression specifies that \[\begin{equation*} \mathbb{P}(Y=k|X=x) = \frac{e^{\beta_{k0} + \beta_{k1} x_1+ \ldots + \beta_{kp} x_p}}{1 + \sum^{K-1}_{l=1} e^{\beta_{l0} + \beta_{l1} x_1+\ldots + \beta_{lp} x_p}}, \quad k=1,\ldots,K-1, \end{equation*}\] and \[\begin{equation*} \mathbb{P}(Y=K|X=x) = \frac{1}{1 + \sum^{K-1}_{l=1} e^{\beta_{l0} + \beta_{l1} x_1+\ldots + \beta_{lp} x_p}} \end{equation*}\]
Then, the log odds is \[\begin{equation*} \log \frac{\mathbb{P}(Y=k|X=x)}{\mathbb{P}(Y=K|X=x)} = \beta_{k0} + \beta_{k1}x_1 + \ldots + \beta_{kp}x_p. \end{equation*}\]
We can use the function multinom()
in the package nnet
for fitting a multinomial logistic regression.
In the following, I use the wine quality dataset from for illustration.
The quality is a categorical variable and so we have a classification problem.
Note: The variables on the rhs of the formula should be roughly scaled to [0,1] or the fit will be slow or may not converge at all.
library(nnet) # for using multinom()
## 'data.frame': 4898 obs. of 12 variables:
## $ fixed.acidity : num 7 6.3 8.1 7.2 7.2 8.1 6.2 7 6.3 8.1 ...
## $ volatile.acidity : num 0.27 0.3 0.28 0.23 0.23 0.28 0.32 0.27 0.3 0.22 ...
## $ citric.acid : num 0.36 0.34 0.4 0.32 0.32 0.4 0.16 0.36 0.34 0.43 ...
## $ residual.sugar : num 20.7 1.6 6.9 8.5 8.5 6.9 7 20.7 1.6 1.5 ...
## $ chlorides : num 0.045 0.049 0.05 0.058 0.058 0.05 0.045 0.045 0.049 0.044 ...
## $ free.sulfur.dioxide : num 45 14 30 47 47 30 30 45 14 28 ...
## $ total.sulfur.dioxide: num 170 132 97 186 186 97 136 170 132 129 ...
## $ density : num 1.001 0.994 0.995 0.996 0.996 ...
## $ pH : num 3 3.3 3.26 3.19 3.19 3.26 3.18 3 3.3 3.22 ...
## $ sulphates : num 0.45 0.49 0.44 0.4 0.4 0.44 0.47 0.45 0.49 0.45 ...
## $ alcohol : num 8.8 9.5 10.1 9.9 9.9 10.1 9.6 8.8 9.5 11 ...
## $ quality : int 6 6 6 6 6 6 6 6 6 6 ...
table(wine_df$quality) # only a few data for quality 3 and 9
## 3 4 5 6 7 8 9
## 20 163 1457 2198 880 175 5
wine_df <- filter(wine_df, quality > 3, quality < 9) # remove data with quality 3 and 9
# scale the features before using multinom for better numerical stability
wine_df[, -12] <- scale(wine_df[, -12]) # the 12th column is the response "quality"
wine_df$quality <- factor(wine_df$quality)
# indices corresponding to training data
random_index <- sample(1:nrow(wine_df), 4000, replace = FALSE)
fit <- multinom(quality ~ ., data = wine_df[random_index, ])
## # weights: 65 (48 variable)
## initial value 6437.751650
## iter 10 value 4787.842148
## iter 20 value 4602.011928
## iter 30 value 4430.212773
## iter 40 value 4298.888895
## iter 50 value 4255.839267
## final value 4251.314670
## converged
round(summary(fit)$coefficients, 3)
## (Intercept) fixed.acidity volatile.acidity
## 5 2.712 -0.029 -0.354
## 6 3.398 -0.011 -0.937
## 7 2.036 0.409 -1.140
## 8 0.220 0.516 -1.050
## citric.acid residual.sugar chlorides
## 5 0.080 1.142 0.003
## 6 0.079 1.756 0.015
## 7 -0.012 2.989 -0.248
## 8 0.014 3.550 0.021
## free.sulfur.dioxide total.sulfur.dioxide density
## 5 0.478 0.234 -1.385
## 6 0.604 0.166 -1.945
## 7 0.741 0.147 -3.652
## 8 0.955 0.161 -4.325
## pH sulphates alcohol
## 5 0.144 0.093 -0.610
## 6 0.264 0.229 0.183
## 7 0.690 0.452 0.135
## 8 0.815 0.375 0.165
For example, \[\begin{equation*} \log \frac{\mathbb{P}(Y = 8|X=x)}{\mathbb{P}(Y = 4 | X = x)} = 0.220 + 0.516 \text{fixed acidity} +\ldots + 0.165\text{alcohol}. \end{equation*}\]
- One unit increase in
is associated with the increase in the log odds of quality being 8 vs quality being 4 in the amount of 0.165.
z <- summary(fit)$coefficients / summary(fit)$standard.errors
p <- (1 - pnorm(abs(z), 0, 1)) * 2
round(p, digits = 3)
## (Intercept) fixed.acidity volatile.acidity
## 5 0.000 0.852 0
## 6 0.000 0.943 0
## 7 0.000 0.017 0
## 8 0.257 0.020 0
## citric.acid residual.sugar chlorides
## 5 0.405 0.001 0.975
## 6 0.422 0.000 0.867
## 7 0.913 0.000 0.056
## 8 0.926 0.000 0.902
## free.sulfur.dioxide total.sulfur.dioxide density
## 5 0.001 0.074 0.006
## 6 0.000 0.211 0.000
## 7 0.000 0.318 0.000
## 8 0.000 0.417 0.000
## pH sulphates alcohol
## 5 0.337 0.414 0.026
## 6 0.082 0.043 0.513
## 7 0.000 0.000 0.665
## 8 0.000 0.008 0.693
Prediction and Confusion Matrix
prediction <- predict(fit, wine_df[-random_index, ])
test_value <- wine_df[-random_index, ]$quality
confusion_mat <- table(prediction, test_value)
## test_value
## prediction 4 5 6 7 8
## 4 2 0 2 0 0
## 5 14 142 79 8 0
## 6 10 112 278 126 23
## 7 0 0 30 35 12
## 8 0 0 0 0 0
# accuracy
mean(prediction == test_value)
## [1] 0.5234822