Explainable Heart Disease Classifier with Shapley Additive Explanations (SHAP)
A simple workflow to classify whether a patient has a heart disease or not using a Logistic Regression model. SHAP explainer is used to further explain the model decision via several plots, such as SHAP force, summary, dependence, and decision plot.
- Problem Statement
- Objective
- Import Libraries
- Data Loading
- Exploratory Data Analysis
- Data Preprocessing
- Logistic Regression
- SHAP Explainer
Problem Statement
Heart disease describes a range of conditions that affect your heart. With growing stress, the number of cases of heart diseases are increasing rapidly.
According to the World Health Organisation (WHO), Cardiovascular diseases (CVDs) are the number 1 cause of death globally, taking an estimated 17.9 million lives each year. 17.9 million people die each year from CVDs, an estimated 31% of all deaths worldwide.
The doctors of Health Hospital in Zastra wish to incorporate Data Science into their workings. Seeing the rising cases of heart diseases, they are specially interested in predicting the presence of heart disease in a person using some existing data.
import numpy as np
import pandas as pd
# data visualization
import matplotlib.pyplot as plt
import seaborn as sns
plt.style.use('seaborn')
# modeling
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
# explainable AI
import shap
from IPython.core.display import display, HTML
Data Loading
Load the data which is downloaded from here
heart_data = pd.read_csv("data_input/heart_disease.csv")
heart_data.head()
Data description:
-
age
: Age in years -
sex
: 1 = male, 0 = female -
cp
: Chest pain type -
trestbps
: Resting blood pressure (in mm Hg on admission to the hospital) -
chol
: serum cholesterol in mg/dl -
fbs
: fasting blood sugar > 120 mg/dl (1 = true; 0 = false) -
restecg
: Resting electrocardiographic results -
thalach
: Maximum heart rate achieved -
exang
: Exercise induced angina (1 = yes; 0 = no) -
oldpeak
: ST depression induced by exercise relative to rest -
slope
: The slope of the peak exercise ST segment -
ca
: Number of major vessels (0-4) colored by fluoroscopy -
thal
: 0 = null, 1 = fixed defect found, 2 = blood flow is normal, 3 = reversible defect found -
target
: 1 = Heart disease present, 0 = Heart disease not present
heart_data.isna().sum()
Great, there is no missing value in our data.
class_prop = heart_data.target.value_counts()
class_prop.plot.bar()
plt.xticks([0, 1], labels=['Have Heart Disease', 'No Heart Disease'], rotation=0)
plt.show()
(class_prop / class_prop.sum() * 100).round(2).astype('str') + ' %'
The target variable is considered as balanced with 54:46 proportion.
plt.subplots(figsize=(10, 10))
sns.heatmap(heart_data.corr(), annot=True, linewidths=.5)
plt.show()
num_col = ['age', 'trestbps', 'chol',
'restecg', 'thalach', 'oldpeak', 'slope', 'ca']
fig, axes = plt.subplots(2, 4, figsize=(20, 5))
for ax, col in zip(axes.flat, num_col):
sns.kdeplot(data=heart_data, x=col, hue='target', ax=ax)
ax.set_title(f'{col.upper()} DISTRIBUTION')
plt.tight_layout()
plt.show()
age
, higher restecg
result, higher thalach
, lower oldpeak
, higher slope
, and lower ca
. There are only slight difference of trestbps
and chol
distribution between patient with and without heart disease.
cat_col = ['sex', 'cp', 'fbs', 'exang', 'thal']
# label for x-axis
xlab_list = [['Female', 'Male'],
['0 (Typical)', '1 (Atypical)',
'2 (Non-anginal)', '3 (Asymptomatic)'],
['False', 'True'],
['No', 'Yes'],
['0 (Null)', '1 (Fixed)', '2 (Normal)', '3 (Reversible)']]
fig, axes = plt.subplots(2, 3, figsize=(15, 5))
for ax, col, xlab in zip(axes.flat, cat_col, xlab_list):
sns.countplot(data=heart_data, x=col, hue='target', ax=ax)
ax.set_title(f'{col.upper()} PROPORTION')
ax.set_xticklabels(xlab, rotation=10)
axes[-1, -1].set_visible(False) # turn off last axis
plt.tight_layout()
plt.show()
Based on the visualization above:
-
sex
: Proportion of female with heart disease is more than those who doesn't have, while male is nearly the same. -
cp
: Patient withcp=0
is more likely doesn't have heart disease than othercp
type. -
fbs
: The proportion is nearly the same. -
exang
: Patient withexang
is more likely doesn't have heart disease than those who noexang
. -
thal
: Patient withthal=2
is more likely have heart disease than otherthal
type.
# Feature-target splitting
X = heart_data.drop('target', axis=1)
y = heart_data.target
# Train-test splitting
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
print(f'X_train shape: {X_train.shape}')
print(f'X_test shape: {X_test.shape}')
print(f'y_train shape: {y_train.shape}')
print(f'y_test shape: {y_test.shape}')
- Scale the data using
StandardScaler
to its Z-score.
scaler = StandardScaler()
X_train_scale = pd.DataFrame(scaler.fit_transform(X_train), columns=X.columns)
X_test_scale = pd.DataFrame(scaler.transform(X_test), columns=X.columns)
.fit()
method on the training set, then use the fitted scaler
to transform the testing set.
After standarization, the mean of X_train
should be around 0 and the standard deviation (std) should be around 1.
pd.DataFrame({'mean': X_train_scale.mean(),
'std': X_train_scale.std()})
lr = LogisticRegression()
lr.fit(X_train_scale, y_train)
Evaluate the model using the F1-score. This is the harmonic mean between precision and recall as follows:
$F_1 = 2 \dfrac{Precision \times Recall}{Precision + Recall}$
y_pred_train = lr.predict(X_train_scale)
y_pred_test = lr.predict(X_test_scale)
print(f'F1 score on train set: {f1_score(y_train, y_pred_train):.5f}')
print(f'F1 score on test set: {f1_score(y_test, y_pred_test):.5f}')
The performance of our model is quite good and doesn't overfit the training data.
SHAP Explainer
The SHAP values explain the output of a model (function) as a sum of the effect of each feature being introduced into a conditional expectation. It results from averaging over all possible orderings, since the order of introduced feature matters. The use of SHAP values can be detailed into two:
-
Global interpretability - the collective SHAP values can show the contribution of each predictor, either positively or negatively, to the target variable.
-
Local interpretability - each observation have its own set of SHAP values. This enable us to pinpoint and contrast the impact of predictor for each individual case.
explainer = shap.LinearExplainer(model=lr, masker=X_train_scale)
explain_result = explainer(X_test_scale)
shap_values = explain_result.values # explainer.shap_values(X_test_scale)
shap_df = pd.DataFrame(shap_values, columns=X.columns)
shap_df.head()
logit = explainer.expected_value
odds = np.exp(logit)
prob = odds/(1+odds)
print(f'BASE VALUE LOGIT (Expected Value): {logit}')
print(f'BASE VALUE PROB (Base Value): {prob}')
The expected value refers to the average prediction on X_train_scale
, also called as the base value on the force plot.
X_test.iloc[[0]]
shap.initjs()
shap.force_plot(base_value=explainer.expected_value,
shap_values=shap_values[0], # first row of test data
features=X_test.iloc[0], # first row of test data
link='logit') # transforms log-odds to probability
fastpages
, we have to manually add bundle.js
inside _includes/custom-head.html
.
The visualization above is called as an individual SHAP force plot for local interpretability, with components:
- $f(x)$ or output value is the prediction for that observation. The predicted probability of the first patient of
X_test_scale
is 0.94. - Base value is the explainer expected value. The mean prediction probability is 0.5739.
- Color of the bar indicates if the feature support/contradicts the prediction value. Red means that the feature is supporting the positive class (higher the chances of heart disease). Meanwhile, blue means that the feature is contradicting the positive class (lower the chances of heart disease).
cp
, exang
, thalach
, oldpeak
, thal
, slope
, ca
, and age
impact the output value of the first row X_test
in positive way, meanwhile the other features (restecg
, sex
, trestbps
, fbs
, chol
) impact negatively.
The force plot can be visualized as a waterfall plot to get the ranking of local important features.
shap.plots.waterfall(explain_result[0], # first row of test data
max_display=len(X.columns))
Let's infer each of the features by comparing it by the average values (whether the value is higher or lower). Remember the SHAP model is built on the training data set, so we compare with X_train.mean()
.
(X_test.iloc[0] > X_train.mean()).apply(lambda x: "Higher" if x else "Lower")
Based on the waterfall plot, it turns out that the top three most important feature for the first patient on testing data is cp
, restecg
, and exang
. Let's interpret it as follows:
-
cp
has a positive direction on first patient's probability of heart disease. The chest pain is type 2, which is higher than the average data. -
restecg
has a negative direction on first patient's probability of heart disease. The resting electrocardiographic results is 0, which is lower than the average data. -
exang
has a positive direction on first patient's probability of heart disease. The exercise induced angina is 0, which is lower than the average data.
shap.initjs()
shap.force_plot(base_value=explainer.expected_value,
shap_values=shap_values, # all rows of test data
features=X_test, # all rows of test data
link='logit') # transforms log-odds to probability
Instead of plotting force plot for one observation, we can plot it for all observation as shown above. The x-axis refer to the observation index, the y-axis refer to the predicted probability (which centered around its base value, 0.5739).
There are three main options for the x-axis:
- Sample order by similarity, this order determined by the clustering algorithm behind the function (Reference). If we see carefully, the index 10 until 20 have lower probability of heart disease and have more blue than red region. Meaning that a lot of their features value decrease the probability of heart disease.
- Sample order by output value, this order the observation in descending manner.
-
Original sample ordering, this shows the order as it is according to the row index on
X_test
.
The interpretation for the plot above is just the same as the section before. When hovered, the features labeled in red means pulling the probability of heart disease higher. On the other hand, the features labeled in blue means pulling the probability of heart disease lower.
The y-axis can be alter so that we can see the partial SHAP values for each features. Example for the first three options:
-
cp
effects: Havingcp = 0
(blue) decrease the predicted probability by 0.2061 (from base value to 0.3678), the othercp
(red) will increase it to the corresponding hovered probability instead. -
thal
effects: Havingthal = 3
(blue) decrease the predicted probability by 0.2349 (from base value to 0.339), the otherthal
(red) will increase it to the corresponding hovered probability instead. -
exang
effects: Havingexang = 0
(red) increase the predicted probability by 0.1305 (from base value to 0.7044),exang = 1
(blue) decrease it by 0.2661 (from base value to 0.3078) instead.
shap.plots.bar(explain_result, # all rows of test data
max_display=len(X.columns))
# shap.summary_plot(shap_values, features=X_test, plot_type='bar')
The visualization above is called as a summary or variable importance plot for global interpretability. It lists the most significant features in descending order by the mean absolute SHAP values of all observations. So, in this case, top three features that have high predictive power are: cp
, thal
, and exang
. Meanwhile, the bottom three features are: fbs
, chol
, and age
.
The summary plot above can further be plotted as a violin plot to show the positive and negative relationships with the target variable.
shap.summary_plot(shap_values, features=X_test, plot_type='violin')
The violin plot shows the distribution of SHAP values for each feature. Red indicates a higher feature values, while blue indicates a lower feature values. For example, the higher the cp
value, the more positive SHAP value, means it has positive impact globally on the model output value. On the other hand, the higher the thal
value, the more negative SHAP value, means it has negative impact globally on the model output value.
So we can conclude that globally cp
, restecg
, thalach
, slope
, age
, chol
have positive impact while the other features (thal
, exang
, oldpeak
, ca
, sex
, trestbps
, fbs
) have negative impact.
fig, axes = plt.subplots(2, 7, figsize=(20, 5))
for ax, col in zip(axes.flat, X_test.columns):
shap.dependence_plot(ind=col,
interaction_index=col,
shap_values=shap_values, # all rows of test data
features=X_test, # all rows of test data
display_features=X_test, # all rows of test data
show=False,
ax=ax)
ax.set_title(f'{col.upper()} DEPENDENCE PLOT')
axes[-1,-1].set_visible(False)
plt.tight_layout()
plt.show()
The plot above shows how the value of each features correlate with its SHAP value. The trend is linear since the model we built is Logistic Regression which is a linear model.
The conclusion is exactly the same thing as before, on the violin plot of summary plot: We can conclude that globally cp
, restecg
, thalach
, slope
, age
, chol
have positive impact while the other features (thal
, exang
, oldpeak
, ca
, sex
, trestbps
, fbs
) have negative impact. The difference is that, using summary plot we can ranked the features based on its importance too.
r = shap.decision_plot(base_value=explainer.expected_value,
shap_values=shap_values, # all rows of test data
features=X_test, # all rows of test data
link='logit', # transforms log-odds to probability
highlight=0, # highlight first row observation
return_objects=True) # returning the plot structures
The visualization above is the decision plot for all rows in test data. Moving from the bottom to the top, SHAP values for each feature are added to the model's base value. This shows how each feature contributes to the overall prediction.
- The x-axis represents the model output value, in this case is the probability of heart disease (transformed using
link='logit'
). The plot is centered on the x-axis at base value. - The y-axis lists the features, ordered by descending importance.
- The colored line represents each observation's decision of prediction based on the features.
We also highlight the first row on the test set, represented by a dotted line. To make it appear clear, we can separate it on a different decision plot as below:
shap.decision_plot(base_value=explainer.expected_value,
shap_values=shap_values[0], # first row of test data
features=X_test.iloc[0,:], # first row of test data
link='logit', # transforms log-odds to probability
feature_order=r.feature_idx) # preserving order of features from previous plot
The plot above is basically the combination of force plot and local importance plot, where we can know:
- The impact of each feature positively (to the right) or negatively (to the left)
- The ranking of features based on the importance (SHAP values)
A decision plot can expose a model’s typical prediction paths. Here, we plot all of the predictions in the probability interval [0.9, 1.0] to see what high-scoring predictions have in common. We use feature_order='hclust'
to group similar prediction paths.
high_prob = np.where(lr.predict_proba(X_test_scale)[:,1] >= 0.9)[0]
shap.decision_plot(base_value=explainer.expected_value,
shap_values=shap_values[high_prob], # rows with high probability
features=X_test.iloc[high_prob,:], # rows with high probability
feature_order='hclust', # hierarchical clustering
link='logit') # transforms log-odds to probability
There are other uses of decision plot, such as:
- Show a large number of feature effects clearly
- Visualize multioutput predictions
- Display the cumulative effect of interactions
- Explore feature effects for a range of feature values
- Identify outliers
- Compare and contrast predictions for several models