Fix saliency: enable_grad + debug logging
Browse files
app.py
CHANGED
|
@@ -248,14 +248,17 @@ def _compute_saliency(bw_t, adj_t, models):
|
|
| 248 |
maps = []
|
| 249 |
for _, task in sample:
|
| 250 |
try:
|
| 251 |
-
adj = adj_t.clone().requires_grad_(True)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
| 256 |
if adj.grad is not None:
|
| 257 |
-
maps.append(adj.grad[0].abs().detach().numpy())
|
| 258 |
-
except Exception:
|
|
|
|
| 259 |
continue
|
| 260 |
if not maps:
|
| 261 |
n = adj_t.shape[-1]
|
|
@@ -522,7 +525,9 @@ def run_gcn(file_path):
|
|
| 522 |
net_bounds=net_bounds,
|
| 523 |
net_colors=atlas_cfg["net_colors"],
|
| 524 |
)
|
| 525 |
-
except Exception:
|
|
|
|
|
|
|
| 526 |
sal_img = None
|
| 527 |
|
| 528 |
# ── Verdict ──
|
|
|
|
| 248 |
maps = []
|
| 249 |
for _, task in sample:
|
| 250 |
try:
|
| 251 |
+
adj = adj_t.clone().detach().requires_grad_(True)
|
| 252 |
+
bw = bw_t.clone().detach()
|
| 253 |
+
with torch.enable_grad():
|
| 254 |
+
out = task.model(bw, adj)
|
| 255 |
+
logits = out[0] if isinstance(out, tuple) else out
|
| 256 |
+
prob = torch.softmax(logits, -1)[0, 1]
|
| 257 |
+
prob.backward()
|
| 258 |
if adj.grad is not None:
|
| 259 |
+
maps.append(adj.grad[0].abs().detach().cpu().numpy())
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"[saliency model] {e}")
|
| 262 |
continue
|
| 263 |
if not maps:
|
| 264 |
n = adj_t.shape[-1]
|
|
|
|
| 525 |
net_bounds=net_bounds,
|
| 526 |
net_colors=atlas_cfg["net_colors"],
|
| 527 |
)
|
| 528 |
+
except Exception as _sal_err:
|
| 529 |
+
print(f"[saliency] failed: {_sal_err}")
|
| 530 |
+
import traceback; traceback.print_exc()
|
| 531 |
sal_img = None
|
| 532 |
|
| 533 |
# ── Verdict ──
|