Problem Statement

Heart disease describes a range of conditions that affect your heart. With growing stress, the number of cases of heart diseases are increasing rapidly.

According to the World Health Organisation (WHO), Cardiovascular diseases (CVDs) are the number 1 cause of death globally, taking an estimated 17.9 million lives each year. 17.9 million people die each year from CVDs, an estimated 31% of all deaths worldwide.

The doctors of Health Hospital in Zastra wish to incorporate Data Science into their workings. Seeing the rising cases of heart diseases, they are specially interested in predicting the presence of heart disease in a person using some existing data.

Note: This case is taken from DPhi 51st Challenge

Objective

  • Build a Machine Learning model to determine if heart disease is present or not.
  • Build Shapley Additive Explanations (SHAP) explainer to explain the conditional interaction between the presence of heart disease and its predictor.

Import Libraries

import numpy as np
import pandas as pd

# data visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn')

# modeling
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

# explainable AI
import shap
from IPython.core.display import display, HTML

Data Loading

Load the data which is downloaded from here

heart_data = pd.read_csv("data_input/heart_disease.csv")
heart_data.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal target
0 48 1 2 124 255 1 1 175 0 0.0 2 2 2 1
1 68 0 2 120 211 0 0 115 0 1.5 1 0 2 1
2 46 1 0 120 249 0 0 144 0 0.8 2 0 3 0
3 60 1 0 130 253 0 1 144 1 1.4 2 1 3 0
4 43 1 0 115 303 0 1 181 0 1.2 1 0 2 1

Data description:

  • age: Age in years
  • sex: 1 = male, 0 = female
  • cp: Chest pain type
  • trestbps: Resting blood pressure (in mm Hg on admission to the hospital)
  • chol: serum cholesterol in mg/dl
  • fbs: fasting blood sugar > 120 mg/dl (1 = true; 0 = false)
  • restecg: Resting electrocardiographic results
  • thalach: Maximum heart rate achieved
  • exang: Exercise induced angina (1 = yes; 0 = no)
  • oldpeak: ST depression induced by exercise relative to rest
  • slope: The slope of the peak exercise ST segment
  • ca: Number of major vessels (0-4) colored by fluoroscopy
  • thal: 0 = null, 1 = fixed defect found, 2 = blood flow is normal, 3 = reversible defect found
  • target: 1 = Heart disease present, 0 = Heart disease not present

Exploratory Data Analysis

Check Missing Values

heart_data.isna().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64

Great, there is no missing value in our data.

Class Proportion

class_prop = heart_data.target.value_counts()
class_prop.plot.bar()
plt.xticks([0, 1], labels=['Have Heart Disease', 'No Heart Disease'], rotation=0)
plt.show()

(class_prop / class_prop.sum() * 100).round(2).astype('str') + ' %'
1    54.25 %
0    45.75 %
Name: target, dtype: object

The target variable is considered as balanced with 54:46 proportion.

Correlation Heatmap

plt.subplots(figsize=(10, 10))
sns.heatmap(heart_data.corr(), annot=True, linewidths=.5)
plt.show()

Note: There are no alarming (strong) correlation between each features, which is great when we are going to build a linear model to avoid strong multicollinearity.

Distribution of Numerical Features

num_col = ['age', 'trestbps', 'chol',
           'restecg', 'thalach', 'oldpeak', 'slope', 'ca']

fig, axes = plt.subplots(2, 4, figsize=(20, 5))

for ax, col in zip(axes.flat, num_col):
    sns.kdeplot(data=heart_data, x=col, hue='target', ax=ax)
    ax.set_title(f'{col.upper()} DISTRIBUTION')

plt.tight_layout()
plt.show()

Note: Based on the visualization above, heart disease is more likely present in lower age, higher restecg result, higher thalach, lower oldpeak, higher slope, and lower ca. There are only slight difference of trestbps and chol distribution between patient with and without heart disease.

cat_col = ['sex', 'cp', 'fbs', 'exang', 'thal']

# label for x-axis
xlab_list = [['Female', 'Male'],
             ['0 (Typical)', '1 (Atypical)',
              '2 (Non-anginal)', '3 (Asymptomatic)'],
             ['False', 'True'],
             ['No', 'Yes'],
             ['0 (Null)', '1 (Fixed)', '2 (Normal)', '3 (Reversible)']]

fig, axes = plt.subplots(2, 3, figsize=(15, 5))

for ax, col, xlab in zip(axes.flat, cat_col, xlab_list):
    sns.countplot(data=heart_data, x=col, hue='target', ax=ax)
    ax.set_title(f'{col.upper()} PROPORTION')
    ax.set_xticklabels(xlab, rotation=10)

