Posts about Interpretability

Explain the prediction for ImageNet using SHAP

Goal

This post aims to introduce how to explain the prediction for ImageNet using SHAP.

image

Reference

Libraries

In [82]:
import keras
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.preprocessing import image
import requests
from skimage.segmentation import slic
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
import warnings

%matplotlib inline

Configuration

In [77]:
# make a color map
from matplotlib.colors import LinearSegmentedColormap
colors = []
for l in np.linspace(1, 0, 100):
    colors.append((245 / 255, 39 / 255, 87 / 255, l))
for l in np.linspace(0, 1, 100):
    colors.append((24 / 255, 196 / 255, 93 / 255, l))
cm = LinearSegmentedColormap.from_list("shap", colors)

Load pre-trained VGG16 model

In [2]:
# load model data
r = requests.get('https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json')
feature_names = r.json()
model = VGG16()
WARNING:tensorflow:From /Users/hiro/anaconda3/envs/py367/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels.h5
553467904/553467096 [==============================] - 62s 0us/step

Load an image data

In [16]:
# load an image
file = "../images/apple-banana.jpg"
img = image.load_img(file, target_size=(224, 224))
img_orig = image.img_to_array(img)
plt.imshow(img);
plt.axis('off');

Segmentation

In [18]:
# Create segmentation to explain by segment, not every pixel
segments_slic = slic(img, n_segments=30, compactness=30, sigma=3)

plt.imshow(segments_slic);
plt.axis('off');
In [37]:
segments_slic
Out[37]:
array([[ 0,  0,  0, ...,  4,  4,  4],
       [ 0,  0,  0, ...,  4,  4,  4],
       [ 0,  0,  0, ...,  4,  4,  4],
       ...,
       [22, 22, 22, ..., 21, 21, 21],
       [22, 22, 22, ..., 21, 21, 21],
       [22, 22, 22, ..., 21, 21, 21]])

Utility Functions for masking and preprocessing

In [19]:
# define a function that depends on a binary mask representing if an image region is hidden
def mask_image(zs, segmentation, image, background=None):
    
    if background is None:
        background = image.mean((0, 1))
        
    # Create an empty 4D array
    out = np.zeros((zs.shape[0], 
                    image.shape[0], 
                    image.shape[1], 
                    image.shape[2]))
    
    for i in range(zs.shape[0]):
        out[i, :, :, :] = image
        for j in range(zs.shape[1]):
            if zs[i, j] == 0:
                out[i][segmentation == j, :] = background
    return out


def f(z):
    return model.predict(
        preprocess_input(mask_image(z, segments_slic, img_orig, 255)))

def fill_segmentation(values, segmentation):
    out = np.zeros(segmentation.shape)
    for i in range(len(values)):
        out[segmentation == i] = values[i]
    return out
In [39]:
masked_images = mask_image(np.zeros((1,50)), segments_slic, img_orig, 255)

plt.imshow(masked_images[0][:,:, 0]);
plt.axis('off');

Create an explainer and shap values

In [83]:
# use Kernel SHAP to explain the network's predictions
explainer = shap.KernelExplainer(f, np.zeros((1,50)))

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    shap_values = explainer.shap_values(np.ones((1,50)), nsamples=100) # runs VGG16 1000 times

Obtain the prediction with the highest probability

In [85]:
predictions = model.predict(preprocess_input(np.expand_dims(img_orig.copy(), axis=0)))
top_preds = np.argsort(-predictions)
In [86]:
pd.Series(data={feature_names[str(inds[i])][1]:predictions[0, inds[i]]  for i in range(10)}).plot(kind='bar', title='Top 10 Predictions');

Explain the prediction by visualization

The following image is explained well for banana.

In [87]:
# Visualize the explanations
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(12,4))
inds = top_preds[0]
axes[0].imshow(img)
axes[0].axis('off')

max_val = np.max([np.max(np.abs(shap_values[i][:,:-1])) for i in range(len(shap_values))])
for i in range(3):
    m = fill_segmentation(shap_values[inds[i]][0], segments_slic)
    axes[i+1].set_title(feature_names[str(inds[i])][1])
    axes[i+1].imshow(np.array(img.convert('LA'))[:, :, 0], alpha=0.15)
    im = axes[i+1].imshow(m, cmap=cm, vmin=-max_val, vmax=max_val)
    axes[i+1].axis('off')
cb = fig.colorbar(im, ax=axes.ravel().tolist(), label="SHAP value", orientation="horizontal", aspect=60)
cb.outline.set_visible(False)
plt.show()

Explain Iris classification by SHAP

Goal

This post aims to introduce how to explain Iris classification by SHAP.

Reference

Libraries

In [8]:
import sklearn
from sklearn.model_selection import train_test_split
import numpy as np
import shap
import time
shap.initjs()

