Interpretability of Random Forest Prediction for MNIST classification using LIME
Goal¶
This post aims to introduce how to interpret Random Forest classification for MNIST image using LIME, which generates an explainer for each prediction to help human beings to understand what happens in the prediction.

Reference
Libraries¶
In [55]:
import pandas as pd
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
%matplotlib inline
# Scikit Learn
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import Normalizer
from sklearn.model_selection import train_test_split
from skimage.color import gray2rgb, rgb2gray, label2rgb #
# LIME
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm
Load MNIST Dataset¶
In [3]:
mnist = load_digits()
In [7]:
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target,
train_size=0.55, test_size=0.45)
Create a preprocessing¶
In [61]:
class PipeStep(object):
"""
Wrapper for turning functions into pipeline transforms (no-fitting)
"""
def __init__(self, step_func):
self._step_func=step_func
def fit(self,*args):
return self
def transform(self,X):
return self._step_func(X)
makegray_step = PipeStep(lambda img_list: [rgb2gray(img) for img in img_list])
flatten_step = PipeStep(lambda img_list: [img.ravel() for img in img_list])
simple_rf_pipeline = Pipeline([
('Make Gray', makegray_step),
('Flatten Image', flatten_step),
#('Normalize', Normalizer()),
#('PCA', PCA(16)),
('RF', RandomForestClassifier())])
In [62]:
simple_rf_pipeline.fit(X_train, y_train)
Out[62]:
Train a Random Forest model¶
In [70]:
clf = RandomForestClassifier(n_estimators=10)
scores = cross_val_score(simple_rf_pipeline, X_train, y_train, cv=5)
simple_rf_pipeline.fit(X_train, y_train)
plt.plot(scores, '.-');
plt.title('Cross Validation by Random Forest for MNIST')
plt.xlabel('# of Validation');
plt.ylabel('Accuracy');
Create a LIME explainer¶
In [63]:
explainer = lime_image.LimeImageExplainer()
segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)
In [64]:
explanation = explainer.explain_instance(X_test[0],
classifier_fn=simple_rf_pipeline.predict_proba,
top_labels=10,
hide_color=0,
num_samples=10000,
segmentation_fn=segmenter, )
In [66]:
temp, mask = explanation.get_image_and_mask(y_test[0],
positive_only=True,
num_features=10,
hide_rest=False,
min_weight=0.01)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.imshow(label2rgb(mask, temp, bg_label=0), interpolation='nearest')
ax1.set_title('Positive Regions for {}'.format(y_test[0]))
temp, mask = explanation.get_image_and_mask(y_test[0],
positive_only=False,
num_features=10,
hide_rest=False,
min_weight=0.01)
ax2.imshow(label2rgb(3 - mask, temp, bg_label=0), interpolation='nearest')
ax2.set_title('Positive/Negative Regions for {}'.format(y_test[0]))
In [68]:
# now show them for each class
fig, m_axs = plt.subplots(2, 5, figsize=(12, 6))
for i, c_ax in enumerate(m_axs.flatten()):
temp, mask = explanation.get_image_and_mask(
i, positive_only=True, num_features=1000, hide_rest=False, min_weight=0.01)
c_ax.imshow(label2rgb(mask, X_test[i],
bg_label=0), interpolation='nearest')
c_ax.set_title('Positive for {}\nActual {}'.format(i, y_test[i]))
c_ax.axis('off')
Comments
Comments powered by Disqus