axes[-1, -1].set_visible(False)  # turn off last axis
plt.tight_layout()
plt.show()

Based on the visualization above:

  • sex: Proportion of female with heart disease is more than those who doesn't have, while male is nearly the same.
  • cp: Patient with cp=0 is more likely doesn't have heart disease than other cp type.
  • fbs: The proportion is nearly the same.
  • exang: Patient with exang is more likely doesn't have heart disease than those who no exang.
  • thal: Patient with thal=2 is more likely have heart disease than other thal type.

Data Preprocessing

Let's prepare the data before model fitting:

  1. Feature-target splitting
  2. Train-test splitting

# Feature-target splitting
X = heart_data.drop('target', axis=1)
y = heart_data.target

# Train-test splitting
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
print(f'X_train shape: {X_train.shape}')
print(f'X_test shape: {X_test.shape}')
print(f'y_train shape: {y_train.shape}')
print(f'y_test shape: {y_test.shape}')
X_train shape: (169, 13)
X_test shape: (43, 13)
y_train shape: (169,)
y_test shape: (43,)
  1. Scale the data using StandardScaler to its Z-score.
scaler = StandardScaler()
X_train_scale = pd.DataFrame(scaler.fit_transform(X_train), columns=X.columns)
X_test_scale = pd.DataFrame(scaler.transform(X_test), columns=X.columns)

Important: Only apply .fit() method on the training set, then use the fitted scaler to transform the testing set.

After standarization, the mean of X_train should be around 0 and the standard deviation (std) should be around 1.

pd.DataFrame({'mean': X_train_scale.mean(),
              'std': X_train_scale.std()})
mean std
age 1.716247e-16 1.002972
sex 5.124106e-17 1.002972
cp -9.591276e-17 1.002972
trestbps -4.795638e-17 1.002972
chol -5.091259e-18 1.002972
fbs 5.781043e-17 1.002972
restecg 2.496359e-17 1.002972
thalach -3.416071e-17 1.002972
exang -1.510954e-16 1.002972
oldpeak -1.267888e-16 1.002972
slope 1.327012e-16 1.002972
ca -1.313873e-17 1.002972
thal -3.521181e-16 1.002972

Logistic Regression

We fit the scaled X_train to the default LogisticRegression model.

lr = LogisticRegression()
lr.fit(X_train_scale, y_train)
LogisticRegression()

Evaluate the model using the F1-score. This is the harmonic mean between precision and recall as follows:

$F_1 = 2 \dfrac{Precision \times Recall}{Precision + Recall}$

y_pred_train = lr.predict(X_train_scale)
y_pred_test = lr.predict(X_test_scale)
print(f'F1 score on train set: {f1_score(y_train, y_pred_train):.5f}')
print(f'F1 score on test set: {f1_score(y_test, y_pred_test):.5f}')
F1 score on train set: 0.86813
F1 score on test set: 0.84000

The performance of our model is quite good and doesn't overfit the training data.

SHAP Explainer

The SHAP values explain the output of a model (function) as a sum of the effect of each feature being introduced into a conditional expectation. It results from averaging over all possible orderings, since the order of introduced feature matters. The use of SHAP values can be detailed into two:

  1. Global interpretability - the collective SHAP values can show the contribution of each predictor, either positively or negatively, to the target variable.

  2. Local interpretability - each observation have its own set of SHAP values. This enable us to pinpoint and contrast the impact of predictor for each individual case.

explainer = shap.LinearExplainer(model=lr, masker=X_train_scale)
explain_result = explainer(X_test_scale)
shap_values = explain_result.values  # explainer.shap_values(X_test_scale)
shap_df = pd.DataFrame(shap_values, columns=X.columns)
shap_df.head()
The feature_perturbation option is now deprecated in favor of using the appropriate masker (maskers.Independent, or maskers.Impute)
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal
0 0.086223 -0.279619 1.159063 -0.240417 -0.034219 -0.055162 -0.673186 0.473878 0.570792 0.372988 0.318451 0.317311 0.357128
1 -0.023755 0.622378 1.159063 0.037522 -0.030002 0.005456 -0.673186 -1.113122 0.570792 0.372988 0.318451 0.317311 3.002523
2 0.020236 0.622378 0.159871 -0.055124 0.076274 0.005456 0.412598 0.505618 0.570792 -0.254762 0.318451 0.317311 0.357128
3 -0.287704 -0.279619 0.159871 0.454429 -0.013976 0.005456 0.412598 0.092998 0.570792 0.372988 0.318451 0.317311 0.357128
4 0.218197 -0.279619 2.158256 -0.935263 -0.020724 0.005456 -0.673186 0.156478 0.570792 0.059113 -0.359105 0.317311 -0.965569

