Explainability Blog with Detail Code Notebook With Visualization

R Notebook

Goal: Explain which features make a product good quality.

Example Question : What properties of wine makes a good wine.

Wine has been used and produced for thousands of years. Different culture, different age group enjoy drinking wine. There is 400B of market cap. Companies across the world are competing to produce better quality wine to get market share.

However there is no consensus what is definition of good quality wine. Good quality is hard to define in words and explain. To explain good quality wine, we study
a wine dataset Built explanatory model

Dataset

“Wine Quality” dataset from the UC Irvine Machine Learning Data Repository

#Tutorial
#install.packages("ggridges")
#install.packages("ggthemes")
#install.packages("iml")
#install.packages("breakDown")
#install.packages("DALEX")
#install.packages("glmnet")
#install.packages("partykit")
# data wrangling
library(tidyverse)
library(readr)

# ml
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following objects are masked from 'package:MLmetrics':
## 
##     MAE, RMSE
## The following object is masked from 'package:purrr':
## 
##     lift
# plotting
library(gridExtra)
library(grid)
library(ggridges)
library(ggthemes)
theme_set(theme_minimal())

# explaining models
# https://github.com/christophM/iml
library(iml)

# https://pbiecek.github.io/breakDown/
library(breakDown)

# https://pbiecek.github.io/DALEX/
library(DALEX)
## Welcome to DALEX (version: 2.3.0).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
## Additional features will be available after installation of: ggpubr.
## Use 'install_dependencies()' to get all suggested dependencies
## 
## Attaching package: 'DALEX'
## The following object is masked from 'package:dplyr':
## 
##     explain
library(partykit)
## Loading required package: libcoin
## Loading required package: mvtnorm
library(libcoin)
library(mvtnorm)

Overview

We first load data, clean, do data exploration. Build linear regression model. Build random forst predictor Explain

  1. Feature importance
  2. Partial dependence plots
  3. Individual conditional expectation plots (ICE)
  4. Tree surrogate
  5. LocalModel: Local Interpretable Model-agnostic Explanations (similar to lime)
  6. Shapley value for explaining single predictions

Load the data

# Load and clean data
clean_data <- function(df){
  red_wine_df <- read_delim("data/winequality-red.csv", delim=";")
  red_wine_df['wine_type'] <- 'red'
  
  white_wine_df <- read_delim("data/winequality-white.csv", delim=";")
  white_wine_df['wine_type'] <- 'white'
  
  wine_df <- bind_rows(red_wine_df,white_wine_df) %>% 
    filter(quality >= 0 & quality <= 10) %>% 
    drop_na()
  
  #white_wine_df <- read_delim("data/winequality-white.csv", delim=";")
  #white_wine_df['wine_type'] <- 'white'
  
  #wine_df <- rbind(red_wine_df,white_wine_df) %>% 

  return(wine_df)
}

wine_df <- clean_data(df)
## Rows: 1599 Columns: 12
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ";"
## dbl (12): fixed acidity, volatile acidity, citric acid, residual sugar, chlo...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
## Rows: 4898 Columns: 12
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ";"
## dbl (12): fixed acidity, volatile acidity, citric acid, residual sugar, chlo...
## 
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
wine_df = wine_df %>%
            mutate(quality_cat = as.factor(ifelse(quality < 6, "qual_low", "qual_high")))

Data Exploration

“Table 1” shows all the variable of our data along with the first few rows of our data. The varaibles are defined as follows:

  • fixed_acidity - acids involved with wine that are fixed (don’t evaporate readily)

  • volatile_acidity - the amount of acetic acid in wine, which at high of levels can lead to an unpleasant, vinegar taste

  • citric_acid - weak organic acid that occurs naturally in citrus fruits and can add ‘freshness’ and flavor to wines

  • residual_sugar - refers to any natural grape sugars that are left over after fermentation stops. it’s rare to find wines with less than 1 gram/liter and wines with greater than 45 grams/liter are considered sweet

  • chlorides - the amount of salt in the wine

  • free_sulfur_dioxide -free form of SO2 exists in equilibrium between molecular SO2 (as a dissolved gas) and bisulphate ion. It exhibits both germicidal and antioxidant properties

  • total_sulfur_dioxide - amount of free and bound forms of S02

  • density - self explanatory

  • pH - from a winemaker’s point of view, it is a way to measure ripeness in relation to acidity

  • sulphates - a wine additive which can contribute to sulfur dioxide gas (S02) levels. It acts as antimicrobial and antioxidant

  • alcohol - the percent alcohol content of the wine

  • quality - output variable

