Chapter 12 Multinomial Regression Model

  • Sometimes we wish to classify a response variable that has \(K\) classes, where \(K > 2\).

  • We can consider using a multinomial logistic regression, which is an extension to the two-class logistic regression.

  • We first select a single class to serve as the baseline. Without loss of generality, we use the \(K\)th class as the baseline (but the default choice of the baseline in functions for fitting multinomial logistic regression can be different). Then, the multinomial logistic regression specifies that \[\begin{equation*} \mathbb{P}(Y=k|X=x) = \frac{e^{\beta_{k0} + \beta_{k1} x_1+ \ldots + \beta_{kp} x_p}}{1 + \sum^{K-1}_{l=1} e^{\beta_{l0} + \beta_{l1} x_1+\ldots + \beta_{lp} x_p}}, \quad k=1,\ldots,K-1, \end{equation*}\] and \[\begin{equation*} \mathbb{P}(Y=K|X=x) = \frac{1}{1 + \sum^{K-1}_{l=1} e^{\beta_{l0} + \beta_{l1} x_1+\ldots + \beta_{lp} x_p}} \end{equation*}\]

  • Then, the log odds is \[\begin{equation*} \log \frac{\mathbb{P}(Y=k|X=x)}{\mathbb{P}(Y=K|X=x)} = \beta_{k0} + \beta_{k1}x_1 + \ldots + \beta_{kp}x_p. \end{equation*}\]

We can use the function multinom() in the package nnet for fitting a multinomial logistic regression.

Note: The variables on the rhs of the formula should be roughly scaled to [0,1] or the fit will be slow or may not converge at all.

wine_df <- read.csv("winequality-white.csv", sep = ";")
library(nnet) # for using multinom()
library(tidyverse)

str(wine_df)
## 'data.frame':    4898 obs. of  12 variables:
##  $ fixed.acidity       : num  7 6.3 8.1 7.2 7.2 8.1 6.2 7 6.3 8.1 ...
##  $ volatile.acidity    : num  0.27 0.3 0.28 0.23 0.23 0.28 0.32 0.27 0.3 0.22 ...
##  $ citric.acid         : num  0.36 0.34 0.4 0.32 0.32 0.4 0.16 0.36 0.34 0.43 ...
##  $ residual.sugar      : num  20.7 1.6 6.9 8.5 8.5 6.9 7 20.7 1.6 1.5 ...
##  $ chlorides           : num  0.045 0.049 0.05 0.058 0.058 0.05 0.045 0.045 0.049 0.044 ...
##  $ free.sulfur.dioxide : num  45 14 30 47 47 30 30 45 14 28 ...
##  $ total.sulfur.dioxide: num  170 132 97 186 186 97 136 170 132 129 ...
##  $ density             : num  1.001 0.994 0.995 0.996 0.996 ...
##  $ pH                  : num  3 3.3 3.26 3.19 3.19 3.26 3.18 3 3.3 3.22 ...
##  $ sulphates           : num  0.45 0.49 0.44 0.4 0.4 0.44 0.47 0.45 0.49 0.45 ...
##  $ alcohol             : num  8.8 9.5 10.1 9.9 9.9 10.1 9.6 8.8 9.5 11 ...
##  $ quality             : int  6 6 6 6 6 6 6 6 6 6 ...

table(wine_df$quality) # only a few data for quality 3 and 9
## 
##    3    4    5    6    7    8    9 
##   20  163 1457 2198  880  175    5
wine_df <- filter(wine_df, quality > 3, quality < 9) # remove data with quality 3 and 9

# scale the features before using multinom for better numerical stability
wine_df[, -12] <- scale(wine_df[, -12]) # the 12th column is the response "quality"
wine_df$quality <- factor(wine_df$quality)

# indices corresponding to training data
set.seed(1)
random_index <- sample(1:nrow(wine_df), 4000, replace = FALSE) 

fit <- multinom(quality ~ ., data = wine_df[random_index, ])
## # weights:  65 (48 variable)
## initial  value 6437.751650 
## iter  10 value 4787.842148
## iter  20 value 4602.011928
## iter  30 value 4430.212773
## iter  40 value 4298.888895
## iter  50 value 4255.839267
## final  value 4251.314670 
## converged
round(summary(fit)$coefficients, 3)
##   (Intercept) fixed.acidity volatile.acidity
## 5       2.712        -0.029           -0.354
## 6       3.398        -0.011           -0.937
## 7       2.036         0.409           -1.140
## 8       0.220         0.516           -1.050
##   citric.acid residual.sugar chlorides
## 5       0.080          1.142     0.003
## 6       0.079          1.756     0.015
## 7      -0.012          2.989    -0.248
## 8       0.014          3.550     0.021
##   free.sulfur.dioxide total.sulfur.dioxide density
## 5               0.478                0.234  -1.385
## 6               0.604                0.166  -1.945
## 7               0.741                0.147  -3.652
## 8               0.955                0.161  -4.325
##      pH sulphates alcohol
## 5 0.144     0.093  -0.610
## 6 0.264     0.229   0.183
## 7 0.690     0.452   0.135
## 8 0.815     0.375   0.165

For example, \[\begin{equation*} \log \frac{\mathbb{P}(Y = 8|X=x)}{\mathbb{P}(Y = 4 | X = x)} = 0.220 + 0.516 \text{fixed acidity} +\ldots + 0.165\text{alcohol}. \end{equation*}\]

  • One unit increase in alcohol is associated with the increase in the log odds of quality being 8 vs quality being 4 in the amount of 0.165.

p-value

z <- summary(fit)$coefficients / summary(fit)$standard.errors
p <- (1 - pnorm(abs(z), 0, 1)) * 2
round(p, digits = 3)
##   (Intercept) fixed.acidity volatile.acidity
## 5       0.000         0.852                0
## 6       0.000         0.943                0
## 7       0.000         0.017                0
## 8       0.257         0.020                0
##   citric.acid residual.sugar chlorides
## 5       0.405          0.001     0.975
## 6       0.422          0.000     0.867
## 7       0.913          0.000     0.056
## 8       0.926          0.000     0.902
##   free.sulfur.dioxide total.sulfur.dioxide density
## 5               0.001                0.074   0.006
## 6               0.000                0.211   0.000
## 7               0.000                0.318   0.000
## 8               0.000                0.417   0.000
##      pH sulphates alcohol
## 5 0.337     0.414   0.026
## 6 0.082     0.043   0.513
## 7 0.000     0.000   0.665
## 8 0.000     0.008   0.693

Prediction and Confusion Matrix

prediction <- predict(fit, wine_df[-random_index, ])
test_value <- wine_df[-random_index, ]$quality
confusion_mat <- table(prediction, test_value)
confusion_mat
##           test_value
## prediction   4   5   6   7   8
##          4   2   0   2   0   0
##          5  14 142  79   8   0
##          6  10 112 278 126  23
##          7   0   0  30  35  12
##          8   0   0   0   0   0

# accuracy
mean(prediction == test_value)
## [1] 0.5234822