XAI optimizatoin

This commit is contained in:
everbarry 2026-03-18 12:48:42 +01:00
parent 3cb0d59bd6
commit dbb1656ef0

View File

@ -127,7 +127,7 @@ _unnorm = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225], std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
) )
score_cam = ScoreCAM(model=model, target_layers=[model.features[7][0]]) score_cam = ScoreCAM(model=model, target_layers=[model.features[-1]])
lime_exp = lime_image.LimeImageExplainer(random_state=SEED) lime_exp = lime_image.LimeImageExplainer(random_state=SEED)
@ -170,23 +170,6 @@ def _save_img(arr, path):
Image.fromarray(arr).save(path) Image.fromarray(arr).save(path)
def _prob_chart(probs, pred_idx, path):
fig, ax = plt.subplots(figsize=(5, 2.6))
colors = ["#e8832a" if i == pred_idx else "#4a90d9" for i in range(len(CLASS_NAMES))]
bars = ax.barh(CLASS_NAMES, probs, color=colors)
ax.set_xlim(0, 1)
ax.set_xlabel("Probability")
for bar, p in zip(bars, probs):
if p > 0.05:
ax.text(bar.get_width() - 0.01,
bar.get_y() + bar.get_height() / 2,
f"{p:.2f}", va="center", ha="right",
fontsize=8, color="white", fontweight="bold")
plt.tight_layout()
fig.savefig(path, dpi=120, bbox_inches="tight")
plt.close(fig)
def compute_xai(det): def compute_xai(det):
out = XAI_DIR / det["id"] out = XAI_DIR / det["id"]
if (out / "meta.json").exists(): if (out / "meta.json").exists():
@ -204,19 +187,27 @@ def compute_xai(det):
_save_img(raw, out / "original.png") _save_img(raw, out / "original.png")
# single forward pass: probabilities + feature vector for neighbours
with torch.no_grad(): with torch.no_grad():
probs = torch.nn.functional.softmax(model(inp), dim=1).cpu().squeeze() feats = model.features(inp)
pooled = model.avgpool(feats)
logits = model.classifier(torch.flatten(pooled, 1))
probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze()
qf = torch.nn.functional.normalize(
torch.flatten(pooled, 1).cpu(), dim=1,
).squeeze()
pred_idx = probs.argmax().item() pred_idx = probs.argmax().item()
entropy = -(probs * torch.log(probs + 1e-12)).sum().item() entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
_prob_chart(probs.numpy(), pred_idx, out / "chart.png")
# ScoreCAM
hm = score_cam(input_tensor=inp, targets=None) hm = score_cam(input_tensor=inp, targets=None)
_save_img( _save_img(
show_cam_on_image(raw.astype(np.float32), hm.squeeze(), use_rgb=True, image_weight=0.75), show_cam_on_image(raw.astype(np.float32), hm.squeeze(), use_rgb=True, image_weight=0.75),
out / "scorecam.png", out / "scorecam.png",
) )
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=1000) # LIME (reduced from 1000 to 500 samples)
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500)
c1, c2 = expl.top_labels[0], expl.top_labels[1] c1, c2 = expl.top_labels[0], expl.top_labels[1]
t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False) t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False)
t2, m2 = expl.get_image_and_mask(c2, positive_only=False, num_features=10, hide_rest=False) t2, m2 = expl.get_image_and_mask(c2, positive_only=False, num_features=10, hide_rest=False)
@ -233,10 +224,7 @@ def compute_xai(det):
col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3] col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3]
_save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png") _save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png")
with torch.no_grad(): # nearest neighbours (reuse qf from above)
qf = torch.nn.functional.normalize(
torch.flatten(model.avgpool(model.features(inp)), 1).cpu(), dim=1,
).squeeze()
sims, idxs = (train_feats @ qf).topk(3) sims, idxs = (train_feats @ qf).topk(3)
nbs = [] nbs = []
for k, (ni, ns) in enumerate(zip(idxs, sims)): for k, (ni, ns) in enumerate(zip(idxs, sims)):
@ -337,7 +325,7 @@ def api_xai(det_id):
base = f"/xai/{det_id}" base = f"/xai/{det_id}"
meta["urls"] = { meta["urls"] = {
k: f"{base}/{k}.png" k: f"{base}/{k}.png"
for k in ["original", "chart", "scorecam", "lime1", "lime2", "contrastive", for k in ["original", "scorecam", "lime1", "lime2", "contrastive",
"nb1", "nb2", "nb3"] "nb1", "nb2", "nb3"]
} }
return jsonify(meta) return jsonify(meta)