GPU memory safety

#2
by sk16er - opened
Files changed (1) hide show
  1. neuron_steer/core.py +85 -84
neuron_steer/core.py CHANGED
@@ -423,94 +423,95 @@ def compute_attribution(
423
  if hasattr(layer.mlp, "neuron_act"):
424
  layer.mlp.neuron_act = None
425
 
426
- with torch.enable_grad():
427
- outputs = model(input_ids)
428
- logits = outputs.logits[0, position] # [vocab_size]
429
-
430
- target_logit = logits[target_token_id]
431
-
432
- if target_only:
433
- metric = target_logit
434
- elif counterfactual_token_id is None:
435
- sorted_logits, sorted_ids = logits.sort(descending=True)
436
- if sorted_ids[0].item() == target_token_id:
437
- counterfactual_logit = sorted_logits[1]
 
 
 
 
438
  else:
439
- counterfactual_logit = sorted_logits[0]
440
- metric = target_logit - counterfactual_logit
441
- else:
442
- counterfactual_logit = logits[counterfactual_token_id]
443
- metric = target_logit - counterfactual_logit
444
-
445
- # Backward through linearized model
446
- metric.backward()
447
 
448
- # Collect attributions from saved neuron activations
449
- attributions = {}
450
- layer_stats = {} # diagnostic info
451
 
452
- for i, layer in enumerate(_get_model_layers(model)):
453
- if i in blacklist_layers:
454
- continue
455
-
456
- mlp = layer.mlp
457
- if not hasattr(mlp, "neuron_act") or mlp.neuron_act is None:
458
- continue
459
- if mlp.neuron_act.grad is None:
460
- continue
461
-
462
- act = mlp.neuron_act.detach() # [1, T, intermediate_size]
463
- grad = mlp.neuron_act.grad # [1, T, intermediate_size]
464
-
465
- # Attribution = gradient * activation (element-wise)
466
- attr = (grad * act)[0] # [T, intermediate_size]
467
- T = attr.shape[0]
468
-
469
- # NaN-safe statistics (exclude NaN from sums)
470
- valid_mask = ~torch.isnan(attr)
471
- valid_attr = attr[valid_mask]
472
- if valid_attr.numel() > 0:
473
- layer_total = valid_attr.abs().sum().item()
474
- layer_max = valid_attr.abs().max().item()
475
- nan_frac = 1.0 - valid_mask.float().mean().item()
476
- else:
477
- layer_total = 0.0
478
- layer_max = 0.0
479
- nan_frac = 1.0
480
- layer_stats[i] = {"total": layer_total, "max": layer_max, "nan_frac": nan_frac}
481
-
482
- if last_n_positions is not None:
483
- start_pos = max(0, T - last_n_positions)
484
- elif filter_bos:
485
- start_pos = 1
486
- else:
487
- start_pos = 0
488
- for p in range(start_pos, T):
489
- pos_attr = attr[p]
490
- abs_attr = pos_attr.abs()
491
-
492
- # NaN-safe topk: replace NaN with 0 so they don't crowd out valid values
493
- nan_mask = torch.isnan(abs_attr)
494
- if nan_mask.any():
495
- abs_attr = abs_attr.clone()
496
- abs_attr[nan_mask] = 0.0
497
-
498
- # Keep top-k neurons at this position
499
- k = min(top_k_per_layer, abs_attr.shape[0])
500
- top_vals, top_idxs = abs_attr.topk(k)
501
-
502
- for val, idx in zip(top_vals, top_idxs):
503
- if val.item() > 1e-8:
504
- n = idx.item()
505
- if (i, n) in blacklist_neurons:
506
- continue
507
- nidx = NeuronIdx(layer=i, position=p, neuron=n)
508
- attributions[nidx] = pos_attr[idx].item()
509
 
510
- # Free GPU memory - clear saved activations after collection
511
- for layer in _get_model_layers(model):
512
- if hasattr(layer.mlp, "neuron_act"):
513
- layer.mlp.neuron_act = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  if verbose:
516
  print(f" Attribution distribution by layer:")
 
423
  if hasattr(layer.mlp, "neuron_act"):
424
  layer.mlp.neuron_act = None
425
 
426
+ try:
427
+ with torch.enable_grad():
428
+ outputs = model(input_ids)
429
+ logits = outputs.logits[0, position] # [vocab_size]
430
+
431
+ target_logit = logits[target_token_id]
432
+
433
+ if target_only:
434
+ metric = target_logit
435
+ elif counterfactual_token_id is None:
436
+ sorted_logits, sorted_ids = logits.sort(descending=True)
437
+ if sorted_ids[0].item() == target_token_id:
438
+ counterfactual_logit = sorted_logits[1]
439
+ else:
440
+ counterfactual_logit = sorted_logits[0]
441
+ metric = target_logit - counterfactual_logit
442
  else:
443
+ counterfactual_logit = logits[counterfactual_token_id]
444
+ metric = target_logit - counterfactual_logit
 
 
 
 
 
 
445
 
446
+ # Backward through linearized model
447
+ metric.backward()
 
448
 
449
+ # Collect attributions from saved neuron activations
450
+ attributions = {}
451
+ layer_stats = {} # diagnostic info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
+ for i, layer in enumerate(_get_model_layers(model)):
454
+ if i in blacklist_layers:
455
+ continue
456
+
457
+ mlp = layer.mlp
458
+ if not hasattr(mlp, "neuron_act") or mlp.neuron_act is None:
459
+ continue
460
+ if mlp.neuron_act.grad is None:
461
+ continue
462
+
463
+ act = mlp.neuron_act.detach() # [1, T, intermediate_size]
464
+ grad = mlp.neuron_act.grad # [1, T, intermediate_size]
465
+
466
+ # Attribution = gradient * activation (element-wise)
467
+ attr = (grad * act)[0] # [T, intermediate_size]
468
+ T = attr.shape[0]
469
+
470
+ # NaN-safe statistics (exclude NaN from sums)
471
+ valid_mask = ~torch.isnan(attr)
472
+ valid_attr = attr[valid_mask]
473
+ if valid_attr.numel() > 0:
474
+ layer_total = valid_attr.abs().sum().item()
475
+ layer_max = valid_attr.abs().max().item()
476
+ nan_frac = 1.0 - valid_mask.float().mean().item()
477
+ else:
478
+ layer_total = 0.0
479
+ layer_max = 0.0
480
+ nan_frac = 1.0
481
+ layer_stats[i] = {"total": layer_total, "max": layer_max, "nan_frac": nan_frac}
482
+
483
+ if last_n_positions is not None:
484
+ start_pos = max(0, T - last_n_positions)
485
+ elif filter_bos:
486
+ start_pos = 1
487
+ else:
488
+ start_pos = 0
489
+ for p in range(start_pos, T):
490
+ pos_attr = attr[p]
491
+ abs_attr = pos_attr.abs()
492
+
493
+ # NaN-safe topk: replace NaN with 0 so they don't crowd out valid values
494
+ nan_mask = torch.isnan(abs_attr)
495
+ if nan_mask.any():
496
+ abs_attr = abs_attr.clone()
497
+ abs_attr[nan_mask] = 0.0
498
+
499
+ # Keep top-k neurons at this position
500
+ k = min(top_k_per_layer, abs_attr.shape[0])
501
+ top_vals, top_idxs = abs_attr.topk(k)
502
+
503
+ for val, idx in zip(top_vals, top_idxs):
504
+ if val.item() > 1e-8:
505
+ n = idx.item()
506
+ if (i, n) in blacklist_neurons:
507
+ continue
508
+ nidx = NeuronIdx(layer=i, position=p, neuron=n)
509
+ attributions[nidx] = pos_attr[idx].item()
510
+ finally:
511
+ # Free GPU memory - clear saved activations after collection
512
+ for layer in _get_model_layers(model):
513
+ if hasattr(layer.mlp, "neuron_act"):
514
+ layer.mlp.neuron_act = None
515
 
516
  if verbose:
517
  print(f" Attribution distribution by layer:")