Yatsuiii commited on
Commit
ee9799c
·
verified ·
1 Parent(s): 33100f8

Fix saliency: enable_grad + debug logging

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -248,14 +248,17 @@ def _compute_saliency(bw_t, adj_t, models):
248
  maps = []
249
  for _, task in sample:
250
  try:
251
- adj = adj_t.clone().requires_grad_(True)
252
- logits = task.model(bw_t, adj)
253
- if isinstance(logits, tuple):
254
- logits = logits[0]
255
- torch.softmax(logits, -1)[0, 1].backward()
 
 
256
  if adj.grad is not None:
257
- maps.append(adj.grad[0].abs().detach().numpy())
258
- except Exception:
 
259
  continue
260
  if not maps:
261
  n = adj_t.shape[-1]
@@ -522,7 +525,9 @@ def run_gcn(file_path):
522
  net_bounds=net_bounds,
523
  net_colors=atlas_cfg["net_colors"],
524
  )
525
- except Exception:
 
 
526
  sal_img = None
527
 
528
  # ── Verdict ──
 
248
  maps = []
249
  for _, task in sample:
250
  try:
251
+ adj = adj_t.clone().detach().requires_grad_(True)
252
+ bw = bw_t.clone().detach()
253
+ with torch.enable_grad():
254
+ out = task.model(bw, adj)
255
+ logits = out[0] if isinstance(out, tuple) else out
256
+ prob = torch.softmax(logits, -1)[0, 1]
257
+ prob.backward()
258
  if adj.grad is not None:
259
+ maps.append(adj.grad[0].abs().detach().cpu().numpy())
260
+ except Exception as e:
261
+ print(f"[saliency model] {e}")
262
  continue
263
  if not maps:
264
  n = adj_t.shape[-1]
 
525
  net_bounds=net_bounds,
526
  net_colors=atlas_cfg["net_colors"],
527
  )
528
+ except Exception as _sal_err:
529
+ print(f"[saliency] failed: {_sal_err}")
530
+ import traceback; traceback.print_exc()
531
  sal_img = None
532
 
533
  # ── Verdict ──