logit = explainer.expected_value
odds = np.exp(logit)
prob = odds/(1+odds)
print(f'BASE VALUE LOGIT (Expected Value): {logit}')
print(f'BASE VALUE PROB (Base Value): {prob}')
BASE VALUE LOGIT (Expected Value): 0.2977444318944053
BASE VALUE PROB (Base Value): 0.5738910320745784

The expected value refers to the average prediction on X_train_scale, also called as the base value on the force plot.

SHAP Force Plot

Individual SHAP Force Plot

Create a force plot for the first observation (index 0) of the test data.

X_test.iloc[[0]]
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal
206 58 1 2 140 211 1 0 165 0 0.0 2 0 2
shap.initjs()
shap.force_plot(base_value=explainer.expected_value,
                shap_values=shap_values[0],  # first row of test data
                features=X_test.iloc[0],  # first row of test data
                link='logit')  # transforms log-odds to probability
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Tip: To visualize the force plot on fastpages, we have to manually add bundle.js inside _includes/custom-head.html.

The visualization above is called as an individual SHAP force plot for local interpretability, with components:

  1. $f(x)$ or output value is the prediction for that observation. The predicted probability of the first patient of X_test_scale is 0.94.
  2. Base value is the explainer expected value. The mean prediction probability is 0.5739.
  3. Color of the bar indicates if the feature support/contradicts the prediction value. Red means that the feature is supporting the positive class (higher the chances of heart disease). Meanwhile, blue means that the feature is contradicting the positive class (lower the chances of heart disease).

Note: Based on the color of the force plot, we can infer that the features cp, exang, thalach, oldpeak, thal, slope, ca, and age impact the output value of the first row X_test in positive way, meanwhile the other features (restecg, sex, trestbps, fbs, chol) impact negatively.

The force plot can be visualized as a waterfall plot to get the ranking of local important features.

Note: ignore all the numbers inside the plot below, since the SHAP values is in logit (not probability) and the feature values is standardized, we just interested in the ranking.
shap.plots.waterfall(explain_result[0],  # first row of test data
                     max_display=len(X.columns))

Let's infer each of the features by comparing it by the average values (whether the value is higher or lower). Remember the SHAP model is built on the training data set, so we compare with X_train.mean().

(X_test.iloc[0] > X_train.mean()).apply(lambda x: "Higher" if x else "Lower")
age         Higher
sex         Higher
cp          Higher
trestbps    Higher
chol         Lower
fbs         Higher
restecg      Lower
thalach     Higher
exang        Lower
oldpeak      Lower
slope       Higher
ca           Lower
thal         Lower
dtype: object

Based on the waterfall plot, it turns out that the top three most important feature for the first patient on testing data is cp, restecg, and exang. Let's interpret it as follows:

  • cp has a positive direction on first patient's probability of heart disease. The chest pain is type 2, which is higher than the average data.
  • restecg has a negative direction on first patient's probability of heart disease. The resting electrocardiographic results is 0, which is lower than the average data.
  • exang has a positive direction on first patient's probability of heart disease. The exercise induced angina is 0, which is lower than the average data.

Multiple SHAP Force Plot

Create a force plot for the all observations of the test data.

shap.initjs()
shap.force_plot(base_value=explainer.expected_value,
                shap_values=shap_values,  # all rows of test data
                    features=X_test,  # all rows of test data
                link='logit')  # transforms log-odds to probability
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Instead of plotting force plot for one observation, we can plot it for all observation as shown above. The x-axis refer to the observation index, the y-axis refer to the predicted probability (which centered around its base value, 0.5739).

There are three main options for the x-axis:

  1. Sample order by similarity, this order determined by the clustering algorithm behind the function (Reference). If we see carefully, the index 10 until 20 have lower probability of heart disease and have more blue than red region. Meaning that a lot of their features value decrease the probability of heart disease.
  2. Sample order by output value, this order the observation in descending manner.
  3. Original sample ordering, this shows the order as it is according to the row index on X_test.

The interpretation for the plot above is just the same as the section before. When hovered, the features labeled in red means pulling the probability of heart disease higher. On the other hand, the features labeled in blue means pulling the probability of heart disease lower.

The y-axis can be alter so that we can see the partial SHAP values for each features. Example for the first three options:

  • cp effects: Having cp = 0 (blue) decrease the predicted probability by 0.2061 (from base value to 0.3678), the other cp (red) will increase it to the corresponding hovered probability instead.
  • thal effects: Having thal = 3 (blue) decrease the predicted probability by 0.2349 (from base value to 0.339), the other thal (red) will increase it to the corresponding hovered probability instead.
  • exang effects: Having exang = 0 (red) increase the predicted probability by 0.1305 (from base value to 0.7044), exang = 1 (blue) decrease it by 0.2661 (from base value to 0.3078) instead.

