XAI optim

This commit is contained in:
everbarry 2026-03-19 13:19:24 +01:00
parent dbb1656ef0
commit c300289aec
3 changed files with 69 additions and 59 deletions

View File

@ -55,7 +55,6 @@ The server starts at **http://localhost:5000**.
assignment-HAI/
├── dashboard.py # Flask application (main entry point)
├── train.py # Standalone training script
├── final.ipynb # Training, evaluation, and XAI notebook
├── map.webp # Park map background image
├── pyproject.toml # Project metadata and dependencies
├── uv.lock # Locked dependency versions

View File

@ -2,15 +2,6 @@
This guide explains how to use the Wildlife Monitoring Dashboard as an end user (e.g. a park ranger).
## Starting the Dashboard
After setup (see [README.md](README.md)), run:
```bash
uv run python dashboard.py
```
Open **http://localhost:5000** in your browser.
## Home Page — Live Map

View File

@ -9,7 +9,6 @@ Run: uv run python dashboard.py
"""
import csv
import io
import json
import random
import threading
@ -20,8 +19,8 @@ from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from PIL import Image
import matplotlib.cm as mpl_cm
from PIL import Image, ImageDraw
import torch
from torch import nn
@ -170,33 +169,41 @@ def _save_img(arr, path):
Image.fromarray(arr).save(path)
def compute_xai(det):
out = XAI_DIR / det["id"]
if (out / "meta.json").exists():
return
out.mkdir(exist_ok=True)
idx = det["idx"]
tensor, label = test_ds[idx]
img_path, _ = test_ds.samples[idx]
raw = np.array(
Image.open(img_path).convert("RGB").resize((224, 224)),
dtype=np.float64,
) / 255.0
inp = tensor.unsqueeze(0).to(device)
_save_img(raw, out / "original.png")
# single forward pass: probabilities + feature vector for neighbours
def _forward(inp):
"""Single split forward pass returning probs, pred_idx, and neighbour query vector."""
with torch.no_grad():
feats = model.features(inp)
pooled = model.avgpool(feats)
pooled = model.avgpool(model.features(inp))
logits = model.classifier(torch.flatten(pooled, 1))
probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze()
qf = torch.nn.functional.normalize(
torch.flatten(pooled, 1).cpu(), dim=1,
).squeeze()
pred_idx = probs.argmax().item()
return probs, probs.argmax().item(), qf
def compute_xai(det, pre=None):
"""Compute all XAI artefacts. `pre` is an optional dict with keys
inp, raw, probs, pred_idx, qf to skip the duplicate forward pass."""
out = XAI_DIR / det["id"]
if (out / "meta.json").exists():
return
out.mkdir(exist_ok=True)
if pre is not None:
inp, raw, probs, pred_idx, qf = (
pre["inp"], pre["raw"], pre["probs"], pre["pred_idx"], pre["qf"])
else:
idx = det["idx"]
tensor, _ = test_ds[idx]
img_path, _ = test_ds.samples[idx]
raw = np.array(
Image.open(img_path).convert("RGB").resize((224, 224)),
dtype=np.float64,
) / 255.0
inp = tensor.unsqueeze(0).to(device)
_save_img(raw, out / "original.png")
probs, pred_idx, qf = _forward(inp)
entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
# ScoreCAM
@ -206,7 +213,7 @@ def compute_xai(det):
out / "scorecam.png",
)
# LIME (reduced from 1000 to 500 samples)
# LIME (500 samples)
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500)
c1, c2 = expl.top_labels[0], expl.top_labels[1]
t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False)
@ -221,15 +228,29 @@ def compute_xai(det):
diff[segs == sid] = w1.get(sid, 0) - w2.get(sid, 0)
mx = max(np.abs(diff).max(), 1e-8)
diff /= mx
col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3]
col = mpl_cm.RdBu_r((diff + 1) / 2)[:, :, :3]
_save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png")
# nearest neighbours (reuse qf from above)
sims, idxs = (train_feats @ qf).topk(3)
# nearest neighbours — 2x2 composite with labels
sims, idxs = (train_feats @ qf).topk(4)
nbs = []
for k, (ni, ns) in enumerate(zip(idxs, sims)):
_save_img(_to_display(train_ds[ni.item()][0]), out / f"nb{k+1}.png")
nbs.append({"cls": CLASS_NAMES[train_labels[ni]], "sim": f"{ns:.3f}"})
nb_pils = []
for ni, ns in zip(idxs, sims):
arr = _to_display(train_ds[ni.item()][0])
cls = CLASS_NAMES[train_labels[ni]]
sim = f"{ns:.3f}"
nbs.append({"cls": cls, "sim": sim})
pil = Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8))
draw = ImageDraw.Draw(pil)
label = f"{cls} ({sim})"
draw.rectangle([(0, pil.height - 22), (pil.width, pil.height)], fill=(0, 0, 0, 180))
draw.text((6, pil.height - 20), label, fill=(255, 255, 255))
nb_pils.append(pil)
w, h = nb_pils[0].size
grid = Image.new("RGB", (w * 2, h * 2))
for i, p in enumerate(nb_pils):
grid.paste(p, ((i % 2) * w, (i // 2) * h))
grid.save(out / "neighbours.png")
meta = {
"pred": CLASS_NAMES[pred_idx],
@ -252,10 +273,10 @@ _xai_events: dict[str, threading.Event] = {}
_xai_lock = threading.Lock()
def _precompute_xai(det):
def _precompute_xai(det, pre=None):
try:
with _xai_lock:
compute_xai(det)
compute_xai(det, pre)
finally:
ev = _xai_events.get(det["id"])
if ev:
@ -280,12 +301,15 @@ def detail(det_id):
def api_simulate():
idx = random.randint(0, len(test_ds) - 1)
cam_id = random.choice(list(CAMERAS.keys()))
tensor, label = test_ds[idx]
with torch.no_grad():
probs = torch.nn.functional.softmax(
model(tensor.unsqueeze(0).to(device)), dim=1,
).cpu().squeeze()
pred_idx = probs.argmax().item()
tensor, _ = test_ds[idx]
inp = tensor.unsqueeze(0).to(device)
probs, pred_idx, qf = _forward(inp)
img_path, _ = test_ds.samples[idx]
raw = np.array(
Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64,
) / 255.0
det = {
"id": uuid.uuid4().hex[:8],
"idx": idx,
@ -300,14 +324,12 @@ def api_simulate():
detections.append(det)
out = XAI_DIR / det["id"]
out.mkdir(exist_ok=True)
img_path, _ = test_ds.samples[idx]
raw = np.array(
Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64,
) / 255.0
_save_img(raw, out / "original.png")
pre = {"inp": inp, "raw": raw, "probs": probs, "pred_idx": pred_idx, "qf": qf}
ev = threading.Event()
_xai_events[det["id"]] = ev
threading.Thread(target=_precompute_xai, args=(det,), daemon=True).start()
threading.Thread(target=_precompute_xai, args=(det, pre), daemon=True).start()
return jsonify(det)
@ -326,7 +348,7 @@ def api_xai(det_id):
meta["urls"] = {
k: f"{base}/{k}.png"
for k in ["original", "scorecam", "lime1", "lime2", "contrastive",
"nb1", "nb2", "nb3"]
"neighbours"]
}
return jsonify(meta)
@ -962,11 +984,9 @@ document.addEventListener('keydown',e=>{
cap:'Green superpixels support this class; red superpixels oppose it.'},
{src:u.contrastive, title:'Contrastive Explanation',
cap:d.contrast_leg},
{src:u.neighbours, title:'Nearest Training Neighbours',
cap:d.nbs.map((nb,i)=>`#${i+1}: ${nb.cls} (sim ${nb.sim})`).join(' | ')},
];
d.nbs.forEach((nb,i)=>{
slides.push({src:u['nb'+(i+1)], title:'Nearest Neighbour '+(i+1),
cap:nb.cls+' (cosine similarity '+nb.sim+')'});
});
let dots='';
slides.forEach((_,i)=>{dots+=`<button class="dot${i===0?' active':''}" onclick="cur=${i};render()"></button>`;});