Out of 6497 rows in dataset, 6251 are clean rows. We used data 1875 for exploration and 4376 for testing/model.

colnames(wine_df) = gsub(" ", "_", colnames(wine_df))
glimpse(wine_df)
## Rows: 6,497
## Columns: 14
## $ fixed_acidity        <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.4, 7.9, 7.3, 7.8, 7.5…
## $ volatile_acidity     <dbl> 0.700, 0.880, 0.760, 0.280, 0.700, 0.660, 0.600, …
## $ citric_acid          <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.00, 0.06, 0.00, 0…
## $ residual_sugar       <dbl> 1.9, 2.6, 2.3, 1.9, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1,…
## $ chlorides            <dbl> 0.076, 0.098, 0.092, 0.075, 0.076, 0.075, 0.069, …
## $ free_sulfur_dioxide  <dbl> 11, 25, 15, 17, 11, 13, 15, 15, 9, 17, 15, 17, 16…
## $ total_sulfur_dioxide <dbl> 34, 67, 54, 60, 34, 40, 59, 21, 18, 102, 65, 102,…
## $ density              <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9978, 0…
## $ pH                   <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.51, 3.30, 3.39, 3…
## $ sulphates            <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.56, 0.46, 0.47, 0…
## $ alcohol              <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 9.4, 10.0, 9.5, 10.…
## $ quality              <dbl> 5, 5, 5, 6, 5, 5, 5, 7, 7, 5, 5, 5, 5, 5, 5, 5, 7…
## $ wine_type            <chr> "red", "red", "red", "red", "red", "red", "red", …
## $ quality_cat          <fct> qual_low, qual_low, qual_low, qual_high, qual_low…
p1 = wine_df %>%
  ggplot(aes(x = quality, fill = quality)) +
    geom_bar(alpha = 0.8) +
    scale_fill_tableau() +
    guides(fill = FALSE)
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
p1

p2 = wine_df %>%
  ggplot(aes(x = quality_cat, fill = quality_cat)) +
    geom_bar(alpha = 0.8) +
    scale_fill_tableau() +
    guides(fill = FALSE)
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
p2

p3 = wine_df %>%
  gather(x, y, fixed_acidity:alcohol) %>%
  ggplot(aes(x = y, y = quality_cat, color = quality_cat, fill = quality_cat)) +
    facet_wrap( ~ x, scale = "free", ncol = 4) +
    scale_fill_tableau() +
    scale_color_tableau() +
    scale_fill_viridis_d(direction = -1, guide = "none")+
    geom_density_ridges(alpha = 0.7) +
    guides(fill = FALSE, color = FALSE) +
    theme(plot.title = element_text(size = 24, hjust = 0.5))+
    labs(title = "Relationship between Quality and and Features ", y = "Quality")
## Scale for 'fill' is already present. Adding another scale for 'fill', which
## will replace the existing scale.
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
p3
## Picking joint bandwidth of 0.182
## Picking joint bandwidth of 0.00375
## Picking joint bandwidth of 0.0211
## Picking joint bandwidth of 0.000499
## Picking joint bandwidth of 0.168
## Picking joint bandwidth of 3.28
## Picking joint bandwidth of 0.0278
## Picking joint bandwidth of 0.855
## Picking joint bandwidth of 0.0214
## Picking joint bandwidth of 10.2
## Picking joint bandwidth of 0.0266

#grid.arrange(p1, p2, ncol = 2, widths = c(0.3, 0.7))
wine_df2 <- wine_df[c ('fixed_acidity' ,'volatile_acidity','citric_acid', 
                       'residual_sugar', 'chlorides','free_sulfur_dioxide', 
                       'total_sulfur_dioxide', 'density',
                      'pH', 'sulphates','alcohol', 'quality_cat' )]
