From c300289aec6fcf475da4094a5abf4c735f9ac304 Mon Sep 17 00:00:00 2001 From: everbarry Date: Thu, 19 Mar 2026 13:19:24 +0100 Subject: [PATCH] XAI optim --- README.md | 1 - USER_GUIDE.md | 9 ---- dashboard.py | 118 +++++++++++++++++++++++++++++--------------------- 3 files changed, 69 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index e3b8974..2431c39 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,6 @@ The server starts at **http://localhost:5000**. assignment-HAI/ ├── dashboard.py # Flask application (main entry point) ├── train.py # Standalone training script -├── final.ipynb # Training, evaluation, and XAI notebook ├── map.webp # Park map background image ├── pyproject.toml # Project metadata and dependencies ├── uv.lock # Locked dependency versions diff --git a/USER_GUIDE.md b/USER_GUIDE.md index a7f01d6..ae80fe6 100644 --- a/USER_GUIDE.md +++ b/USER_GUIDE.md @@ -2,15 +2,6 @@ This guide explains how to use the Wildlife Monitoring Dashboard as an end user (e.g. a park ranger). -## Starting the Dashboard - -After setup (see [README.md](README.md)), run: - -```bash -uv run python dashboard.py -``` - -Open **http://localhost:5000** in your browser. ## Home Page — Live Map diff --git a/dashboard.py b/dashboard.py index d9b3a4d..f05f5e5 100644 --- a/dashboard.py +++ b/dashboard.py @@ -9,7 +9,6 @@ Run: uv run python dashboard.py """ import csv -import io import json import random import threading @@ -20,8 +19,8 @@ from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") -import matplotlib.pyplot as plt -from PIL import Image +import matplotlib.cm as mpl_cm +from PIL import Image, ImageDraw import torch from torch import nn @@ -170,33 +169,41 @@ def _save_img(arr, path): Image.fromarray(arr).save(path) -def compute_xai(det): - out = XAI_DIR / det["id"] - if (out / "meta.json").exists(): - return - out.mkdir(exist_ok=True) - - idx = det["idx"] - tensor, label = test_ds[idx] - img_path, _ = test_ds.samples[idx] - raw = np.array( - Image.open(img_path).convert("RGB").resize((224, 224)), - dtype=np.float64, - ) / 255.0 - inp = tensor.unsqueeze(0).to(device) - - _save_img(raw, out / "original.png") - - # single forward pass: probabilities + feature vector for neighbours +def _forward(inp): + """Single split forward pass returning probs, pred_idx, and neighbour query vector.""" with torch.no_grad(): - feats = model.features(inp) - pooled = model.avgpool(feats) + pooled = model.avgpool(model.features(inp)) logits = model.classifier(torch.flatten(pooled, 1)) probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze() qf = torch.nn.functional.normalize( torch.flatten(pooled, 1).cpu(), dim=1, ).squeeze() - pred_idx = probs.argmax().item() + return probs, probs.argmax().item(), qf + + +def compute_xai(det, pre=None): + """Compute all XAI artefacts. `pre` is an optional dict with keys + inp, raw, probs, pred_idx, qf to skip the duplicate forward pass.""" + out = XAI_DIR / det["id"] + if (out / "meta.json").exists(): + return + out.mkdir(exist_ok=True) + + if pre is not None: + inp, raw, probs, pred_idx, qf = ( + pre["inp"], pre["raw"], pre["probs"], pre["pred_idx"], pre["qf"]) + else: + idx = det["idx"] + tensor, _ = test_ds[idx] + img_path, _ = test_ds.samples[idx] + raw = np.array( + Image.open(img_path).convert("RGB").resize((224, 224)), + dtype=np.float64, + ) / 255.0 + inp = tensor.unsqueeze(0).to(device) + _save_img(raw, out / "original.png") + probs, pred_idx, qf = _forward(inp) + entropy = -(probs * torch.log(probs + 1e-12)).sum().item() # ScoreCAM @@ -206,7 +213,7 @@ def compute_xai(det): out / "scorecam.png", ) - # LIME (reduced from 1000 to 500 samples) + # LIME (500 samples) expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500) c1, c2 = expl.top_labels[0], expl.top_labels[1] t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False) @@ -221,15 +228,29 @@ def compute_xai(det): diff[segs == sid] = w1.get(sid, 0) - w2.get(sid, 0) mx = max(np.abs(diff).max(), 1e-8) diff /= mx - col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3] + col = mpl_cm.RdBu_r((diff + 1) / 2)[:, :, :3] _save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png") - # nearest neighbours (reuse qf from above) - sims, idxs = (train_feats @ qf).topk(3) + # nearest neighbours — 2x2 composite with labels + sims, idxs = (train_feats @ qf).topk(4) nbs = [] - for k, (ni, ns) in enumerate(zip(idxs, sims)): - _save_img(_to_display(train_ds[ni.item()][0]), out / f"nb{k+1}.png") - nbs.append({"cls": CLASS_NAMES[train_labels[ni]], "sim": f"{ns:.3f}"}) + nb_pils = [] + for ni, ns in zip(idxs, sims): + arr = _to_display(train_ds[ni.item()][0]) + cls = CLASS_NAMES[train_labels[ni]] + sim = f"{ns:.3f}" + nbs.append({"cls": cls, "sim": sim}) + pil = Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8)) + draw = ImageDraw.Draw(pil) + label = f"{cls} ({sim})" + draw.rectangle([(0, pil.height - 22), (pil.width, pil.height)], fill=(0, 0, 0, 180)) + draw.text((6, pil.height - 20), label, fill=(255, 255, 255)) + nb_pils.append(pil) + w, h = nb_pils[0].size + grid = Image.new("RGB", (w * 2, h * 2)) + for i, p in enumerate(nb_pils): + grid.paste(p, ((i % 2) * w, (i // 2) * h)) + grid.save(out / "neighbours.png") meta = { "pred": CLASS_NAMES[pred_idx], @@ -252,10 +273,10 @@ _xai_events: dict[str, threading.Event] = {} _xai_lock = threading.Lock() -def _precompute_xai(det): +def _precompute_xai(det, pre=None): try: with _xai_lock: - compute_xai(det) + compute_xai(det, pre) finally: ev = _xai_events.get(det["id"]) if ev: @@ -280,12 +301,15 @@ def detail(det_id): def api_simulate(): idx = random.randint(0, len(test_ds) - 1) cam_id = random.choice(list(CAMERAS.keys())) - tensor, label = test_ds[idx] - with torch.no_grad(): - probs = torch.nn.functional.softmax( - model(tensor.unsqueeze(0).to(device)), dim=1, - ).cpu().squeeze() - pred_idx = probs.argmax().item() + tensor, _ = test_ds[idx] + inp = tensor.unsqueeze(0).to(device) + probs, pred_idx, qf = _forward(inp) + + img_path, _ = test_ds.samples[idx] + raw = np.array( + Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64, + ) / 255.0 + det = { "id": uuid.uuid4().hex[:8], "idx": idx, @@ -300,14 +324,12 @@ def api_simulate(): detections.append(det) out = XAI_DIR / det["id"] out.mkdir(exist_ok=True) - img_path, _ = test_ds.samples[idx] - raw = np.array( - Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64, - ) / 255.0 _save_img(raw, out / "original.png") + + pre = {"inp": inp, "raw": raw, "probs": probs, "pred_idx": pred_idx, "qf": qf} ev = threading.Event() _xai_events[det["id"]] = ev - threading.Thread(target=_precompute_xai, args=(det,), daemon=True).start() + threading.Thread(target=_precompute_xai, args=(det, pre), daemon=True).start() return jsonify(det) @@ -326,7 +348,7 @@ def api_xai(det_id): meta["urls"] = { k: f"{base}/{k}.png" for k in ["original", "scorecam", "lime1", "lime2", "contrastive", - "nb1", "nb2", "nb3"] + "neighbours"] } return jsonify(meta) @@ -962,11 +984,9 @@ document.addEventListener('keydown',e=>{ cap:'Green superpixels support this class; red superpixels oppose it.'}, {src:u.contrastive, title:'Contrastive Explanation', cap:d.contrast_leg}, + {src:u.neighbours, title:'Nearest Training Neighbours', + cap:d.nbs.map((nb,i)=>`#${i+1}: ${nb.cls} (sim ${nb.sim})`).join(' | ')}, ]; - d.nbs.forEach((nb,i)=>{ - slides.push({src:u['nb'+(i+1)], title:'Nearest Neighbour '+(i+1), - cap:nb.cls+' (cosine similarity '+nb.sim+')'}); - }); let dots=''; slides.forEach((_,i)=>{dots+=``;});