XAI optimizatoin
This commit is contained in:
parent
3cb0d59bd6
commit
dbb1656ef0
40
dashboard.py
40
dashboard.py
@ -127,7 +127,7 @@ _unnorm = transforms.Normalize(
|
||||
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
|
||||
std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
|
||||
)
|
||||
score_cam = ScoreCAM(model=model, target_layers=[model.features[7][0]])
|
||||
score_cam = ScoreCAM(model=model, target_layers=[model.features[-1]])
|
||||
lime_exp = lime_image.LimeImageExplainer(random_state=SEED)
|
||||
|
||||
|
||||
@ -170,23 +170,6 @@ def _save_img(arr, path):
|
||||
Image.fromarray(arr).save(path)
|
||||
|
||||
|
||||
def _prob_chart(probs, pred_idx, path):
|
||||
fig, ax = plt.subplots(figsize=(5, 2.6))
|
||||
colors = ["#e8832a" if i == pred_idx else "#4a90d9" for i in range(len(CLASS_NAMES))]
|
||||
bars = ax.barh(CLASS_NAMES, probs, color=colors)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_xlabel("Probability")
|
||||
for bar, p in zip(bars, probs):
|
||||
if p > 0.05:
|
||||
ax.text(bar.get_width() - 0.01,
|
||||
bar.get_y() + bar.get_height() / 2,
|
||||
f"{p:.2f}", va="center", ha="right",
|
||||
fontsize=8, color="white", fontweight="bold")
|
||||
plt.tight_layout()
|
||||
fig.savefig(path, dpi=120, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def compute_xai(det):
|
||||
out = XAI_DIR / det["id"]
|
||||
if (out / "meta.json").exists():
|
||||
@ -204,19 +187,27 @@ def compute_xai(det):
|
||||
|
||||
_save_img(raw, out / "original.png")
|
||||
|
||||
# single forward pass: probabilities + feature vector for neighbours
|
||||
with torch.no_grad():
|
||||
probs = torch.nn.functional.softmax(model(inp), dim=1).cpu().squeeze()
|
||||
feats = model.features(inp)
|
||||
pooled = model.avgpool(feats)
|
||||
logits = model.classifier(torch.flatten(pooled, 1))
|
||||
probs = torch.nn.functional.softmax(logits, dim=1).cpu().squeeze()
|
||||
qf = torch.nn.functional.normalize(
|
||||
torch.flatten(pooled, 1).cpu(), dim=1,
|
||||
).squeeze()
|
||||
pred_idx = probs.argmax().item()
|
||||
entropy = -(probs * torch.log(probs + 1e-12)).sum().item()
|
||||
_prob_chart(probs.numpy(), pred_idx, out / "chart.png")
|
||||
|
||||
# ScoreCAM
|
||||
hm = score_cam(input_tensor=inp, targets=None)
|
||||
_save_img(
|
||||
show_cam_on_image(raw.astype(np.float32), hm.squeeze(), use_rgb=True, image_weight=0.75),
|
||||
out / "scorecam.png",
|
||||
)
|
||||
|
||||
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=1000)
|
||||
# LIME (reduced from 1000 to 500 samples)
|
||||
expl = lime_exp.explain_instance(raw, _lime_predict, top_labels=2, hide_color=0, num_samples=500)
|
||||
c1, c2 = expl.top_labels[0], expl.top_labels[1]
|
||||
t1, m1 = expl.get_image_and_mask(c1, positive_only=False, num_features=10, hide_rest=False)
|
||||
t2, m2 = expl.get_image_and_mask(c2, positive_only=False, num_features=10, hide_rest=False)
|
||||
@ -233,10 +224,7 @@ def compute_xai(det):
|
||||
col = plt.cm.RdBu_r((diff + 1) / 2)[:, :, :3]
|
||||
_save_img((0.6 * raw + 0.4 * col).clip(0, 1), out / "contrastive.png")
|
||||
|
||||
with torch.no_grad():
|
||||
qf = torch.nn.functional.normalize(
|
||||
torch.flatten(model.avgpool(model.features(inp)), 1).cpu(), dim=1,
|
||||
).squeeze()
|
||||
# nearest neighbours (reuse qf from above)
|
||||
sims, idxs = (train_feats @ qf).topk(3)
|
||||
nbs = []
|
||||
for k, (ni, ns) in enumerate(zip(idxs, sims)):
|
||||
@ -337,7 +325,7 @@ def api_xai(det_id):
|
||||
base = f"/xai/{det_id}"
|
||||
meta["urls"] = {
|
||||
k: f"{base}/{k}.png"
|
||||
for k in ["original", "chart", "scorecam", "lime1", "lime2", "contrastive",
|
||||
for k in ["original", "scorecam", "lime1", "lime2", "contrastive",
|
||||
"nb1", "nb2", "nb3"]
|
||||
}
|
||||
return jsonify(meta)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user