Explainable AI (XAI) - Seminar Demo
Libraries:¶
In [1]:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
# SHAP imports
import shap
# LIME imports
from lime import lime_image
from skimage.segmentation import mark_boundaries
# Grad-Cam imports
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Captum imports
from captum.attr import (
    IntegratedGradients, 
    LayerIntegratedGradients,
    visualization as viz
)
from transformers import (
    BertTokenizer, 
    BertForSequenceClassification, 
    BertForQuestionAnswering
)
C:\Users\ASUS\anaconda3\lib\site-packages\pandas\core\computation\expressions.py:20: UserWarning: Pandas requires version '2.7.3' or newer of 'numexpr' (version '2.7.1' currently installed). from pandas.core.computation.check import NUMEXPR_INSTALLED
1) SHAP
In [2]:
# Read Dataset
dataset = pd.read_csv("shap_dataset.csv")
dataset
Out[2]:
| age | income | credit_score | loan_amount | weight | approved | |
|---|---|---|---|---|---|---|
| 0 | 30 | 5953.19 | 659 | 133670 | 76.30 | 1 | 
| 1 | 61 | 1000.00 | 539 | 96670 | 63.03 | 0 | 
| 2 | 53 | 4102.22 | 513 | 111530 | 63.03 | 0 | 
| 3 | 29 | 5951.58 | 795 | 89060 | 58.83 | 1 | 
| 4 | 39 | 4984.51 | 615 | 57670 | 59.40 | 1 | 
| ... | ... | ... | ... | ... | ... | ... | 
| 4995 | 46 | 1841.62 | 610 | 172500 | 43.05 | 0 | 
| 4996 | 39 | 6528.76 | 734 | 131510 | 57.71 | 1 | 
| 4997 | 53 | 2002.29 | 555 | 122580 | 92.28 | 0 | 
| 4998 | 48 | 1197.70 | 554 | 157050 | 74.05 | 0 | 
| 4999 | 36 | 4270.00 | 782 | 112760 | 67.76 | 1 | 
5000 rows × 6 columns
In [3]:
# EDA & Preprocessing Part Skipped... (Not Relevant for the XAI Demo)
# ...
# Split the dataset
X = dataset.drop("approved", axis=1)
y = dataset["approved"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
# Train Model
RF = RandomForestClassifier(n_estimators=150, n_jobs=-1)
RF.fit(X_train, y_train)
# Evaluation & Fine-Tuning Part Skipped... (Not Relevant for the XAI Demo)
# ...
# Prediction
y_pred = RF.predict(X_test)
In [4]:
# Visualizations
fig, ax = plt.subplots(figsize=(8, 8))
cf = ConfusionMatrixDisplay(confusion_matrix(y_test, y_pred, labels=RF.classes_))
cf.plot(ax=ax)
plt.title("Confusion Matrix")
plt.show()
In [5]:
# XAI
explainer = shap.TreeExplainer(RF)
shap_values = explainer.shap_values(X_test)
Bar Plot¶
In [6]:
shap_exp = shap.Explanation(
    values=shap_values[1],
    base_values=explainer.expected_value[1],
    data=X_test.values,
    feature_names=X_test.columns.tolist()
)
shap.plots.bar(shap_exp)
WaterFall Plot¶
In [7]:
X_test.iloc[5].to_frame().T
Out[7]:
| age | income | credit_score | loan_amount | weight | |
|---|---|---|---|---|---|
| 106 | 50.0 | 8019.34 | 680.0 | 76060.0 | 97.76 | 
In [8]:
shap.plots.waterfall(shap_exp[5], max_display=14)
Summary Plot¶
In [9]:
shap.summary_plot(shap_values[1], X_test, plot_size=(15, 6))
Partial Dependence Plot¶
In [10]:
fig, ax = plt.subplots(figsize=(12, 6))
shap.plots.scatter(shap_exp[:, "income"], ax=ax)
In [11]:
fig, ax = plt.subplots(figsize=(12, 6))
shap.plots.scatter(shap_exp[:, "age"], ax=ax)
In [12]:
# Unimportant Feature
fig, ax = plt.subplots(figsize=(12, 6))
shap.plots.scatter(shap_exp[:, "weight"], ax=ax)
In [13]:
fig, ax = plt.subplots(figsize=(12, 6))
shap.dependence_plot(
    'age', 
    shap_values[1], 
    X_test, 
    ax=ax
)
plt.show()
2) LIME
Google MobileNet V2
In [2]:
# Load model and Preprocessor
preprocessor = AutoImageProcessor.from_pretrained("google/mobilenet_v2_1.0_224")
model = AutoModelForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
In [3]:
image_example_1 = Image.open("pictures/dog.jpg")
image_example_2 = Image.open("pictures/cat_tiger_cat.png")
image_example_3 = Image.open("pictures/cat-dog.jpg")
fig, axs = plt.subplots(1, 3, figsize=(12,12))
for ax, i, d in zip(
    axs,
    [image_example_1, image_example_2, image_example_3],
    ['Dog', 'Cat', 'Cat and Chihuahua']
):
    ax.imshow(i)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f'{d}')