glimpse(wine_df2)
## Rows: 6,497
## Columns: 12
## $ fixed_acidity        <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.4, 7.9, 7.3, 7.8, 7.5…
## $ volatile_acidity     <dbl> 0.700, 0.880, 0.760, 0.280, 0.700, 0.660, 0.600, …
## $ citric_acid          <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.00, 0.06, 0.00, 0…
## $ residual_sugar       <dbl> 1.9, 2.6, 2.3, 1.9, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1,…
## $ chlorides            <dbl> 0.076, 0.098, 0.092, 0.075, 0.076, 0.075, 0.069, …
## $ free_sulfur_dioxide  <dbl> 11, 25, 15, 17, 11, 13, 15, 15, 9, 17, 15, 17, 16…
## $ total_sulfur_dioxide <dbl> 34, 67, 54, 60, 34, 40, 59, 21, 18, 102, 65, 102,…
## $ density              <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9978, 0…
## $ pH                   <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.51, 3.30, 3.39, 3…
## $ sulphates            <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.56, 0.46, 0.47, 0…
## $ alcohol              <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 9.4, 10.0, 9.5, 10.…
## $ quality_cat          <fct> qual_low, qual_low, qual_low, qual_high, qual_low…
#Remove categorical columns
wine_df_num = subset(wine_df2, select = -c(quality_cat))
histgrams <- apply(wine_df_num, 2,
                   function(x){
                       figure(title= "NULL", xlab = colnames(x), 
                              width = 400, height = 250) %>%
                       ly_hist(x,breaks = 40, freq = FALSE, 
                               color=brewer.pal(9, "GnBu")) %>%
                       ly_density(x)})

grid_plot(histgrams, nrow=6)

Build Model

Build Train and test set

set.seed(42)

idx = createDataPartition(wine_df2$quality_cat, 
                           p = 0.7, 
                           list = FALSE, 
                           times = 1)

wine_train = wine_df2[ idx,]
wine_test  = wine_df2[-idx,]
glimpse(wine_df2)
## Rows: 6,497
## Columns: 12
## $ fixed_acidity        <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.4, 7.9, 7.3, 7.8, 7.5…
## $ volatile_acidity     <dbl> 0.700, 0.880, 0.760, 0.280, 0.700, 0.660, 0.600, …
## $ citric_acid          <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.00, 0.06, 0.00, 0…
## $ residual_sugar       <dbl> 1.9, 2.6, 2.3, 1.9, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1,…
## $ chlorides            <dbl> 0.076, 0.098, 0.092, 0.075, 0.076, 0.075, 0.069, …
## $ free_sulfur_dioxide  <dbl> 11, 25, 15, 17, 11, 13, 15, 15, 9, 17, 15, 17, 16…
## $ total_sulfur_dioxide <dbl> 34, 67, 54, 60, 34, 40, 59, 21, 18, 102, 65, 102,…
## $ density              <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9978, 0…
## $ pH                   <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.51, 3.30, 3.39, 3…
## $ sulphates            <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.56, 0.46, 0.47, 0…
## $ alcohol              <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 9.4, 10.0, 9.5, 10.…
## $ quality_cat          <fct> qual_low, qual_low, qual_low, qual_high, qual_low…
options(knitr.table.format = "latex")
head(wine_df2) %>%
  kbl(caption = "Summary Table of Wine Dataset") %>% 
  kable_classic(html_font = "Cambria", full_width = F)  %>%
  kable_styling(latex_options = c("striped", "scale_down"))
Summary Table of Wine Dataset
fixed_acidity volatile_acidity citric_acid residual_sugar chlorides free_sulfur_dioxide total_sulfur_dioxide density pH sulphates alcohol quality_cat
7.4 0.70 0.00 1.9 0.076 11 34 0.9978 3.51 0.56 9.4 qual_low
7.8 0.88 0.00 2.6 0.098 25 67 0.9968 3.20 0.68 9.8 qual_low
7.8 0.76 0.04 2.3 0.092 15 54 0.9970 3.26 0.65 9.8 qual_low
11.2 0.28 0.56 1.9 0.075 17 60 0.9980 3.16 0.58 9.8 qual_high
7.4 0.70 0.00 1.9 0.076 11 34 0.9978 3.51 0.56 9.4 qual_low
7.4 0.66 0.00 1.8 0.075 13 40 0.9978 3.51 0.56 9.4 qual_low
#figure 2

#corr=cor(exploratory_data_wine, method = "pearson")
corr=cor(wine_df_num, method = "pearson")
ggcorrplot(corr, hc.order = TRUE, 
           lab = TRUE, 
           lab_size = 3, 
           method="square", 
           colors = c("tomato2", "white", "springgreen3"),
           title="Figure 2: Correlation of Variables")