XAI optim
This commit is contained in:
parent
dbb1656ef0
commit
c300289aec
@ -55,7 +55,6 @@ The server starts at **http://localhost:5000**.
|
|||||||
assignment-HAI/
|
assignment-HAI/
|
||||||
├── dashboard.py # Flask application (main entry point)
|
├── dashboard.py # Flask application (main entry point)
|
||||||
├── train.py # Standalone training script
|
├── train.py # Standalone training script
|
||||||
├── final.ipynb # Training, evaluation, and XAI notebook
|
|
||||||
├── map.webp # Park map background image
|
├── map.webp # Park map background image
|
||||||
├── pyproject.toml # Project metadata and dependencies
|
├── pyproject.toml # Project metadata and dependencies
|
||||||
├── uv.lock # Locked dependency versions
|
├── uv.lock # Locked dependency versions
|
||||||
|
|||||||
@ -2,15 +2,6 @@
|
|||||||
|
|
||||||
This guide explains how to use the Wildlife Monitoring Dashboard as an end user (e.g. a park ranger).
|
This guide explains how to use the Wildlife Monitoring Dashboard as an end user (e.g. a park ranger).
|
||||||
|
|
||||||
## Starting the Dashboard
|
|
||||||
|
|
||||||
After setup (see [README.md](README.md)), run:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uv run python dashboard.py
|
|
||||||
```
|
|
||||||
|
|
||||||
Open **http://localhost:5000** in your browser.
|
|
||||||
|
|
||||||
## Home Page — Live Map
|
## Home Page — Live Map
|
||||||
|
|
||||||
|
|||||||
102
dashboard.py
102
dashboard.py
@ -9,7 +9,6 @@ Run: uv run python dashboard.py
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import io
|
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
@ -20,8 +19,8 @@ from pathlib import Path
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.cm as mpl_cm
|
||||||
from PIL import Image
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -170,33 +169,41 @@ def _save_img(arr, path):
|
|||||||
Image.fromarray(arr).save(path)
|
Image.fromarray(arr).save(path)
|
||||||
|
|
||||||
|
|
||||||
def compute_xai(det):
|
def _forward(inp):
|
||||||
|
"""Single split forward pass returning probs, pred_idx, and neighbour query vector."""
|
||||||
|
with torch.no_grad():
|
||||||
|
pooled = model.avgpool(model.features(inp))
|
||||||
|
logits = model.classifier(torch.flatten(pooled, 1))
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze()
|
||||||
|
qf = torch.nn.functional.normalize(
|
||||||
|
torch.flatten(pooled, 1).cpu(), dim=1,
|
||||||
|
).squeeze()
|
||||||
|
return probs, probs.argmax().item(), qf
|
||||||
|
|
||||||
|
|
||||||
|
def compute_xai(det, pre=None):
|
||||||
|
"""Compute all XAI artefacts. `pre` is an optional dict with keys
|
||||||
|
inp, raw, probs, pred_idx, qf to skip the duplicate forward pass."""
|
||||||
out = XAI_DIR / det["id"]
|
out = XAI_DIR / det["id"]
|
||||||
if (out / "meta.json").exists():
|
if (out / "meta.json").exists():
|
||||||
return
|
return
|
||||||
out.mkdir(exist_ok=True)
|
out.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
if pre is not None:
|
||||||
|
inp, raw, probs, pred_idx, qf = (
|
||||||
|
pre["inp"], pre["raw"], pre["probs"], pre["pred_idx"], pre["qf"])
|
||||||
|
else:
|
||||||
idx = det["idx"]
|
idx = det["idx"]
|
||||||
tensor, label = test_ds[idx]
|
tensor, _ = test_ds[idx]
|
||||||
img_path, _ = test_ds.samples[idx]
|
img_path, _ = test_ds.samples[idx]
|
||||||
raw = np.array(
|
raw = np.array(
|
||||||
Image.open(img_path).convert("RGB").resize((224, 224)),
|
Image.open(img_path).convert("RGB").resize((224, 224)),
|
||||||
dtype=np.float64,
|
dtype=np.float64,
|
||||||
) / 255.0
|
) / 255.0
|
||||||
inp = tensor.unsqueeze(0).to(device)
|
inp = tensor.unsqueeze(0).to(device)
|
||||||
|
|
||||||
_save_img(raw, out / "original.png")
|
_save_img(raw, out / "original.png")
|
||||||
|
probs, pred_idx, qf = _forward(inp)
|
||||||
|
|
||||||
# single forward pass: probabilities + feature vector for neighbours
|
|
||||||
with torch.no_grad():
|
|
||||||
feats = model.features(inp)
|
|
||||||
pooled = model.avgpool(feats)
|
|
||||||
logits = model.classifier(torch.flatten(pooled, 1))
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze()
|
|
||||||
qf = torch.nn.functional.normalize(
|
|
||||||
torch.flatten(pooled, 1).cpu(), dim=1,
|
|
||||||
).squeeze()
|
|
||||||
pred_idx = probs.argmax().item()
|
|
||||||
entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
|
entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
|
||||||
|
|
||||||
# ScoreCAM
|
# ScoreCAM
|
||||||
@ -206,7 +213,7 @@ def compute_xai(det):
|
|||||||
out / "scorecam.png",
|
out / "scorecam.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
# LIME (reduced from 1000 to 500 samples)
|
# LIME (500 samples)
|
||||||
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500)
|
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500)
|
||||||
c1, c2 = expl.top_labels[0], expl.top_labels[1]
|
c1, c2 = expl.top_labels[0], expl.top_labels[1]
|
||||||
t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False)
|
t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False)
|
||||||
@ -221,15 +228,29 @@ def compute_xai(det):
|
|||||||
diff[segs == sid] = w1.get(sid, 0) - w2.get(sid, 0)
|
diff[segs == sid] = w1.get(sid, 0) - w2.get(sid, 0)
|
||||||
mx = max(np.abs(diff).max(), 1e-8)
|
mx = max(np.abs(diff).max(), 1e-8)
|
||||||
diff /= mx
|
diff /= mx
|
||||||
col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3]
|
col = mpl_cm.RdBu_r((diff + 1) / 2)[:, :, :3]
|
||||||
_save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png")
|
_save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png")
|
||||||
|
|
||||||
# nearest neighbours (reuse qf from above)
|
# nearest neighbours — 2x2 composite with labels
|
||||||
sims, idxs = (train_feats @ qf).topk(3)
|
sims, idxs = (train_feats @ qf).topk(4)
|
||||||
nbs = []
|
nbs = []
|
||||||
for k, (ni, ns) in enumerate(zip(idxs, sims)):
|
nb_pils = []
|
||||||
_save_img(_to_display(train_ds[ni.item()][0]), out / f"nb{k+1}.png")
|
for ni, ns in zip(idxs, sims):
|
||||||
nbs.append({"cls": CLASS_NAMES[train_labels[ni]], "sim": f"{ns:.3f}"})
|
arr = _to_display(train_ds[ni.item()][0])
|
||||||
|
cls = CLASS_NAMES[train_labels[ni]]
|
||||||
|
sim = f"{ns:.3f}"
|
||||||
|
nbs.append({"cls": cls, "sim": sim})
|
||||||
|
pil = Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8))
|
||||||
|
draw = ImageDraw.Draw(pil)
|
||||||
|
label = f"{cls} ({sim})"
|
||||||
|
draw.rectangle([(0, pil.height - 22), (pil.width, pil.height)], fill=(0, 0, 0, 180))
|
||||||
|
draw.text((6, pil.height - 20), label, fill=(255, 255, 255))
|
||||||
|
nb_pils.append(pil)
|
||||||
|
w, h = nb_pils[0].size
|
||||||
|
grid = Image.new("RGB", (w * 2, h * 2))
|
||||||
|
for i, p in enumerate(nb_pils):
|
||||||
|
grid.paste(p, ((i % 2) * w, (i // 2) * h))
|
||||||
|
grid.save(out / "neighbours.png")
|
||||||
|
|
||||||
meta = {
|
meta = {
|
||||||
"pred": CLASS_NAMES[pred_idx],
|
"pred": CLASS_NAMES[pred_idx],
|
||||||
@ -252,10 +273,10 @@ _xai_events: dict[str, threading.Event] = {}
|
|||||||
_xai_lock = threading.Lock()
|
_xai_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def _precompute_xai(det):
|
def _precompute_xai(det, pre=None):
|
||||||
try:
|
try:
|
||||||
with _xai_lock:
|
with _xai_lock:
|
||||||
compute_xai(det)
|
compute_xai(det, pre)
|
||||||
finally:
|
finally:
|
||||||
ev = _xai_events.get(det["id"])
|
ev = _xai_events.get(det["id"])
|
||||||
if ev:
|
if ev:
|
||||||
@ -280,12 +301,15 @@ def detail(det_id):
|
|||||||
def api_simulate():
|
def api_simulate():
|
||||||
idx = random.randint(0, len(test_ds) - 1)
|
idx = random.randint(0, len(test_ds) - 1)
|
||||||
cam_id = random.choice(list(CAMERAS.keys()))
|
cam_id = random.choice(list(CAMERAS.keys()))
|
||||||
tensor, label = test_ds[idx]
|
tensor, _ = test_ds[idx]
|
||||||
with torch.no_grad():
|
inp = tensor.unsqueeze(0).to(device)
|
||||||
probs = torch.nn.functional.softmax(
|
probs, pred_idx, qf = _forward(inp)
|
||||||
model(tensor.unsqueeze(0).to(device)), dim=1,
|
|
||||||
).cpu().squeeze()
|
img_path, _ = test_ds.samples[idx]
|
||||||
pred_idx = probs.argmax().item()
|
raw = np.array(
|
||||||
|
Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64,
|
||||||
|
) / 255.0
|
||||||
|
|
||||||
det = {
|
det = {
|
||||||
"id": uuid.uuid4().hex[:8],
|
"id": uuid.uuid4().hex[:8],
|
||||||
"idx": idx,
|
"idx": idx,
|
||||||
@ -300,14 +324,12 @@ def api_simulate():
|
|||||||
detections.append(det)
|
detections.append(det)
|
||||||
out = XAI_DIR / det["id"]
|
out = XAI_DIR / det["id"]
|
||||||
out.mkdir(exist_ok=True)
|
out.mkdir(exist_ok=True)
|
||||||
img_path, _ = test_ds.samples[idx]
|
|
||||||
raw = np.array(
|
|
||||||
Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64,
|
|
||||||
) / 255.0
|
|
||||||
_save_img(raw, out / "original.png")
|
_save_img(raw, out / "original.png")
|
||||||
|
|
||||||
|
pre = {"inp": inp, "raw": raw, "probs": probs, "pred_idx": pred_idx, "qf": qf}
|
||||||
ev = threading.Event()
|
ev = threading.Event()
|
||||||
_xai_events[det["id"]] = ev
|
_xai_events[det["id"]] = ev
|
||||||
threading.Thread(target=_precompute_xai, args=(det,), daemon=True).start()
|
threading.Thread(target=_precompute_xai, args=(det, pre), daemon=True).start()
|
||||||
return jsonify(det)
|
return jsonify(det)
|
||||||
|
|
||||||
|
|
||||||
@ -326,7 +348,7 @@ def api_xai(det_id):
|
|||||||
meta["urls"] = {
|
meta["urls"] = {
|
||||||
k: f"{base}/{k}.png"
|
k: f"{base}/{k}.png"
|
||||||
for k in ["original", "scorecam", "lime1", "lime2", "contrastive",
|
for k in ["original", "scorecam", "lime1", "lime2", "contrastive",
|
||||||
"nb1", "nb2", "nb3"]
|
"neighbours"]
|
||||||
}
|
}
|
||||||
return jsonify(meta)
|
return jsonify(meta)
|
||||||
|
|
||||||
@ -962,11 +984,9 @@ document.addEventListener('keydown',e=>{
|
|||||||
cap:'Green superpixels support this class; red superpixels oppose it.'},
|
cap:'Green superpixels support this class; red superpixels oppose it.'},
|
||||||
{src:u.contrastive, title:'Contrastive Explanation',
|
{src:u.contrastive, title:'Contrastive Explanation',
|
||||||
cap:d.contrast_leg},
|
cap:d.contrast_leg},
|
||||||
|
{src:u.neighbours, title:'Nearest Training Neighbours',
|
||||||
|
cap:d.nbs.map((nb,i)=>`#${i+1}: ${nb.cls} (sim ${nb.sim})`).join(' | ')},
|
||||||
];
|
];
|
||||||
d.nbs.forEach((nb,i)=>{
|
|
||||||
slides.push({src:u['nb'+(i+1)], title:'Nearest Neighbour '+(i+1),
|
|
||||||
cap:nb.cls+' (cosine similarity '+nb.sim+')'});
|
|
||||||
});
|
|
||||||
|
|
||||||
let dots='';
|
let dots='';
|
||||||
slides.forEach((_,i)=>{dots+=`<button class="dot${i===0?' active':''}" onclick="cur=${i};render()"></button>`;});
|
slides.forEach((_,i)=>{dots+=`<button class="dot${i===0?' active':''}" onclick="cur=${i};render()"></button>`;});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user