plt.tight_layout()
plt.show()
In [4]:
class LIME_Wrapper:
    def __init__(self, model, preprocessor):
        self.model = model
        self.preprocessor = preprocessor
        self.model.eval()  
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
    def run_image(self, img, top_k=5):
        img = img.convert('RGB')
        plt.imshow(img)
        plt.axis('off')
        plt.title("Input Image")
        plt.show()
        inputs = self.preprocessor(images=img, return_tensors="pt")
        inputs = {key: value.to(self.device) for key, value in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        logits = outputs.logits
        probs = F.softmax(logits, dim=1)
        top_k_probs, top_k_indices = torch.topk(probs, k=top_k, dim=1)
        top_k_predictions = [
            (p.item(), c.item(), self.model.config.id2label[c.item()])  # Convert tensors to Python values
            for p, c in zip(top_k_probs.squeeze(), top_k_indices.squeeze())]
        print(f"Top-{top_k} Predictions:")
        for prob, class_id, label in top_k_predictions:
            print(f"Class ID: {class_id}, Label: {label}, Confidence: {prob:.5f}")
        return img, top_k_predictions
    def get_input_transform(self):
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transf = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ])
        return transf
    def get_input_tensors(self, img):
        transf = self.get_input_transform()
        return transf(img).unsqueeze(0)
    def get_pil_transform(self):
        transf = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224)
        ])
        return transf
    def batch_predict(self, images):
        inputs = self.preprocessor(images=images, return_tensors="pt")
        inputs = {key: value.to(self.device) for key, value in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        logits = outputs.logits
        probs = F.softmax(logits, dim=1)
        return probs.detach().cpu().numpy()
    def visualize_lime_explanation(self, K, explanation, top_predictions, pos_features, all_features):
        temp1, mask1 = explanation.get_image_and_mask(
            explanation.top_labels[K], positive_only=True, num_features=pos_features, hide_rest=False
        )
        img_boundry1 = mark_boundaries(temp1 / 255.0, mask1)
        temp2, mask2 = explanation.get_image_and_mask(
            explanation.top_labels[K], positive_only=False, num_features=all_features, hide_rest=False
        )
        img_boundry2 = mark_boundaries(temp2 / 255.0, mask2)
        fig, ax = plt.subplots(1, 3, figsize=(12, 5))
        ax[0].imshow(explanation.image)
        ax[0].axis("off")
        ax[0].set_title(f"Original Image (Class: {top_predictions[K][2]})")
        ax[1].imshow(img_boundry1)
        ax[1].axis("off")
        ax[1].set_title(f"LIME: Positive Features (K={K+1})")
        ax[2].imshow(img_boundry2)
        ax[2].axis("off")
        ax[2].set_title(f"LIME: Pos & Neg Features (K={K+1})")
        plt.suptitle(
            f"K={K+1}, Class ID: {top_predictions[K][1]}, Label: {top_predictions[K][2]}, "
            f"Confidence: {top_predictions[K][0]:.5f}",
            fontsize=14,
            y=1.05,
        )
        plt.tight_layout()
        plt.show()
    def explain_image(self, img, top_labels=10, num_samples=1000):
        pill_transf = self.get_pil_transform()
        
        explainer = lime_image.LimeImageExplainer()
        explanation = explainer.explain_instance(
            np.array(self.get_pil_transform()(img)),
            self.batch_predict,
            top_labels=top_labels,
            hide_color=0,
            num_samples=num_samples
        )
        return explanation
    
    def visualize_class_difference(self, explanation, top_predictions, class_idx_A, class_idx_B, num_features=10):
        from matplotlib.colors import Normalize
        import matplotlib.cm as cm
        _, mask_A = explanation.get_image_and_mask(
            label=explanation.top_labels[class_idx_A],
            positive_only=False,
            num_features=num_features,
            hide_rest=False
        )
        _, mask_B = explanation.get_image_and_mask(
            label=explanation.top_labels[class_idx_B],
            positive_only=False,
            num_features=num_features,
            hide_rest=False
        )
        diff_mask = mask_A.astype(int) - mask_B.astype(int) 
        original_image = explanation.image / 255.0
        cmap = cm.bwr  
        norm = Normalize(vmin=-1, vmax=1)
        heatmap = cmap(norm(diff_mask))[:, :, :3]  
        overlay = 0.5 * original_image + 0.5 * heatmap
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow(mask_A, cmap='gray')
        ax[0].set_title(f"Class A: {top_predictions[class_idx_A][2]}")
        ax[0].axis('off')
        ax[1].imshow(mask_B, cmap='gray')
        ax[1].set_title(f"Class B: {top_predictions[class_idx_B][2]}")
        ax[1].axis('off')
        ax[2].imshow(overlay)
        ax[2].set_title(f"Difference (A - B)")
        ax[2].axis('off')
        plt.tight_layout()
        plt.show()
    def visualize_heatmap_intensity(self, explanation, top_predictions, class_idx, num_features=10):
            local_exp = explanation.local_exp[explanation.top_labels[class_idx]]
            local_exp = sorted(local_exp, key=lambda x: abs(x[1]), reverse=True)[:num_features]
            weight_dict = dict(local_exp)
            segments = explanation.segments
            heatmap = np.zeros(segments.shape, dtype=float)
            for sp_idx, weight in weight_dict.items():
                heatmap[segments == sp_idx] = weight
            max_abs = max(abs(np.min(heatmap)), abs(np.max(heatmap)))
            norm = Normalize(vmin=-max_abs, vmax=max_abs)
            cmap = cm.bwr
            heatmap_colors = cmap(norm(heatmap))[:, :, :3]
            img = explanation.image / 255.0
            overlay = 0.6 * img + 0.4 * heatmap_colors
            overlay = np.clip(overlay, 0, 1)
            plt.figure(figsize=(8, 8))
            plt.imshow(overlay)
            plt.title(
                f"LIME Heatmap Intensity for Class: {top_predictions[class_idx][2]}\n"
                f"Confidence: {top_predictions[class_idx][0]:.4f}"
            )
            plt.axis('off')
            plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), label='Superpixel Importance')
            plt.show()
