XAI optim
This commit is contained in:
parent
dbb1656ef0
commit
c300289aec
@ -55,7 +55,6 @@ The server starts at **http://localhost:5000**.
|
||||
assignment-HAI/
|
||||
├── dashboard.py # Flask application (main entry point)
|
||||
├── train.py # Standalone training script
|
||||
├── final.ipynb # Training, evaluation, and XAI notebook
|
||||
├── map.webp # Park map background image
|
||||
├── pyproject.toml # Project metadata and dependencies
|
||||
├── uv.lock # Locked dependency versions
|
||||
|
||||
@ -2,15 +2,6 @@
|
||||
|
||||
This guide explains how to use the Wildlife Monitoring Dashboard as an end user (e.g. a park ranger).
|
||||
|
||||
## Starting the Dashboard
|
||||
|
||||
After setup (see [README.md](README.md)), run:
|
||||
|
||||
```bash
|
||||
uv run python dashboard.py
|
||||
```
|
||||
|
||||
Open **http://localhost:5000** in your browser.
|
||||
|
||||
## Home Page — Live Map
|
||||
|
||||
|
||||
118
dashboard.py
118
dashboard.py
@ -9,7 +9,6 @@ Run: uv run python dashboard.py
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import threading
|
||||
@ -20,8 +19,8 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import matplotlib.cm as mpl_cm
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -170,33 +169,41 @@ def _save_img(arr, path):
|
||||
Image.fromarray(arr).save(path)
|
||||
|
||||
|
||||
def compute_xai(det):
|
||||
out = XAI_DIR / det["id"]
|
||||
if (out / "meta.json").exists():
|
||||
return
|
||||
out.mkdir(exist_ok=True)
|
||||
|
||||
idx = det["idx"]
|
||||
tensor, label = test_ds[idx]
|
||||
img_path, _ = test_ds.samples[idx]
|
||||
raw = np.array(
|
||||
Image.open(img_path).convert("RGB").resize((224, 224)),
|
||||
dtype=np.float64,
|
||||
) / 255.0
|
||||
inp = tensor.unsqueeze(0).to(device)
|
||||
|
||||
_save_img(raw, out / "original.png")
|
||||
|
||||
# single forward pass: probabilities + feature vector for neighbours
|
||||
def _forward(inp):
|
||||
"""Single split forward pass returning probs, pred_idx, and neighbour query vector."""
|
||||
with torch.no_grad():
|
||||
feats = model.features(inp)
|
||||
pooled = model.avgpool(feats)
|
||||
pooled = model.avgpool(model.features(inp))
|
||||
logits = model.classifier(torch.flatten(pooled, 1))
|
||||
probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze()
|
||||
qf = torch.nn.functional.normalize(
|
||||
torch.flatten(pooled, 1).cpu(), dim=1,
|
||||
).squeeze()
|
||||
pred_idx = probs.argmax().item()
|
||||
return probs, probs.argmax().item(), qf
|
||||
|
||||
|
||||
def compute_xai(det, pre=None):
|
||||
"""Compute all XAI artefacts. `pre` is an optional dict with keys
|
||||
inp, raw, probs, pred_idx, qf to skip the duplicate forward pass."""
|
||||
out = XAI_DIR / det["id"]
|
||||
if (out / "meta.json").exists():
|
||||
return
|
||||
out.mkdir(exist_ok=True)
|
||||
|
||||
if pre is not None:
|
||||
inp, raw, probs, pred_idx, qf = (
|
||||
pre["inp"], pre["raw"], pre["probs"], pre["pred_idx"], pre["qf"])
|
||||
else:
|
||||
idx = det["idx"]
|
||||
tensor, _ = test_ds[idx]
|
||||
img_path, _ = test_ds.samples[idx]
|
||||
raw = np.array(
|
||||
Image.open(img_path).convert("RGB").resize((224, 224)),
|
||||
dtype=np.float64,
|
||||
) / 255.0
|
||||
inp = tensor.unsqueeze(0).to(device)
|
||||
_save_img(raw, out / "original.png")
|
||||
probs, pred_idx, qf = _forward(inp)
|
||||
|
||||
entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
|
||||
|
||||
# ScoreCAM
|
||||
@ -206,7 +213,7 @@ def compute_xai(det):
|
||||
out / "scorecam.png",
|
||||
)
|
||||
|
||||
# LIME (reduced from 1000 to 500 samples)
|
||||
# LIME (500 samples)
|
||||
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500)
|
||||
c1, c2 = expl.top_labels[0], expl.top_labels[1]
|
||||
t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False)
|
||||
@ -221,15 +228,29 @@ def compute_xai(det):
|
||||
diff[segs == sid] = w1.get(sid, 0) - w2.get(sid, 0)
|
||||
mx = max(np.abs(diff).max(), 1e-8)
|
||||
diff /= mx
|
||||
col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3]
|
||||
col = mpl_cm.RdBu_r((diff + 1) / 2)[:, :, :3]
|
||||
_save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png")
|
||||
|
||||
# nearest neighbours (reuse qf from above)
|
||||
sims, idxs = (train_feats @ qf).topk(3)
|
||||
# nearest neighbours — 2x2 composite with labels
|
||||
sims, idxs = (train_feats @ qf).topk(4)
|
||||
nbs = []
|
||||
for k, (ni, ns) in enumerate(zip(idxs, sims)):
|
||||
_save_img(_to_display(train_ds[ni.item()][0]), out / f"nb{k+1}.png")
|
||||
nbs.append({"cls": CLASS_NAMES[train_labels[ni]], "sim": f"{ns:.3f}"})
|
||||
nb_pils = []
|
||||
for ni, ns in zip(idxs, sims):
|
||||
arr = _to_display(train_ds[ni.item()][0])
|
||||
cls = CLASS_NAMES[train_labels[ni]]
|
||||
sim = f"{ns:.3f}"
|
||||
nbs.append({"cls": cls, "sim": sim})
|
||||
pil = Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8))
|
||||
draw = ImageDraw.Draw(pil)
|
||||
label = f"{cls} ({sim})"
|
||||
draw.rectangle([(0, pil.height - 22), (pil.width, pil.height)], fill=(0, 0, 0, 180))
|
||||
draw.text((6, pil.height - 20), label, fill=(255, 255, 255))
|
||||
nb_pils.append(pil)
|
||||
w, h = nb_pils[0].size
|
||||
grid = Image.new("RGB", (w * 2, h * 2))
|
||||
for i, p in enumerate(nb_pils):
|
||||
grid.paste(p, ((i % 2) * w, (i // 2) * h))
|
||||
grid.save(out / "neighbours.png")
|
||||
|
||||
meta = {
|
||||
"pred": CLASS_NAMES[pred_idx],
|
||||
@ -252,10 +273,10 @@ _xai_events: dict[str, threading.Event] = {}
|
||||
_xai_lock = threading.Lock()
|
||||
|
||||
|
||||
def _precompute_xai(det):
|
||||
def _precompute_xai(det, pre=None):
|
||||
try:
|
||||
with _xai_lock:
|
||||
compute_xai(det)
|
||||
compute_xai(det, pre)
|
||||
finally:
|
||||
ev = _xai_events.get(det["id"])
|
||||
if ev:
|
||||
@ -280,12 +301,15 @@ def detail(det_id):
|
||||
def api_simulate():
|
||||
idx = random.randint(0, len(test_ds) - 1)
|
||||
cam_id = random.choice(list(CAMERAS.keys()))
|
||||
tensor, label = test_ds[idx]
|
||||
with torch.no_grad():
|
||||
probs = torch.nn.functional.softmax(
|
||||
model(tensor.unsqueeze(0).to(device)), dim=1,
|
||||
).cpu().squeeze()
|
||||
pred_idx = probs.argmax().item()
|
||||
tensor, _ = test_ds[idx]
|
||||
inp = tensor.unsqueeze(0).to(device)
|
||||
probs, pred_idx, qf = _forward(inp)
|
||||
|
||||
img_path, _ = test_ds.samples[idx]
|
||||
raw = np.array(
|
||||
Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64,
|
||||
) / 255.0
|
||||
|
||||
det = {
|
||||
"id": uuid.uuid4().hex[:8],
|
||||
"idx": idx,
|
||||
@ -300,14 +324,12 @@ def api_simulate():
|
||||
detections.append(det)
|
||||
out = XAI_DIR / det["id"]
|
||||
out.mkdir(exist_ok=True)
|
||||
img_path, _ = test_ds.samples[idx]
|
||||
raw = np.array(
|
||||
Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64,
|
||||
) / 255.0
|
||||
_save_img(raw, out / "original.png")
|
||||
|
||||
pre = {"inp": inp, "raw": raw, "probs": probs, "pred_idx": pred_idx, "qf": qf}
|
||||
ev = threading.Event()
|
||||
_xai_events[det["id"]] = ev
|
||||
threading.Thread(target=_precompute_xai, args=(det,), daemon=True).start()
|
||||
threading.Thread(target=_precompute_xai, args=(det, pre), daemon=True).start()
|
||||
return jsonify(det)
|
||||
|
||||
|
||||
@ -326,7 +348,7 @@ def api_xai(det_id):
|
||||
meta["urls"] = {
|
||||
k: f"{base}/{k}.png"
|
||||
for k in ["original", "scorecam", "lime1", "lime2", "contrastive",
|
||||
"nb1", "nb2", "nb3"]
|
||||
"neighbours"]
|
||||
}
|
||||
return jsonify(meta)
|
||||
|
||||
@ -962,11 +984,9 @@ document.addEventListener('keydown',e=>{
|
||||
cap:'Green superpixels support this class; red superpixels oppose it.'},
|
||||
{src:u.contrastive, title:'Contrastive Explanation',
|
||||
cap:d.contrast_leg},
|
||||
{src:u.neighbours, title:'Nearest Training Neighbours',
|
||||
cap:d.nbs.map((nb,i)=>`#${i+1}: ${nb.cls} (sim ${nb.sim})`).join(' | ')},
|
||||
];
|
||||
d.nbs.forEach((nb,i)=>{
|
||||
slides.push({src:u['nb'+(i+1)], title:'Nearest Neighbour '+(i+1),
|
||||
cap:nb.cls+' (cosine similarity '+nb.sim+')'});
|
||||
});
|
||||
|
||||
let dots='';
|
||||
slides.forEach((_,i)=>{dots+=`<button class="dot${i===0?' active':''}" onclick="cur=${i};render()"></button>`;});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user