Explain Iris classification by SHAP


This post aims to introduce how to explain Iris classification by SHAP.



In [8]:
import sklearn
from sklearn.model_selection import train_test_split
import numpy as np
import shap
import time

Load Iris Data

In [2]:
X_train, X_test, Y_train, Y_test = train_test_split(
    *shap.datasets.iris(), test_size=0.2, random_state=0)
In [3]:
# Predictor 
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
137 6.4 3.1 5.5 1.8
84 5.4 3.0 4.5 1.5
27 5.2 3.5 1.5 0.2
127 6.1 3.0 4.9 1.8
132 6.4 2.8 5.6 2.2
In [19]:
# Label 
array([2, 1, 0, 2, 2])

Train K-nearest neighbors

In [4]:
clf = sklearn.neighbors.KNeighborsClassifier()
clf.fit(X_train, Y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=5, p=2,

Create an explainer

In [10]:
explainer = shap.KernelExplainer(clf.predict_proba, X_train)
Using 120 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.

Use summarized X by k-measn

In [21]:
X_train_summary = shap.kmeans(X_train, 50)
explainer = shap.KernelExplainer(clf.predict_proba, X_train_summary)

Explain one test prediction

In [22]:
shap_values = explainer.shap_values(X_test.iloc[0, :])
shap.force_plot(explainer.expected_value[0], shap_values[0], X_test.iloc[0, :])
/Users/hiro/anaconda3/envs/py367/lib/python3.6/site-packages/shap/explainers/kernel.py:545: UserWarning: l1_reg="auto" is deprecated and in the next version (v0.29) the behavior will change from a conditional use of AIC to simply "num_features(10)"!
  "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.


Comments powered by Disqus