In [5]:
lime_wrapper = LIME_Wrapper(model, preprocessor)
Example 1: Dog¶
In [18]:
img1, top_predictions_1 = lime_wrapper.run_image(image_example_1, top_k=10)
lime_explanation_1 = lime_wrapper.explain_image(img1, top_labels=10)
Top-10 Predictions: Class ID: 255, Label: pug, pug-dog, Confidence: 0.81005 Class ID: 263, Label: Brabancon griffon, Confidence: 0.01426 Class ID: 244, Label: bull mastiff, Confidence: 0.00916 Class ID: 175, Label: Norwegian elkhound, elkhound, Confidence: 0.00302 Class ID: 246, Label: French bulldog, Confidence: 0.00210 Class ID: 853, Label: tennis ball, Confidence: 0.00189 Class ID: 436, Label: bathtub, bathing tub, bath, tub, Confidence: 0.00181 Class ID: 677, Label: muzzle, Confidence: 0.00165 Class ID: 508, Label: combination lock, Confidence: 0.00137 Class ID: 792, Label: shopping cart, Confidence: 0.00131
0%| | 0/1000 [00:00<?, ?it/s]
In [19]:
# LIME Heatmap
lime_wrapper.visualize_lime_explanation(
    0, 
    lime_explanation_1, 
    top_predictions_1, 
    pos_features=5, 
    all_features=10
)
lime_wrapper.visualize_lime_explanation(
    5, 
    lime_explanation_1, 
    top_predictions_1, 
    pos_features=5, 
    all_features=15
)
In [20]:
# Intensity Heatmap
lime_wrapper.visualize_heatmap_intensity(
    explanation=lime_explanation_1,
    top_predictions=top_predictions_1,
    class_idx=0,    
    num_features=10    
)
Unable to determine Axes to steal space for Colorbar. Using gca(), but will raise in the future. Either provide the *cax* argument to use as the Axes for the Colorbar, provide the *ax* argument to steal space from it, or add *mappable* to an Axes.
In [21]:
# Difference Maps Between Classes
lime_wrapper.visualize_class_difference(
    explanation=lime_explanation_1,
    top_predictions=top_predictions_1,
    class_idx_A=0,
    class_idx_B=5,
    num_features=10
)
Example 2: Cat¶
In [22]:
img2, top_predictions_2 = lime_wrapper.run_image(image_example_2, top_k=10)
lime_explanation_2 = lime_wrapper.explain_image(img2, top_labels=10)
Top-10 Predictions: Class ID: 283, Label: tiger cat, Confidence: 0.62426 Class ID: 288, Label: lynx, catamount, Confidence: 0.13119 Class ID: 292, Label: lion, king of beasts, Panthera leo, Confidence: 0.08663 Class ID: 293, Label: tiger, Panthera tigris, Confidence: 0.04508 Class ID: 282, Label: tabby, tabby cat, Confidence: 0.01313 Class ID: 286, Label: Egyptian cat, Confidence: 0.01046 Class ID: 284, Label: Persian cat, Confidence: 0.00840 Class ID: 287, Label: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor, Confidence: 0.00409 Class ID: 291, Label: jaguar, panther, Panthera onca, Felis onca, Confidence: 0.00227 Class ID: 261, Label: chow, chow chow, Confidence: 0.00164
0%| | 0/1000 [00:00<?, ?it/s]
In [23]:
lime_wrapper.visualize_lime_explanation(
    0, 
    lime_explanation_2, 
    top_predictions_2, 
    pos_features=5, 
    all_features=10
)
In [24]:
lime_wrapper.visualize_heatmap_intensity(
    explanation=lime_explanation_2,
    top_predictions=top_predictions_2,
    class_idx=0,    
    num_features=10    
)
Unable to determine Axes to steal space for Colorbar. Using gca(), but will raise in the future. Either provide the *cax* argument to use as the Axes for the Colorbar, provide the *ax* argument to steal space from it, or add *mappable* to an Axes.
Example 3: Cat and Dog¶
In [6]:
img3, top_predictions_3 = lime_wrapper.run_image(image_example_3, top_k=5)
lime_explanation_3 = lime_wrapper.explain_image(img3, top_labels=5)
Top-5 Predictions: Class ID: 152, Label: Chihuahua, Confidence: 0.25436 Class ID: 158, Label: papillon, Confidence: 0.14414 Class ID: 285, Label: Siamese cat, Siamese, Confidence: 0.13388 Class ID: 153, Label: Japanese spaniel, Confidence: 0.01501 Class ID: 155, Label: Pekinese, Pekingese, Peke, Confidence: 0.01313
0%| | 0/1000 [00:00<?, ?it/s]
In [7]:
lime_wrapper.visualize_lime_explanation(
    0, 
    lime_explanation_3, 
    top_predictions_3, 
    pos_features=3, 
    all_features=13
)
lime_wrapper.visualize_lime_explanation(
    2, 
    lime_explanation_3, 
    top_predictions_3, 
    pos_features=5, 
    all_features=10
)
3) Saliency Maps: Grad-Cam
In [8]:
class GradCamWrapper(torch.nn.Module):
    def __init__(self, model): 
        super(GradCamWrapper, self).__init__()
        self.model = model
        
    def forward(self, x):
        return self.model(x).logits
    def image_class(self, img_logits):
        predicted_class_idx = img_logits.argmax(-1).item()
        return predicted_class_idx, self.model.config.id2label[predicted_class_idx] 
    def topk_classes(self, img_logits, k):
        topk_idx = img_logits.topk(k)[1].numpy().flatten()
        topk_logits = img_logits.topk(k)[0].detach().numpy().flatten()
        topk_classes = [self.model.config.id2label[x] for x in topk_idx]
        return [(i, c, l) for i, c, l in zip(topk_idx, topk_classes, topk_logits)]
    def plot_gradcam(self, img, target_l, target_c, title=''):
        logits = self.forward(img)
        img_plot = np.transpose(img.detach().numpy().squeeze(), axes=[1, 2, 0])
        img_plot = (img_plot - img_plot.min())/(img_plot.max()-img_plot.min())
        
        cam = GradCAM(model=wrapper, target_layers=target_l)
        
        targets = []
        for t in target_c:
            targets.append([ClassifierOutputTarget(t)])
    
        cams = []
        for t in targets:
            grayscale_cam = cam(input_tensor=img, targets=t, aug_smooth=True)
            cams.append(grayscale_cam[0, :])
        
        visualization = []
        for i in cams:
            visualization.append(show_cam_on_image(img_plot, i, use_rgb=True))
        
        fig, axs = plt.subplots(1, len(cams), figsize=(8,len(cams)*4))
        for ax, i, v in zip(axs, target_c, visualization):
            ax.imshow(v)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(f'{self.model.config.id2label[i] if len(self.model.config.id2label[i]) < 30 else self.model.config.id2label[i][:30]+"..."} \n logits: {logits[:,i].item():.3f}')
        plt.tight_layout()
        fig.suptitle(title, y=0.68)
        plt.show()
