From dbb1656ef0c807ad0d5170c1e9c6660743997995 Mon Sep 17 00:00:00 2001 From: everbarry Date: Wed, 18 Mar 2026 12:48:42 +0100 Subject: [PATCH] XAI optimizatoin --- dashboard.py | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/dashboard.py b/dashboard.py index 7d1ab65..d9b3a4d 100644 --- a/dashboard.py +++ b/dashboard.py @@ -127,7 +127,7 @@ _unnorm = transforms.Normalize( mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225], ) -score_cam = ScoreCAM(model=model, target_layers=[model.features[7][0]]) +score_cam = ScoreCAM(model=model, target_layers=[model.features[-1]]) lime_exp = lime_image.LimeImageExplainer(random_state=SEED) @@ -170,23 +170,6 @@ def _save_img(arr, path): Image.fromarray(arr).save(path) -def _prob_chart(probs, pred_idx, path): - fig, ax = plt.subplots(figsize=(5, 2.6)) - colors = ["#e8832a" if i == pred_idx else "#4a90d9" for i in range(len(CLASS_NAMES))] - bars = ax.barh(CLASS_NAMES, probs, color=colors) - ax.set_xlim(0, 1) - ax.set_xlabel("Probability") - for bar, p in zip(bars, probs): - if p > 0.05: - ax.text(bar.get_width() - 0.01, - bar.get_y() + bar.get_height() / 2, - f"{p:.2f}", va="center", ha="right", - fontsize=8, color="white", fontweight="bold") - plt.tight_layout() - fig.savefig(path, dpi=120, bbox_inches="tight") - plt.close(fig) - - def compute_xai(det): out = XAI_DIR / det["id"] if (out / "meta.json").exists(): @@ -204,19 +187,27 @@ def compute_xai(det): _save_img(raw, out / "original.png") + # single forward pass: probabilities + feature vector for neighbours with torch.no_grad(): - probs = torch.nn.functional.softmax(model(inp), dim=1).cpu().squeeze() + feats = model.features(inp) + pooled = model.avgpool(feats) + logits = model.classifier(torch.flatten(pooled, 1)) + probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze() + qf = torch.nn.functional.normalize( + torch.flatten(pooled, 1).cpu(), dim=1, + ).squeeze() pred_idx = probs.argmax().item() entropy = -(probs * torch.log(probs + 1e-12)).sum().item() - _prob_chart(probs.numpy(), pred_idx, out / "chart.png") + # ScoreCAM hm = score_cam(input_tensor=inp, targets=None) _save_img( show_cam_on_image(raw.astype(np.float32), hm.squeeze(), use_rgb=True, image_weight=0.75), out / "scorecam.png", ) - expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=1000) + # LIME (reduced from 1000 to 500 samples) + expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500) c1, c2 = expl.top_labels[0], expl.top_labels[1] t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False) t2, m2 = expl.get_image_and_mask(c2, positive_only=False, num_features=10, hide_rest=False) @@ -233,10 +224,7 @@ def compute_xai(det): col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3] _save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png") - with torch.no_grad(): - qf = torch.nn.functional.normalize( - torch.flatten(model.avgpool(model.features(inp)), 1).cpu(), dim=1, - ).squeeze() + # nearest neighbours (reuse qf from above) sims, idxs = (train_feats @ qf).topk(3) nbs = [] for k, (ni, ns) in enumerate(zip(idxs, sims)): @@ -337,7 +325,7 @@ def api_xai(det_id): base = f"/xai/{det_id}" meta["urls"] = { k: f"{base}/{k}.png" - for k in ["original", "chart", "scorecam", "lime1", "lime2", "contrastive", + for k in ["original", "scorecam", "lime1", "lime2", "contrastive", "nb1", "nb2", "nb3"] } return jsonify(meta)