Spaces:
Running on L40S
Running on L40S
GPU memory safety
#2
by sk16er - opened
- neuron_steer/core.py +85 -84
neuron_steer/core.py
CHANGED
|
@@ -423,94 +423,95 @@ def compute_attribution(
|
|
| 423 |
if hasattr(layer.mlp, "neuron_act"):
|
| 424 |
layer.mlp.neuron_act = None
|
| 425 |
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
else:
|
| 439 |
-
counterfactual_logit =
|
| 440 |
-
|
| 441 |
-
else:
|
| 442 |
-
counterfactual_logit = logits[counterfactual_token_id]
|
| 443 |
-
metric = target_logit - counterfactual_logit
|
| 444 |
-
|
| 445 |
-
# Backward through linearized model
|
| 446 |
-
metric.backward()
|
| 447 |
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
layer_stats = {} # diagnostic info
|
| 451 |
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
mlp = layer.mlp
|
| 457 |
-
if not hasattr(mlp, "neuron_act") or mlp.neuron_act is None:
|
| 458 |
-
continue
|
| 459 |
-
if mlp.neuron_act.grad is None:
|
| 460 |
-
continue
|
| 461 |
-
|
| 462 |
-
act = mlp.neuron_act.detach() # [1, T, intermediate_size]
|
| 463 |
-
grad = mlp.neuron_act.grad # [1, T, intermediate_size]
|
| 464 |
-
|
| 465 |
-
# Attribution = gradient * activation (element-wise)
|
| 466 |
-
attr = (grad * act)[0] # [T, intermediate_size]
|
| 467 |
-
T = attr.shape[0]
|
| 468 |
-
|
| 469 |
-
# NaN-safe statistics (exclude NaN from sums)
|
| 470 |
-
valid_mask = ~torch.isnan(attr)
|
| 471 |
-
valid_attr = attr[valid_mask]
|
| 472 |
-
if valid_attr.numel() > 0:
|
| 473 |
-
layer_total = valid_attr.abs().sum().item()
|
| 474 |
-
layer_max = valid_attr.abs().max().item()
|
| 475 |
-
nan_frac = 1.0 - valid_mask.float().mean().item()
|
| 476 |
-
else:
|
| 477 |
-
layer_total = 0.0
|
| 478 |
-
layer_max = 0.0
|
| 479 |
-
nan_frac = 1.0
|
| 480 |
-
layer_stats[i] = {"total": layer_total, "max": layer_max, "nan_frac": nan_frac}
|
| 481 |
-
|
| 482 |
-
if last_n_positions is not None:
|
| 483 |
-
start_pos = max(0, T - last_n_positions)
|
| 484 |
-
elif filter_bos:
|
| 485 |
-
start_pos = 1
|
| 486 |
-
else:
|
| 487 |
-
start_pos = 0
|
| 488 |
-
for p in range(start_pos, T):
|
| 489 |
-
pos_attr = attr[p]
|
| 490 |
-
abs_attr = pos_attr.abs()
|
| 491 |
-
|
| 492 |
-
# NaN-safe topk: replace NaN with 0 so they don't crowd out valid values
|
| 493 |
-
nan_mask = torch.isnan(abs_attr)
|
| 494 |
-
if nan_mask.any():
|
| 495 |
-
abs_attr = abs_attr.clone()
|
| 496 |
-
abs_attr[nan_mask] = 0.0
|
| 497 |
-
|
| 498 |
-
# Keep top-k neurons at this position
|
| 499 |
-
k = min(top_k_per_layer, abs_attr.shape[0])
|
| 500 |
-
top_vals, top_idxs = abs_attr.topk(k)
|
| 501 |
-
|
| 502 |
-
for val, idx in zip(top_vals, top_idxs):
|
| 503 |
-
if val.item() > 1e-8:
|
| 504 |
-
n = idx.item()
|
| 505 |
-
if (i, n) in blacklist_neurons:
|
| 506 |
-
continue
|
| 507 |
-
nidx = NeuronIdx(layer=i, position=p, neuron=n)
|
| 508 |
-
attributions[nidx] = pos_attr[idx].item()
|
| 509 |
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
|
| 515 |
if verbose:
|
| 516 |
print(f" Attribution distribution by layer:")
|
|
|
|
| 423 |
if hasattr(layer.mlp, "neuron_act"):
|
| 424 |
layer.mlp.neuron_act = None
|
| 425 |
|
| 426 |
+
try:
|
| 427 |
+
with torch.enable_grad():
|
| 428 |
+
outputs = model(input_ids)
|
| 429 |
+
logits = outputs.logits[0, position] # [vocab_size]
|
| 430 |
+
|
| 431 |
+
target_logit = logits[target_token_id]
|
| 432 |
+
|
| 433 |
+
if target_only:
|
| 434 |
+
metric = target_logit
|
| 435 |
+
elif counterfactual_token_id is None:
|
| 436 |
+
sorted_logits, sorted_ids = logits.sort(descending=True)
|
| 437 |
+
if sorted_ids[0].item() == target_token_id:
|
| 438 |
+
counterfactual_logit = sorted_logits[1]
|
| 439 |
+
else:
|
| 440 |
+
counterfactual_logit = sorted_logits[0]
|
| 441 |
+
metric = target_logit - counterfactual_logit
|
| 442 |
else:
|
| 443 |
+
counterfactual_logit = logits[counterfactual_token_id]
|
| 444 |
+
metric = target_logit - counterfactual_logit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
+
# Backward through linearized model
|
| 447 |
+
metric.backward()
|
|
|
|
| 448 |
|
| 449 |
+
# Collect attributions from saved neuron activations
|
| 450 |
+
attributions = {}
|
| 451 |
+
layer_stats = {} # diagnostic info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
+
for i, layer in enumerate(_get_model_layers(model)):
|
| 454 |
+
if i in blacklist_layers:
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
mlp = layer.mlp
|
| 458 |
+
if not hasattr(mlp, "neuron_act") or mlp.neuron_act is None:
|
| 459 |
+
continue
|
| 460 |
+
if mlp.neuron_act.grad is None:
|
| 461 |
+
continue
|
| 462 |
+
|
| 463 |
+
act = mlp.neuron_act.detach() # [1, T, intermediate_size]
|
| 464 |
+
grad = mlp.neuron_act.grad # [1, T, intermediate_size]
|
| 465 |
+
|
| 466 |
+
# Attribution = gradient * activation (element-wise)
|
| 467 |
+
attr = (grad * act)[0] # [T, intermediate_size]
|
| 468 |
+
T = attr.shape[0]
|
| 469 |
+
|
| 470 |
+
# NaN-safe statistics (exclude NaN from sums)
|
| 471 |
+
valid_mask = ~torch.isnan(attr)
|
| 472 |
+
valid_attr = attr[valid_mask]
|
| 473 |
+
if valid_attr.numel() > 0:
|
| 474 |
+
layer_total = valid_attr.abs().sum().item()
|
| 475 |
+
layer_max = valid_attr.abs().max().item()
|
| 476 |
+
nan_frac = 1.0 - valid_mask.float().mean().item()
|
| 477 |
+
else:
|
| 478 |
+
layer_total = 0.0
|
| 479 |
+
layer_max = 0.0
|
| 480 |
+
nan_frac = 1.0
|
| 481 |
+
layer_stats[i] = {"total": layer_total, "max": layer_max, "nan_frac": nan_frac}
|
| 482 |
+
|
| 483 |
+
if last_n_positions is not None:
|
| 484 |
+
start_pos = max(0, T - last_n_positions)
|
| 485 |
+
elif filter_bos:
|
| 486 |
+
start_pos = 1
|
| 487 |
+
else:
|
| 488 |
+
start_pos = 0
|
| 489 |
+
for p in range(start_pos, T):
|
| 490 |
+
pos_attr = attr[p]
|
| 491 |
+
abs_attr = pos_attr.abs()
|
| 492 |
+
|
| 493 |
+
# NaN-safe topk: replace NaN with 0 so they don't crowd out valid values
|
| 494 |
+
nan_mask = torch.isnan(abs_attr)
|
| 495 |
+
if nan_mask.any():
|
| 496 |
+
abs_attr = abs_attr.clone()
|
| 497 |
+
abs_attr[nan_mask] = 0.0
|
| 498 |
+
|
| 499 |
+
# Keep top-k neurons at this position
|
| 500 |
+
k = min(top_k_per_layer, abs_attr.shape[0])
|
| 501 |
+
top_vals, top_idxs = abs_attr.topk(k)
|
| 502 |
+
|
| 503 |
+
for val, idx in zip(top_vals, top_idxs):
|
| 504 |
+
if val.item() > 1e-8:
|
| 505 |
+
n = idx.item()
|
| 506 |
+
if (i, n) in blacklist_neurons:
|
| 507 |
+
continue
|
| 508 |
+
nidx = NeuronIdx(layer=i, position=p, neuron=n)
|
| 509 |
+
attributions[nidx] = pos_attr[idx].item()
|
| 510 |
+
finally:
|
| 511 |
+
# Free GPU memory - clear saved activations after collection
|
| 512 |
+
for layer in _get_model_layers(model):
|
| 513 |
+
if hasattr(layer.mlp, "neuron_act"):
|
| 514 |
+
layer.mlp.neuron_act = None
|
| 515 |
|
| 516 |
if verbose:
|
| 517 |
print(f" Attribution distribution by layer:")
|