In [9]:
wrapper = GradCamWrapper(model)
Example 1: Dog¶
In [12]:
inputs = preprocessor(images=image_example_1, return_tensors="pt")
image_tensor = inputs['pixel_values']
predicted_id = wrapper.image_class(wrapper(image_tensor))[0]
target_ids = [predicted_id]+[263, 5]
# define target layer and get gradcam plot
target_layers = [wrapper.model.mobilenet_v2.layer[-1]]
wrapper.plot_gradcam(image_tensor, target_layers, target_ids, 'last layer')
Example 2: Cat¶
In [14]:
inputs = preprocessor(images=image_example_2, return_tensors="pt")
image_tensor = inputs['pixel_values']
predicted_id = wrapper.image_class(wrapper(image_tensor))[0]
target_ids = [predicted_id]+[288, 100]
# define target layer and get gradcam plot
target_layers = [wrapper.model.mobilenet_v2.layer[-1]]
wrapper.plot_gradcam(image_tensor, target_layers, target_ids, 'last layer')
4) Captum
bert-base-uncased-imdb
In [23]:
model_name = "textattack/bert-base-uncased-imdb"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
model.eval()
model.config.id2label = {0: "negative", 1: "positive"}
model.config.label2id = {"negative": 0, "positive": 1}
embedding_layer = model.bert.embeddings.word_embeddings
def explain_sentiment_with_ig(text, target_class=None):
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    input_embed = embedding_layer(input_ids)
    def forward_func(embeds):
        outputs = model(inputs_embeds=embeds, attention_mask=attention_mask)
        return torch.softmax(outputs.logits, dim=1)
    with torch.no_grad():
        probs = forward_func(input_embed)
        pred_class = torch.argmax(probs, dim=1).item()
        confidence = probs[0, pred_class].item()
    
    class_names = model.config.id2label
    pred_label = class_names[pred_class]
    print(f"🔍 Prediction: **{pred_label.upper()}** ({confidence:.4f} confidence)")
    target = pred_class if target_class is None else target_class
    ig = IntegratedGradients(forward_func)
    baseline = torch.zeros_like(input_embed)
    attributions, _ = ig.attribute(
        inputs=input_embed,
        baselines=baseline,
        target=target,
        return_convergence_delta=True
    )
    attributions_sum = attributions.sum(dim=-1).squeeze(0)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    attr_values = attributions_sum.detach().numpy()
    attr_norm = (attr_values - attr_values.min()) / (attr_values.max() - attr_values.min() + 1e-8)
    def color_token(token, value):
        color = matplotlib.colors.rgb2hex(plt.cm.Reds(value)[:3])
        return f'<span style="background-color:{color}; padding:2px; margin:1px; border-radius:3px;">{token}</span>'
    colored_tokens = [color_token(tok, val) for tok, val in zip(tokens, attr_norm)]
    html_content = ' '.join(colored_tokens).replace(' ##', '') 
    display(HTML(f"<div style='font-family:monospace; line-height:1.5'>{html_content}</div>"))
    print("\n🧠 Token Attributions:")
    for token, score in zip(tokens, attributions_sum):
        print(f"{token:15s}: {score.item():.4f}")
