Chapter 12 Multinomial Regression Model

  • Sometimes we wish to classify a response variable that has \(K\) classes, where \(K > 2\).

  • In this case, we can use multinomial logistic regression, which extends the two-class logistic regression model.

  • First, we select one class to serve as the baseline (also called the reference class). Without loss of generality, suppose we use the \(K\)th class as the baseline.

(Note: In practice, the default baseline chosen by software may be different.)

Then the multinomial logistic regression model specifies that, for \(k = 1, \ldots, K-1\), \[\begin{equation*} \mathbb{P}(Y = k \mid X = x) \frac{e^{\beta_{k0} + \beta_{k1} x_1 + \cdots + \beta_{kp} x_p}} {1 + \sum_{l=1}^{K-1} e^{\beta_{l0} + \beta_{l1} x_1 + \cdots + \beta_{lp} x_p}}, \end{equation*}\] and for the baseline class, \[\begin{equation*} \mathbb{P}(Y = K \mid X = x) = \frac{1} {1 + \sum_{l=1}^{K-1} e^{\beta_{l0} + \beta_{l1} x_1 + \cdots + \beta_{lp} x_p}}. \end{equation*}\]

  • Then, the log odds is \[\begin{equation*} \log \frac{\mathbb{P}(Y=k|X=x)}{\mathbb{P}(Y=K|X=x)} = \beta_{k0} + \beta_{k1}x_1 + \ldots + \beta_{kp}x_p. \end{equation*}\]

Fitting the Model in R

We can use the function multinom() from the package nnet to fit a multinomial logistic regression model.

Note: The predictors on the right-hand side of the formula should be roughly scaled to a similar range (for example, mean 0 and variance 1). Otherwise, the algorithm may converge slowly or fail to converge.

wine_df <- read.csv("winequality-white.csv", sep = ";")
library(nnet)     # for multinom()
library(tidyverse)

str(wine_df)
## 'data.frame':    4898 obs. of  12 variables:
##  $ fixed.acidity       : num  7 6.3 8.1 7.2 7.2 8.1 6.2 7 6.3 8.1 ...
##  $ volatile.acidity    : num  0.27 0.3 0.28 0.23 0.23 0.28 0.32 0.27 0.3 0.22 ...
##  $ citric.acid         : num  0.36 0.34 0.4 0.32 0.32 0.4 0.16 0.36 0.34 0.43 ...
##  $ residual.sugar      : num  20.7 1.6 6.9 8.5 8.5 6.9 7 20.7 1.6 1.5 ...
##  $ chlorides           : num  0.045 0.049 0.05 0.058 0.058 0.05 0.045 0.045 0.049 0.044 ...
##  $ free.sulfur.dioxide : num  45 14 30 47 47 30 30 45 14 28 ...
##  $ total.sulfur.dioxide: num  170 132 97 186 186 97 136 170 132 129 ...
##  $ density             : num  1.001 0.994 0.995 0.996 0.996 ...
##  $ pH                  : num  3 3.3 3.26 3.19 3.19 3.26 3.18 3 3.3 3.22 ...
##  $ sulphates           : num  0.45 0.49 0.44 0.4 0.4 0.44 0.47 0.45 0.49 0.45 ...
##  $ alcohol             : num  8.8 9.5 10.1 9.9 9.9 10.1 9.6 8.8 9.5 11 ...
##  $ quality             : int  6 6 6 6 6 6 6 6 6 6 ...

table(wine_df$quality) 
## 
##    3    4    5    6    7    8    9 
##   20  163 1457 2198  880  175    5
# Only a small number of observations for quality 3 and 9

wine_df <- filter(wine_df, quality > 3, quality < 9) 
# Remove observations with quality 3 and 9

# Scale predictors for numerical stability
wine_df[, -12] <- scale(wine_df[, -12]) 
# The 12th column is the response variable "quality"

wine_df$quality <- factor(wine_df$quality)

# Indices corresponding to training data
set.seed(1)
random_index <- sample(1:nrow(wine_df), 4000, replace = FALSE)

fit <- multinom(quality ~ ., data = wine_df[random_index, ])
## # weights:  65 (48 variable)
## initial  value 6437.751650 
## iter  10 value 4787.842148
## iter  20 value 4602.011928
## iter  30 value 4430.212773
## iter  40 value 4298.888895
## iter  50 value 4255.839267
## final  value 4251.314670 
## converged
round(summary(fit)$coefficients, 3)
##   (Intercept) fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 5       2.712        -0.029           -0.354       0.080          1.142     0.003
## 6       3.398        -0.011           -0.937       0.079          1.756     0.015
## 7       2.036         0.409           -1.140      -0.012          2.989    -0.248
## 8       0.220         0.516           -1.050       0.014          3.550     0.021
##   free.sulfur.dioxide total.sulfur.dioxide density    pH sulphates alcohol
## 5               0.478                0.234  -1.385 0.144     0.093  -0.610
## 6               0.604                0.166  -1.945 0.264     0.229   0.183
## 7               0.741                0.147  -3.652 0.690     0.452   0.135
## 8               0.955                0.161  -4.325 0.815     0.375   0.165

For example, \[\begin{equation*} \log \frac{\mathbb{P}(Y = 8|X=x)}{\mathbb{P}(Y = 4 | X = x)} = 0.220 + 0.516 \text{fixed acidity} +\ldots + 0.165\text{alcohol}. \end{equation*}\]

  • A one-unit increase in alcohol is associated with an increase of 0.165 in the log odds of quality being 8 versus quality being 4, holding all other variables fixed.

p-values

The multinom() function does not automatically report \(p\)-values. We can compute Wald-type \(p\)-values as follows:

z <- summary(fit)$coefficients / summary(fit)$standard.errors
p <- (1 - pnorm(abs(z), 0, 1)) * 2
round(p, digits = 3)
##   (Intercept) fixed.acidity volatile.acidity citric.acid residual.sugar chlorides
## 5       0.000         0.852                0       0.405          0.001     0.975
## 6       0.000         0.943                0       0.422          0.000     0.867
## 7       0.000         0.017                0       0.913          0.000     0.056
## 8       0.257         0.020                0       0.926          0.000     0.902
##   free.sulfur.dioxide total.sulfur.dioxide density    pH sulphates alcohol
## 5               0.001                0.074   0.006 0.337     0.414   0.026
## 6               0.000                0.211   0.000 0.082     0.043   0.513
## 7               0.000                0.318   0.000 0.000     0.000   0.665
## 8               0.000                0.417   0.000 0.000     0.008   0.693

Prediction and Confusion Matrix

prediction <- predict(fit, wine_df[-random_index, ])
test_value <- wine_df[-random_index, ]$quality
confusion_mat <- table(prediction, test_value)
confusion_mat
##           test_value
## prediction   4   5   6   7   8
##          4   2   0   2   0   0
##          5  14 142  79   8   0
##          6  10 112 278 126  23
##          7   0   0  30  35  12
##          8   0   0   0   0   0

# Accuracy
mean(prediction == test_value)
## [1] 0.5234822