Interpretability of Random Forest Prediction for MNIST classification using LIME
Goal¶
This post aims to introduce how to interpret Random Forest classification for MNIST image using LIME
, which generates an explainer for each prediction to help human beings to understand what happens in the prediction.
Reference
Libraries¶
In [55]:
import pandas as pd
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
%matplotlib inline
# Scikit Learn
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import Normalizer
from sklearn.model_selection import train_test_split
from skimage.color import gray2rgb, rgb2gray, label2rgb #
# LIME
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm
Load MNIST Dataset¶
In [3]:
mnist = load_digits()
In [7]:
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target,
train_size=0.55, test_size=0.45)
Create a preprocessing¶
In [61]:
class PipeStep(object):
"""
Wrapper for turning functions into pipeline transforms (no-fitting)
"""
def __init__(self, step_func):
self._step_func=step_func
def fit(self,*args):
return self
def transform(self,X):
return self._step_func(X)
makegray_step = PipeStep(lambda img_list: [rgb2gray(img) for img in img_list])
flatten_step = PipeStep(lambda img_list: [img.ravel() for img in img_list])
simple_rf_pipeline = Pipeline([
('Make Gray', makegray_step),
('Flatten Image', flatten_step),
#('Normalize', Normalizer()),
#('PCA', PCA(16)),
('RF', RandomForestClassifier())])
In [62]:
simple_rf_pipeline.fit(X_train, y_train)
Out[62]:
Train a Random Forest model¶
In [70]:
clf = RandomForestClassifier(n_estimators=10)
scores = cross_val_score(simple_rf_pipeline, X_train, y_train, cv=5)
simple_rf_pipeline.fit(X_train, y_train)
plt.plot(scores, '.-');
plt.title('Cross Validation by Random Forest for MNIST')
plt.xlabel('# of Validation');
plt.ylabel('Accuracy');
Create a LIME explainer¶
In [63]:
explainer = lime_image.LimeImageExplainer()
segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)
In [64]:
explanation = explainer.explain_instance(X_test[0],
classifier_fn=simple_rf_pipeline.predict_proba,
top_labels=10,
hide_color=0,
num_samples=10000,
segmentation_fn=segmenter, )
In [66]:
temp, mask = explanation.get_image_and_mask(y_test[0],
positive_only=True,
num_features=10,
hide_rest=False,
min_weight=0.01)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.imshow(label2rgb(mask, temp, bg_label=0), interpolation='nearest')
ax1.set_title('Positive Regions for {}'.format(y_test[0]))
temp, mask = explanation.get_image_and_mask(y_test[0],
positive_only=False,
num_features=10,
hide_rest=False,
min_weight=0.01)
ax2.imshow(label2rgb(3 - mask, temp, bg_label=0), interpolation='nearest')
ax2.set_title('Positive/Negative Regions for {}'.format(y_test[0]))
In [68]:
# now show them for each class
fig, m_axs = plt.subplots(2, 5, figsize=(12, 6))
for i, c_ax in enumerate(m_axs.flatten()):
temp, mask = explanation.get_image_and_mask(
i, positive_only=True, num_features=1000, hide_rest=False, min_weight=0.01)
c_ax.imshow(label2rgb(mask, X_test[i],
bg_label=0), interpolation='nearest')
c_ax.set_title('Positive for {}\nActual {}'.format(i, y_test[i]))
c_ax.axis('off')
Comments
Comments powered by Disqus