""" Wildlife Monitoring Dashboard — Yellowstone National Park Flask app with two pages: / - fullscreen map with camera markers + togglable sidebar /det/ - detection detail page with all XAI visualisations Run: uv run python dashboard.py """ import csv import json import random import threading import uuid from datetime import datetime from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.cm as mpl_cm from PIL import Image, ImageDraw import torch from torch import nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights from pytorch_grad_cam import ScoreCAM from pytorch_grad_cam.utils.image import show_cam_on_image from lime import lime_image from skimage.segmentation import mark_boundaries from flask import Flask, render_template_string, jsonify, send_from_directory, request # ── constants ───────────────────────────────────────────────────────────────── SEED = 42 random.seed(SEED) np.random.seed(SEED) torch.manual_seed(SEED) CLASS_NAMES = ["bear", "deer", "fox", "hare", "moose", "person", "wolf"] WEIGHTS_PATH = Path("efficientnet_v2_wild_forest_animals.pt") DATASET_DIR = Path("wild-forest-animals-and-person-1") XAI_DIR = Path("_xai_cache") XAI_DIR.mkdir(exist_ok=True) # Camera positions as percentages of the map image (adjust to match map.webp) CAMERAS = { "CAM-01": {"name": "Lamar Valley", "px": 65, "py": 17, "desc": "Northeast corridor, prime wolf and bison territory"}, "CAM-02": {"name": "Hayden Valley", "px": 48, "py": 38, "desc": "Central meadows between canyon and lake"}, "CAM-03": {"name": "Mammoth Hot Springs", "px": 28, "py": 12, "desc": "Northern range, year-round elk habitat"}, "CAM-04": {"name": "Old Faithful", "px": 24, "py": 54, "desc": "Upper Geyser Basin, forested southwest"}, "CAM-05": {"name": "Yellowstone Lake", "px": 55, "py": 48, "desc": "Eastern shoreline, moose and waterfowl corridor"}, } SPECIES_ICON = { "bear": "\U0001f43b", "deer": "\U0001f98c", "fox": "\U0001f98a", "hare": "\U0001f407", "moose": "\U0001f98c", "person": "\U0001f9d1", "wolf": "\U0001f43a", } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── dataset ─────────────────────────────────────────────────────────────────── class WildForestAnimalsDataset(Dataset): def __init__(self, root, split, transform=None): self.root, self.transform = Path(root), transform split_dir = self.root / split self.samples = [] with (split_dir / "_classes.csv").open(newline="") as f: for row in csv.DictReader(f): oh = [int(row[n]) for n in CLASS_NAMES] try: label = oh.index(1) except ValueError: continue self.samples.append((split_dir / row["filename"], label)) def __len__(self): return len(self.samples) def __getitem__(self, idx): p, l = self.samples[idx] img = Image.open(p).convert("RGB") return self.transform(img) if self.transform else img, l _eval_tf = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224), antialias=True), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ── dataset auto-download ────────────────────────────────────────────────── if not DATASET_DIR.exists(): print("Dataset not found -- downloading from Roboflow ...") from roboflow import Roboflow rf = Roboflow(api_key="VCZWezdoCHQz7juipBdt") project = rf.workspace("forestanimals").project("wild-forest-animals-and-person") project.version(1).download("multiclass") if not DATASET_DIR.exists(): raise RuntimeError(f"Download finished but {DATASET_DIR} not found.") # ── model ───────────────────────────────────────────────────────────────────── print("Loading model and data …") model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) model.classifier[1] = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES)) model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device, weights_only=True)) model.to(device).eval() test_ds = WildForestAnimalsDataset(DATASET_DIR, "test", transform=_eval_tf) train_ds = WildForestAnimalsDataset(DATASET_DIR, "train", transform=_eval_tf) _norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) _unnorm = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225], ) score_cam = ScoreCAM(model=model, target_layers=[model.features[-1]]) lime_exp = lime_image.LimeImageExplainer(random_state=SEED) def _to_display(t): return torch.clamp(_unnorm(t), 0, 1).numpy().transpose(1, 2, 0) def _lime_predict(images): batch = torch.stack([ _norm(torch.tensor(im.transpose(2, 0, 1), dtype=torch.float32)) for im in images ]).to(device) with torch.no_grad(): return torch.nn.functional.softmax(model(batch), dim=1).cpu().numpy() print("Pre-computing training features …") def _extract_features(ds, bs=64): loader = DataLoader(ds, batch_size=bs, shuffle=False) feats, labs = [], [] with torch.no_grad(): for imgs, lbls in loader: x = torch.flatten(model.avgpool(model.features(imgs.to(device))), 1) feats.append(x.cpu()) labs.extend(lbls.tolist()) return torch.nn.functional.normalize(torch.cat(feats), dim=1), labs train_feats, train_labels = _extract_features(train_ds) print("Ready.") # ── XAI helpers ─────────────────────────────────────────────────────────────── def _save_img(arr, path): if arr.dtype != np.uint8: arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8) Image.fromarray(arr).save(path) def _forward(inp): """Single split forward pass returning probs, pred_idx, and neighbour query vector.""" with torch.no_grad(): pooled = model.avgpool(model.features(inp)) logits = model.classifier(torch.flatten(pooled, 1)) probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze() qf = torch.nn.functional.normalize( torch.flatten(pooled, 1).cpu(), dim=1, ).squeeze() return probs, probs.argmax().item(), qf def compute_xai(det, pre=None): """Compute all XAI artefacts. `pre` is an optional dict with keys inp, raw, probs, pred_idx, qf to skip the duplicate forward pass.""" out = XAI_DIR / det["id"] if (out / "meta.json").exists(): return out.mkdir(exist_ok=True) if pre is not None: inp, raw, probs, pred_idx, qf = ( pre["inp"], pre["raw"], pre["probs"], pre["pred_idx"], pre["qf"]) else: idx = det["idx"] tensor, _ = test_ds[idx] img_path, _ = test_ds.samples[idx] raw = np.array( Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64, ) / 255.0 inp = tensor.unsqueeze(0).to(device) _save_img(raw, out / "original.png") probs, pred_idx, qf = _forward(inp) entropy = -(probs * torch.log(probs + 1e-12)).sum().item() # ScoreCAM hm = score_cam(input_tensor=inp, targets=None) _save_img( show_cam_on_image(raw.astype(np.float32), hm.squeeze(), use_rgb=True, image_weight=0.75), out / "scorecam.png", ) # LIME (500 samples) expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500) c1, c2 = expl.top_labels[0], expl.top_labels[1] t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False) t2, m2 = expl.get_image_and_mask(c2, positive_only=False, num_features=10, hide_rest=False) _save_img(mark_boundaries(t1, m1), out / "lime1.png") _save_img(mark_boundaries(t2, m2), out / "lime2.png") segs = expl.segments w1, w2 = dict(expl.local_exp[c1]), dict(expl.local_exp[c2]) diff = np.zeros(segs.shape, dtype=np.float64) for sid in np.unique(segs): diff[segs == sid] = w1.get(sid, 0) - w2.get(sid, 0) mx = max(np.abs(diff).max(), 1e-8) diff /= mx col = mpl_cm.RdBu_r((diff + 1) / 2)[:, :, :3] _save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png") # nearest neighbours — 2x2 composite with labels sims, idxs = (train_feats @ qf).topk(4) nbs = [] nb_pils = [] for ni, ns in zip(idxs, sims): arr = _to_display(train_ds[ni.item()][0]) cls = CLASS_NAMES[train_labels[ni]] sim = f"{ns:.3f}" nbs.append({"cls": cls, "sim": sim}) pil = Image.fromarray((np.clip(arr, 0, 1) * 255).astype(np.uint8)) draw = ImageDraw.Draw(pil) label = f"{cls} ({sim})" draw.rectangle([(0, pil.height - 22), (pil.width, pil.height)], fill=(0, 0, 0, 180)) draw.text((6, pil.height - 20), label, fill=(255, 255, 255)) nb_pils.append(pil) w, h = nb_pils[0].size grid = Image.new("RGB", (w * 2, h * 2)) for i, p in enumerate(nb_pils): grid.paste(p, ((i % 2) * w, (i // 2) * h)) grid.save(out / "neighbours.png") meta = { "pred": CLASS_NAMES[pred_idx], "conf": round(probs[pred_idx].item() * 100, 1), "ppl": round(float(np.exp(entropy)), 2), "probs": {CLASS_NAMES[i]: round(float(p), 4) for i, p in enumerate(probs)}, "lime1_cls": CLASS_NAMES[c1], "lime2_cls": CLASS_NAMES[c2], "contrast_leg": f"Blue = {CLASS_NAMES[c1]} | Red = {CLASS_NAMES[c2]}", "nbs": nbs, } (out / "meta.json").write_text(json.dumps(meta)) # ── Flask app ───────────────────────────────────────────────────────────────── app = Flask(__name__) detections: list[dict] = [] _xai_events: dict[str, threading.Event] = {} _xai_lock = threading.Lock() def _precompute_xai(det, pre=None): try: with _xai_lock: compute_xai(det, pre) finally: ev = _xai_events.get(det["id"]) if ev: ev.set() @app.route("/") @app.route("/home") def home(): return render_template_string(HOME_HTML, cameras=CAMERAS, class_names=CLASS_NAMES) @app.route("/det/") def detail(det_id): det = next((d for d in detections if d["id"] == det_id), None) if det is None: return "Detection not found", 404 return render_template_string(DETAIL_HTML, det=det, cameras=CAMERAS, class_names=CLASS_NAMES) @app.route("/api/simulate", methods=["POST"]) def api_simulate(): idx = random.randint(0, len(test_ds) - 1) cam_id = random.choice(list(CAMERAS.keys())) tensor, _ = test_ds[idx] inp = tensor.unsqueeze(0).to(device) probs, pred_idx, qf = _forward(inp) img_path, _ = test_ds.samples[idx] raw = np.array( Image.open(img_path).convert("RGB").resize((224, 224)), dtype=np.float64, ) / 255.0 det = { "id": uuid.uuid4().hex[:8], "idx": idx, "cam": cam_id, "cam_name": CAMERAS[cam_id]["name"], "pred": CLASS_NAMES[pred_idx], "conf": round(probs[pred_idx].item() * 100, 1), "time": datetime.now().strftime("%H:%M:%S"), "verified": False, "manual": False, } detections.append(det) out = XAI_DIR / det["id"] out.mkdir(exist_ok=True) _save_img(raw, out / "original.png") pre = {"inp": inp, "raw": raw, "probs": probs, "pred_idx": pred_idx, "qf": qf} ev = threading.Event() _xai_events[det["id"]] = ev threading.Thread(target=_precompute_xai, args=(det, pre), daemon=True).start() return jsonify(det) @app.route("/api/xai/") def api_xai(det_id): det = next((d for d in detections if d["id"] == det_id), None) if det is None: return jsonify(error="not found"), 404 ev = _xai_events.get(det_id) if ev: ev.wait() else: compute_xai(det) meta = json.loads((XAI_DIR / det_id / "meta.json").read_text()) base = f"/xai/{det_id}" meta["urls"] = { k: f"{base}/{k}.png" for k in ["original", "scorecam", "lime1", "lime2", "contrastive", "neighbours"] } return jsonify(meta) @app.route("/cam/") def camera(cam_id): if cam_id not in CAMERAS: return "Camera not found", 404 cam_dets = [d for d in reversed(detections) if d["cam"] == cam_id] return render_template_string( CAM_HTML, cam_id=cam_id, cam=CAMERAS[cam_id], dets=cam_dets, class_names=CLASS_NAMES, ) @app.route("/api/verify/", methods=["POST"]) def api_verify(det_id): det = next((d for d in detections if d["id"] == det_id), None) if det is None: return jsonify(error="not found"), 404 data = request.get_json() if data.get("action") == "correct": det["verified"] = True elif data.get("action") == "wrong": det["verified"] = True det["manual"] = True det["orig_pred"] = det["pred"] det["orig_conf"] = det["conf"] det["pred"] = data["true_class"] det["conf"] = 100.0 return jsonify(det) @app.route("/api/detections") def api_detections(): return jsonify(detections) @app.route("/map.jpg") def serve_map(): return send_from_directory(".", "yellowstone-camping-map.jpg") @app.route("/xai//") def serve_xai(det_id, filename): return send_from_directory(str(XAI_DIR / det_id), filename) # ── HTML: home page ────────────────────────────────────────────────────────── HOME_HTML = r""" Yellowstone — Wildlife Monitor

Yellowstone National Park

Wildlife Camera Monitoring

Park map {% for cid, c in cameras.items() %}
{{ cid }} · {{ c.name }}
{% endfor %}
""" # ── HTML: detail page ───────────────────────────────────────────────────────── DETAIL_HTML = r""" Detection {{ det.id }} — Yellowstone
← Dashboard
Detection {{ det.id }}
{{ det.cam }} · {{ cameras[det.cam].name }} · {{ det.time }}
{% if det.verified %} {% if det.manual %}
Manually corrected Model predicted {{ det.orig_pred }} — ranger corrected to {{ det.pred }}
{% else %}
Verified correct A ranger confirmed this detection is accurate.
{% endif %} {% else %}
Is this detection correct?
{% endif %}
Computing explanations …
Loading…

Class Probabilities

""" # ── HTML: camera feed page ──────────────────────────────────────────────────── CAM_HTML = r""" {{ cam_id }} — {{ cam.name }}
← Dashboard
{{ cam_id }} — {{ cam.name }}
{{ cam.desc }}

Detections by Species (this camera)

{% if dets %} {% else %}
No detections from this camera yet.
Detections will appear here automatically.
{% endif %}
""" if __name__ == "__main__": app.run(debug=True, port=5000)