SHAP Summary Plot

shap.plots.bar(explain_result,  # all rows of test data
               max_display=len(X.columns))
# shap.summary_plot(shap_values, features=X_test, plot_type='bar')

The visualization above is called as a summary or variable importance plot for global interpretability. It lists the most significant features in descending order by the mean absolute SHAP values of all observations. So, in this case, top three features that have high predictive power are: cp, thal, and exang. Meanwhile, the bottom three features are: fbs, chol, and age.

The summary plot above can further be plotted as a violin plot to show the positive and negative relationships with the target variable.

shap.summary_plot(shap_values, features=X_test, plot_type='violin')

The violin plot shows the distribution of SHAP values for each feature. Red indicates a higher feature values, while blue indicates a lower feature values. For example, the higher the cp value, the more positive SHAP value, means it has positive impact globally on the model output value. On the other hand, the higher the thal value, the more negative SHAP value, means it has negative impact globally on the model output value.

So we can conclude that globally cp, restecg, thalach, slope, age, chol have positive impact while the other features (thal, exang, oldpeak, ca, sex, trestbps, fbs) have negative impact.

SHAP Dependence Plot

fig, axes = plt.subplots(2, 7, figsize=(20, 5))

for ax, col in zip(axes.flat, X_test.columns):
    shap.dependence_plot(ind=col,
                         interaction_index=col,
                         shap_values=shap_values,  # all rows of test data
                         features=X_test,  # all rows of test data
                         display_features=X_test,  # all rows of test data
                         show=False,
                         ax=ax)

    ax.set_title(f'{col.upper()} DEPENDENCE PLOT')
axes[-1,-1].set_visible(False)
plt.tight_layout()
plt.show()
Passing parameters norm and vmin/vmax simultaneously is deprecated since 3.3 and will become an error two minor releases later. Please pass vmin/vmax directly to the norm when creating it.

The plot above shows how the value of each features correlate with its SHAP value. The trend is linear since the model we built is Logistic Regression which is a linear model.

The conclusion is exactly the same thing as before, on the violin plot of summary plot: We can conclude that globally cp, restecg, thalach, slope, age, chol have positive impact while the other features (thal, exang, oldpeak, ca, sex, trestbps, fbs) have negative impact. The difference is that, using summary plot we can ranked the features based on its importance too.

SHAP Decision plot

This is the section to explore what more we can do with SHAP, one of them is decision plot. It shows how the model make decisions and arrive at the prediction value.

r = shap.decision_plot(base_value=explainer.expected_value,
                       shap_values=shap_values,  # all rows of test data
                       features=X_test,  # all rows of test data
                       link='logit',  # transforms log-odds to probability
                       highlight=0,  # highlight first row observation
                       return_objects=True)  # returning the plot structures

The visualization above is the decision plot for all rows in test data. Moving from the bottom to the top, SHAP values for each feature are added to the model's base value. This shows how each feature contributes to the overall prediction.

  • The x-axis represents the model output value, in this case is the probability of heart disease (transformed using link='logit'). The plot is centered on the x-axis at base value.
  • The y-axis lists the features, ordered by descending importance.
  • The colored line represents each observation's decision of prediction based on the features.

We also highlight the first row on the test set, represented by a dotted line. To make it appear clear, we can separate it on a different decision plot as below:

shap.decision_plot(base_value=explainer.expected_value,
                   shap_values=shap_values[0],  # first row of test data
                   features=X_test.iloc[0,:],  # first row of test data
                   link='logit',  # transforms log-odds to probability
                   feature_order=r.feature_idx)  # preserving order of features from previous plot

The plot above is basically the combination of force plot and local importance plot, where we can know:

  • The impact of each feature positively (to the right) or negatively (to the left)
  • The ranking of features based on the importance (SHAP values)

A decision plot can expose a model’s typical prediction paths. Here, we plot all of the predictions in the probability interval [0.9, 1.0] to see what high-scoring predictions have in common. We use feature_order='hclust' to group similar prediction paths.

high_prob = np.where(lr.predict_proba(X_test_scale)[:,1] >= 0.9)[0]
shap.decision_plot(base_value=explainer.expected_value,
                   shap_values=shap_values[high_prob],  # rows with high probability
                   features=X_test.iloc[high_prob,:],  # rows with high probability
                   feature_order='hclust',  # hierarchical clustering
                   link='logit')  # transforms log-odds to probability

There are other uses of decision plot, such as:

  • Show a large number of feature effects clearly
  • Visualize multioutput predictions
  • Display the cumulative effect of interactions
  • Explore feature effects for a range of feature values
  • Identify outliers
  • Compare and contrast predictions for several models

Note: Reference of SHAP Decision Plot