Interpretability of Random Forest Prediction for MNIST classification using LIME

Goal

This post aims to introduce how to interpret Random Forest classification for MNIST image using LIME, which generates an explainer for each prediction to help human beings to understand what happens in the prediction.

image

Reference

Libraries

In [55]:
import pandas as pd
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
%matplotlib inline

# Scikit Learn
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import Normalizer
from sklearn.model_selection import train_test_split
from skimage.color import gray2rgb, rgb2gray, label2rgb # 
# LIME
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm

Load MNIST Dataset

In [3]:
mnist = load_digits()
In [7]:
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target,
                                                    train_size=0.55, test_size=0.45)

Create a preprocessing

In [61]:
class PipeStep(object):
    """
    Wrapper for turning functions into pipeline transforms (no-fitting)
    """
    def __init__(self, step_func):
        self._step_func=step_func
    def fit(self,*args):
        return self
    def transform(self,X):
        return self._step_func(X)


makegray_step = PipeStep(lambda img_list: [rgb2gray(img) for img in img_list])
flatten_step = PipeStep(lambda img_list: [img.ravel() for img in img_list])

simple_rf_pipeline = Pipeline([
    ('Make Gray', makegray_step),
    ('Flatten Image', flatten_step),
    #('Normalize', Normalizer()),
    #('PCA', PCA(16)),
    ('RF', RandomForestClassifier())])
In [62]:
simple_rf_pipeline.fit(X_train, y_train)
/Users/hiro/anaconda3/envs/py367/lib/python3.6/site-packages/sklearn/ensemble/forest.py:246: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.
  "10 in version 0.20 to 100 in 0.22.", FutureWarning)
Out[62]:
Pipeline(memory=None,
     steps=[('Make Gray', <__main__.PipeStep object at 0x12479b470>), ('Flatten Image', <__main__.PipeStep object at 0x12479b390>), ('RF', RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impu...obs=None,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False))])

Train a Random Forest model

In [70]:
clf = RandomForestClassifier(n_estimators=10)
scores = cross_val_score(simple_rf_pipeline, X_train, y_train, cv=5)
simple_rf_pipeline.fit(X_train, y_train)
plt.plot(scores, '.-');
plt.title('Cross Validation by Random Forest for MNIST')
plt.xlabel('# of Validation');
plt.ylabel('Accuracy');

Create a LIME explainer

In [63]:
explainer = lime_image.LimeImageExplainer()
segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)
In [64]:
explanation = explainer.explain_instance(X_test[0],
                                         classifier_fn=simple_rf_pipeline.predict_proba,
                                         top_labels=10,
                                         hide_color=0,
                                         num_samples=10000,
                                         segmentation_fn=segmenter, )
In [66]:
temp, mask = explanation.get_image_and_mask(y_test[0],
                                            positive_only=True,
                                            num_features=10,
                                            hide_rest=False,
                                            min_weight=0.01)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.imshow(label2rgb(mask, temp, bg_label=0), interpolation='nearest')
ax1.set_title('Positive Regions for {}'.format(y_test[0]))
temp, mask = explanation.get_image_and_mask(y_test[0],
                                            positive_only=False,
                                            num_features=10,
                                            hide_rest=False,
                                            min_weight=0.01)

ax2.imshow(label2rgb(3 - mask, temp, bg_label=0), interpolation='nearest')
ax2.set_title('Positive/Negative Regions for {}'.format(y_test[0]))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [68]:
# now show them for each class
fig, m_axs = plt.subplots(2, 5, figsize=(12, 6))
for i, c_ax in enumerate(m_axs.flatten()):
    temp, mask = explanation.get_image_and_mask(
        i, positive_only=True, num_features=1000, hide_rest=False, min_weight=0.01)
    c_ax.imshow(label2rgb(mask, X_test[i],
                          bg_label=0), interpolation='nearest')
    c_ax.set_title('Positive for {}\nActual {}'.format(i, y_test[i]))
    c_ax.axis('off')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Comments

Comments powered by Disqus