Load Iris Data

In [2]:
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.iris(), test_size=0.2, random_state=0)
In [3]:
# Predictor 
X_train.head()
Out[3]:
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
137 6.4 3.1 5.5 1.8
84 5.4 3.0 4.5 1.5
27 5.2 3.5 1.5 0.2
127 6.1 3.0 4.9 1.8
132 6.4 2.8 5.6 2.2
In [19]:
# Label 
Y_train[:5]
Out[19]:
array([2, 1, 0, 2, 2])

Train K-nearest neighbors

In [4]:
clf = sklearn.neighbors.KNeighborsClassifier()
clf.fit(X_train, Y_train)
Out[4]:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=5, p=2,
           weights='uniform')

Create an explainer

In [10]:
explainer = shap.KernelExplainer(clf.predict_proba, X_train)
Using 120 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.

Use summarized X by k-measn

In [21]:
X_train_summary = shap.kmeans(X_train, 50)
explainer = shap.KernelExplainer(clf.predict_proba, X_train_summary)

Explain one test prediction

In [22]:
shap_values = explainer.shap_values(X_test.iloc[0, :])
shap.force_plot(explainer.expected_value[0], shap_values[0], X_test.iloc[0, :])
/Users/hiro/anaconda3/envs/py367/lib/python3.6/site-packages/shap/explainers/kernel.py:545: UserWarning: l1_reg="auto" is deprecated and in the next version (v0.29) the behavior will change from a conditional use of AIC to simply "num_features(10)"!
  "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
Out[22]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Explain Image Classification by SHAP Deep Explainer

Goal

This post aims to introduce how to explain Image Classification (trained by PyTorch) via SHAP Deep Explainer.

Shap is the module to make the black box model interpretable. For example, image classification tasks can be explained by the scores on each pixel on a predicted image, which indicates how much it contributes to the probability positively or negatively.

image

Reference

Interpretability of prediction for Boston Housing using SHAP

Goal

This post aims to introduce how to interpret the prediction for Boston Housing using shap.

What is SHAP?

SHAP is a module for making a prediction by some machine learning models interpretable, where we can see which feature variables have an impact on the predicted value. In other words, it can calculate SHAP values, i.e., how much the predicted variable would be increased or decreased by a certain feature variable.

Reference

Libraries

In [11]:
import xgboost
import shap
shap.initjs()

Load Boston Housing Dataset

In [3]:
X, y = shap.datasets.boston()
X[:5]
Out[3]:
CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT
0 0.00632 18.0 2.31 0.0 0.538 6.575 65.2 4.0900 1.0 296.0 15.3 396.90 4.98
1 0.02731 0.0 7.07 0.0 0.469 6.421 78.9 4.9671 2.0 242.0 17.8 396.90 9.14
2 0.02729 0.0 7.07 0.0 0.469 7.185 61.1 4.9671 2.0 242.0 17.8 392.83 4.03
3 0.03237 0.0 2.18 0.0 0.458 6.998 45.8 6.0622 3.0 222.0 18.7 394.63 2.94
4 0.06905 0.0 2.18 0.0 0.458 7.147 54.2 6.0622 3.0 222.0 18.7 396.90 5.33
In [4]:
y[:5]
Out[4]:
array([24. , 21.6, 34.7, 33.4, 36.2])

Train a predictor by xgboost

In [10]:
d_param = {
    "learning_rate": 0.01
}

model = xgboost.train(params=d_param,
                      dtrain=xgboost.DMatrix(X, label=y), 
                      num_boost_round=100)

Create an explainer

In [12]:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

Outcome of SHAP

Single prediction explainer

The visualization below shows the explanations for one prediction based on i-th data.

  • red: positive impacts on the prediction
  • blue: negative impacts on the prediction
In [19]:
i = 0
shap.force_plot(explainer.expected_value, shap_values[i,:], X.iloc[i,:])
Out[19]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

All prediction explainers

All explainers like the above are plotted in one graph as below.

In [14]:
shap.summary_plot(shap_values, X, plot_type="violin")

Variable importance

This variable importance shown as below simply aggregates the above by computing the sum of the absolute values of shap values for all data points.

In [13]:
shap.summary_plot(shap_values, X, plot_type="bar")

Force Plot

The other way of visualizing shap values are the one to stack all shap values across samples or feature values themselves.

In [16]:
shap.force_plot(explainer.expected_value, shap_values, X)
Out[16]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Dependency Plot

This plot shows a certain value and its shap value as a scatter plot with the color specified by automatically selected variable, which separates most the certain value and its shap value.

In [25]:
# specify by the index of the features
shap.dependence_plot(ind=12, shap_values=shap_values, features=X)
In [17]:
# specify by the feature name
shap.dependence_plot(ind="RM", shap_values=shap_values, features=X)