In [24]:
explain_sentiment_with_ig("This movie was absolutely amazing. The performances were stunning.")
🔍 Prediction: **POSITIVE** (0.9991 confidence)
[CLS] this movie was absolutely amazing . the performances were stunning . [SEP]
🧠 Token Attributions: [CLS] : -0.0115 this : 0.0039 movie : 0.0169 was : 0.1046 absolutely : 0.1221 amazing : 0.1691 . : -0.0294 the : 0.1195 performances : 0.0996 were : 0.0352 stunning : -0.0055 . : -0.0332 [SEP] : -1.3094
In [38]:
explain_sentiment_with_ig("This movie was super disgusting.")
🔍 Prediction: **NEGATIVE** (0.9743 confidence)
[CLS] this movie was super disgusting . [SEP]
🧠 Token Attributions: [CLS] : 0.0159 this : 0.0748 movie : -0.0783 was : -0.1325 super : -0.1031 disgusting : 0.2851 . : 0.0152 [SEP] : 0.0938
In [36]:
explain_sentiment_with_ig("This movie was normal.")
🔍 Prediction: **POSITIVE** (0.5586 confidence)
[CLS] this movie was normal . [SEP]
🧠 Token Attributions: [CLS] : 0.0001 this : -0.0090 movie : -0.2239 was : 0.2493 normal : 0.1685 . : 0.3094 [SEP] : 0.0596