pliny-the-prompter commited on
Commit
102206c
·
verified ·
1 Parent(s): e3db1b1

Upload 135 files

Browse files
app.py CHANGED
@@ -316,6 +316,7 @@ def _get_preset_defaults(method_display: str):
316
  cfg = _PRESET_CONFIGS.get(method_key, _PRESET_CONFIGS["advanced"])
317
  return {
318
  "n_directions": cfg.get("n_directions", 4),
 
319
  "regularization": cfg.get("regularization", 0.3),
320
  "refinement_passes": cfg.get("refinement_passes", 2),
321
  "norm_preserve": cfg.get("norm_preserve", True),
@@ -348,6 +349,7 @@ def _on_method_change(method_display: str):
348
  d = _get_preset_defaults(method_display)
349
  return (
350
  d["n_directions"],
 
351
  d["regularization"],
352
  d["refinement_passes"],
353
  d["reflection_strength"],
@@ -1731,8 +1733,9 @@ def _format_multi_model_results(results: list[dict], context: dict | None = None
1731
  def obliterate(model_choice: str, method_choice: str,
1732
  prompt_volume_choice: str, dataset_source_choice: str,
1733
  custom_harmful: str, custom_harmless: str,
1734
- # Advanced params (sliders)
1735
- adv_n_directions: int, adv_regularization: float,
 
1736
  adv_refinement_passes: int, adv_reflection_strength: float,
1737
  adv_embed_regularization: float, adv_steering_strength: float,
1738
  adv_transplant_blend: float,
@@ -1906,6 +1909,7 @@ def obliterate(model_choice: str, method_choice: str,
1906
  on_log=on_log,
1907
  # Advanced overrides from UI
1908
  n_directions=int(adv_n_directions),
 
1909
  regularization=float(adv_regularization),
1910
  refinement_passes=int(adv_refinement_passes),
1911
  norm_preserve=adv_norm_preserve,
@@ -3930,7 +3934,13 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
3930
  with gr.Row():
3931
  adv_n_directions = gr.Slider(
3932
  1, 8, value=_defaults["n_directions"], step=1,
3933
- label="Directions", info="Number of refusal directions to extract via SVD",
 
 
 
 
 
 
3934
  )
3935
  adv_regularization = gr.Slider(
3936
  0.0, 1.0, value=_defaults["regularization"], step=0.05,
@@ -3999,7 +4009,8 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
3999
 
4000
  # List of all advanced controls (order must match _on_method_change return)
4001
  _adv_controls = [
4002
- adv_n_directions, adv_regularization, adv_refinement_passes,
 
4003
  adv_reflection_strength, adv_embed_regularization,
4004
  adv_steering_strength, adv_transplant_blend,
4005
  adv_spectral_bands, adv_spectral_threshold,
 
316
  cfg = _PRESET_CONFIGS.get(method_key, _PRESET_CONFIGS["advanced"])
317
  return {
318
  "n_directions": cfg.get("n_directions", 4),
319
+ "direction_method": cfg.get("direction_method", "svd"),
320
  "regularization": cfg.get("regularization", 0.3),
321
  "refinement_passes": cfg.get("refinement_passes", 2),
322
  "norm_preserve": cfg.get("norm_preserve", True),
 
349
  d = _get_preset_defaults(method_display)
350
  return (
351
  d["n_directions"],
352
+ d["direction_method"],
353
  d["regularization"],
354
  d["refinement_passes"],
355
  d["reflection_strength"],
 
1733
  def obliterate(model_choice: str, method_choice: str,
1734
  prompt_volume_choice: str, dataset_source_choice: str,
1735
  custom_harmful: str, custom_harmless: str,
1736
+ # Advanced params (sliders + radio)
1737
+ adv_n_directions: int, adv_direction_method: str,
1738
+ adv_regularization: float,
1739
  adv_refinement_passes: int, adv_reflection_strength: float,
1740
  adv_embed_regularization: float, adv_steering_strength: float,
1741
  adv_transplant_blend: float,
 
1909
  on_log=on_log,
1910
  # Advanced overrides from UI
1911
  n_directions=int(adv_n_directions),
1912
+ direction_method=adv_direction_method,
1913
  regularization=float(adv_regularization),
1914
  refinement_passes=int(adv_refinement_passes),
1915
  norm_preserve=adv_norm_preserve,
 
3934
  with gr.Row():
3935
  adv_n_directions = gr.Slider(
3936
  1, 8, value=_defaults["n_directions"], step=1,
3937
+ label="Directions", info="Number of refusal directions to extract",
3938
+ )
3939
+ adv_direction_method = gr.Radio(
3940
+ choices=["diff_means", "svd", "leace"],
3941
+ value=_defaults["direction_method"],
3942
+ label="Direction Method",
3943
+ info="diff_means: simple & robust, svd: multi-direction, leace: optimal erasure",
3944
  )
3945
  adv_regularization = gr.Slider(
3946
  0.0, 1.0, value=_defaults["regularization"], step=0.05,
 
4009
 
4010
  # List of all advanced controls (order must match _on_method_change return)
4011
  _adv_controls = [
4012
+ adv_n_directions, adv_direction_method,
4013
+ adv_regularization, adv_refinement_passes,
4014
  adv_reflection_strength, adv_embed_regularization,
4015
  adv_steering_strength, adv_transplant_blend,
4016
  adv_spectral_bands, adv_spectral_threshold,
obliteratus/abliterate.py CHANGED
@@ -63,6 +63,7 @@ METHODS = {
63
  "label": "Basic (Arditi et al.)",
64
  "description": "Single refusal direction via difference-in-means",
65
  "n_directions": 1,
 
66
  "norm_preserve": False,
67
  "regularization": 0.0,
68
  "refinement_passes": 1,
@@ -75,6 +76,7 @@ METHODS = {
75
  "label": "Advanced (Multi-direction + Norm-preserving)",
76
  "description": "SVD-based multi-direction extraction with norm preservation",
77
  "n_directions": 4,
 
78
  "norm_preserve": True,
79
  "regularization": 0.3,
80
  "embed_regularization": 0.5,
@@ -97,6 +99,7 @@ METHODS = {
97
  "Zero regularization for maximum refusal removal."
98
  ),
99
  "n_directions": 8,
 
100
  "norm_preserve": True,
101
  "regularization": 0.0,
102
  "refinement_passes": 3,
@@ -124,6 +127,7 @@ METHODS = {
124
  "separating trained-in refusal patterns from per-layer artifacts."
125
  ),
126
  "n_directions": 6,
 
127
  "norm_preserve": True,
128
  "regularization": 0.0,
129
  "refinement_passes": 2,
@@ -146,25 +150,31 @@ METHODS = {
146
  "Uses InformedAbliterationPipeline for the full feedback loop. "
147
  "Auto-detects alignment method (DPO/RLHF/CAI/SFT), maps concept "
148
  "cone geometry, performs cluster-aware layer selection, and gates "
149
- "projection by safety-capability entanglement. Includes spectral "
150
- "certification of abliteration completeness and Wasserstein-optimal "
151
- "primary direction extraction."
152
  ),
153
- "n_directions": 4,
 
154
  "norm_preserve": True,
155
  "regularization": 0.0,
156
  "refinement_passes": 2,
157
  "project_biases": True,
158
  "use_chat_template": True,
159
- "use_whitened_svd": True,
160
  "true_iterative_refinement": True,
161
  "use_jailbreak_contrast": False,
162
- "layer_adaptive_strength": False,
163
  "safety_neuron_masking": False,
164
  "per_expert_directions": False,
165
  "attention_head_surgery": False,
166
  "use_sae_features": False,
167
- "use_wasserstein_optimal": True,
 
 
 
 
 
168
  },
169
  "surgical": {
170
  "label": "Surgical (Full SOTA MoE-Aware)",
@@ -176,6 +186,7 @@ METHODS = {
176
  "minimizing capability damage via precision targeting."
177
  ),
178
  "n_directions": 8,
 
179
  "norm_preserve": True,
180
  "regularization": 0.0,
181
  "refinement_passes": 2,
@@ -204,6 +215,7 @@ METHODS = {
204
  "techniques plus the inversion layer."
205
  ),
206
  "n_directions": 8,
 
207
  "norm_preserve": True,
208
  "regularization": 0.0,
209
  "refinement_passes": 2,
@@ -234,6 +246,7 @@ METHODS = {
234
  "Best for maximizing quality when compute budget allows ~50 trials."
235
  ),
236
  "n_directions": 4,
 
237
  "norm_preserve": True,
238
  "regularization": 0.0,
239
  "refinement_passes": 1,
@@ -275,6 +288,7 @@ METHODS = {
275
  "runtime overhead except lightweight steering hooks."
276
  ),
277
  "n_directions": 4,
 
278
  "norm_preserve": True,
279
  "regularization": 0.0,
280
  "refinement_passes": 2,
@@ -320,6 +334,7 @@ METHODS = {
320
  "abliterated models were created with."
321
  ),
322
  "n_directions": 1,
 
323
  "norm_preserve": False,
324
  "regularization": 0.0,
325
  "refinement_passes": 1,
@@ -347,6 +362,7 @@ METHODS = {
347
  "whitened SVD, no iterative refinement."
348
  ),
349
  "n_directions": 4,
 
350
  "norm_preserve": False,
351
  # Ridge alpha=0.3 → effective reg = alpha/(1+alpha) = 0.3/1.3 ≈ 0.231
352
  # For orthonormal V: P_V^alpha = 1/(1+alpha) * VV^T = 0.769 * VV^T
@@ -379,6 +395,7 @@ METHODS = {
379
  "over the (refusal_rate, KL_divergence) frontier."
380
  ),
381
  "n_directions": 2,
 
382
  "norm_preserve": True,
383
  "regularization": 0.0,
384
  "refinement_passes": 1,
@@ -414,6 +431,7 @@ METHODS = {
414
  "boundary rather than the statistical activation difference."
415
  ),
416
  "n_directions": 4,
 
417
  "norm_preserve": True,
418
  "regularization": 0.0,
419
  "refinement_passes": 1,
@@ -566,6 +584,7 @@ class AbliterationPipeline:
566
  hub_token: str | None = None,
567
  hub_community_org: str | None = None,
568
  n_directions: int | None = None,
 
569
  norm_preserve: bool | None = None,
570
  regularization: float | None = None,
571
  refinement_passes: int | None = None,
@@ -659,6 +678,7 @@ class AbliterationPipeline:
659
  method_cfg = METHODS[method]
660
  self.method = method
661
  self.n_directions = n_directions if n_directions is not None else method_cfg["n_directions"]
 
662
  self.norm_preserve = norm_preserve if norm_preserve is not None else method_cfg["norm_preserve"]
663
  self.regularization = regularization if regularization is not None else method_cfg["regularization"]
664
  self.refinement_passes = refinement_passes if refinement_passes is not None else method_cfg["refinement_passes"]
@@ -936,7 +956,7 @@ class AbliterationPipeline:
936
  self.log(f"Loading model: {self.model_name}")
937
  self.log(f"Device: {self.device} | Dtype: {self.dtype}")
938
  self.log(f"Method: {method_label}")
939
- self.log(f" Directions: {self.n_directions} | Norm-preserve: {self.norm_preserve}")
940
  self.log(f" Regularization: {self.regularization} | Refinement passes: {self.refinement_passes}")
941
 
942
  self.handle = load_model(
@@ -1400,18 +1420,26 @@ class AbliterationPipeline:
1400
  else:
1401
  max_length = 384 if collect_multi_pos else 256
1402
  free_gb = dev.get_total_free_gb()
 
 
 
 
 
 
 
 
1403
  if dev.is_gpu_available():
1404
- if self.max_seq_length is None and free_gb < 2.0:
1405
  max_length = 64
1406
- self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
1407
- elif self.max_seq_length is None and free_gb < 4.0:
1408
  max_length = 128
1409
- self.log(f" Tight GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
1410
 
1411
  device = self._get_model_device(model)
1412
 
1413
  # Batch prompts for throughput — hooks unbatch per-prompt activations
1414
- batch_size = 16 if free_gb > 4.0 else 8 if free_gb > 2.0 else 1
1415
  # Left-pad so position -1 is always the last real token in every batch element
1416
  orig_padding_side = getattr(tokenizer, "padding_side", "right")
1417
  if batch_size > 1:
@@ -1498,9 +1526,16 @@ class AbliterationPipeline:
1498
  wasserstein_extractor = WassersteinOptimalExtractor()
1499
  self.log("Using Wasserstein-optimal direction extraction (cost-minimizing GEP)")
1500
 
 
 
 
 
 
 
 
1501
  # Optionally use whitened SVD for cleaner direction extraction
1502
  whitened_extractor = None
1503
- if self.use_whitened_svd and n_dirs > 1 and not self.use_wasserstein_optimal:
1504
  from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
1505
  whitened_extractor = WhitenedSVDExtractor()
1506
  self.log("Using whitened SVD (covariance-normalized) for direction extraction")
@@ -1547,6 +1582,30 @@ class AbliterationPipeline:
1547
  if idx < 5:
1548
  self.log(f" layer {idx}: Wasserstein extraction failed ({e}), falling back to SVD")
1549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1550
  if n_dirs == 1:
1551
  # Classic single-direction: difference-in-means
1552
  diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
@@ -3589,9 +3648,18 @@ class AbliterationPipeline:
3589
  except Exception:
3590
  pass
3591
 
 
 
 
 
 
 
 
 
 
3592
  # Use whitened SVD when enabled (matching main _distill)
3593
  whitened_extractor = None
3594
- if self.use_whitened_svd and n_dirs > 1 and wasserstein_extractor is None:
3595
  from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
3596
  whitened_extractor = WhitenedSVDExtractor()
3597
 
@@ -3624,6 +3692,22 @@ class AbliterationPipeline:
3624
  except Exception:
3625
  pass # Fall through to SVD
3626
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3627
  if n_dirs == 1:
3628
  diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
3629
  norm = diff.norm()
@@ -5741,6 +5825,7 @@ class AbliterationPipeline:
5741
  "method": self.method,
5742
  "method_config": {
5743
  "n_directions": self.n_directions,
 
5744
  "norm_preserve": self.norm_preserve,
5745
  "regularization": self.regularization,
5746
  "refinement_passes": self.refinement_passes,
 
63
  "label": "Basic (Arditi et al.)",
64
  "description": "Single refusal direction via difference-in-means",
65
  "n_directions": 1,
66
+ "direction_method": "diff_means",
67
  "norm_preserve": False,
68
  "regularization": 0.0,
69
  "refinement_passes": 1,
 
76
  "label": "Advanced (Multi-direction + Norm-preserving)",
77
  "description": "SVD-based multi-direction extraction with norm preservation",
78
  "n_directions": 4,
79
+ "direction_method": "svd",
80
  "norm_preserve": True,
81
  "regularization": 0.3,
82
  "embed_regularization": 0.5,
 
99
  "Zero regularization for maximum refusal removal."
100
  ),
101
  "n_directions": 8,
102
+ "direction_method": "svd",
103
  "norm_preserve": True,
104
  "regularization": 0.0,
105
  "refinement_passes": 3,
 
127
  "separating trained-in refusal patterns from per-layer artifacts."
128
  ),
129
  "n_directions": 6,
130
+ "direction_method": "svd",
131
  "norm_preserve": True,
132
  "regularization": 0.0,
133
  "refinement_passes": 2,
 
150
  "Uses InformedAbliterationPipeline for the full feedback loop. "
151
  "Auto-detects alignment method (DPO/RLHF/CAI/SFT), maps concept "
152
  "cone geometry, performs cluster-aware layer selection, and gates "
153
+ "projection by safety-capability entanglement. Defaults to single "
154
+ "diff-of-means direction + Bayesian optimization (Heretic-style). "
155
+ "LEACE available via direction_method='leace'."
156
  ),
157
+ "n_directions": 1,
158
+ "direction_method": "diff_means",
159
  "norm_preserve": True,
160
  "regularization": 0.0,
161
  "refinement_passes": 2,
162
  "project_biases": True,
163
  "use_chat_template": True,
164
+ "use_whitened_svd": False,
165
  "true_iterative_refinement": True,
166
  "use_jailbreak_contrast": False,
167
+ "layer_adaptive_strength": True,
168
  "safety_neuron_masking": False,
169
  "per_expert_directions": False,
170
  "attention_head_surgery": False,
171
  "use_sae_features": False,
172
+ "use_wasserstein_optimal": False,
173
+ "use_kl_optimization": True,
174
+ "kl_budget": 0.5,
175
+ "float_layer_interpolation": True,
176
+ "winsorize_activations": True,
177
+ "winsorize_percentile": 0.01,
178
  },
179
  "surgical": {
180
  "label": "Surgical (Full SOTA MoE-Aware)",
 
186
  "minimizing capability damage via precision targeting."
187
  ),
188
  "n_directions": 8,
189
+ "direction_method": "svd",
190
  "norm_preserve": True,
191
  "regularization": 0.0,
192
  "refinement_passes": 2,
 
215
  "techniques plus the inversion layer."
216
  ),
217
  "n_directions": 8,
218
+ "direction_method": "svd",
219
  "norm_preserve": True,
220
  "regularization": 0.0,
221
  "refinement_passes": 2,
 
246
  "Best for maximizing quality when compute budget allows ~50 trials."
247
  ),
248
  "n_directions": 4,
249
+ "direction_method": "svd",
250
  "norm_preserve": True,
251
  "regularization": 0.0,
252
  "refinement_passes": 1,
 
288
  "runtime overhead except lightweight steering hooks."
289
  ),
290
  "n_directions": 4,
291
+ "direction_method": "svd",
292
  "norm_preserve": True,
293
  "regularization": 0.0,
294
  "refinement_passes": 2,
 
334
  "abliterated models were created with."
335
  ),
336
  "n_directions": 1,
337
+ "direction_method": "diff_means",
338
  "norm_preserve": False,
339
  "regularization": 0.0,
340
  "refinement_passes": 1,
 
362
  "whitened SVD, no iterative refinement."
363
  ),
364
  "n_directions": 4,
365
+ "direction_method": "svd",
366
  "norm_preserve": False,
367
  # Ridge alpha=0.3 → effective reg = alpha/(1+alpha) = 0.3/1.3 ≈ 0.231
368
  # For orthonormal V: P_V^alpha = 1/(1+alpha) * VV^T = 0.769 * VV^T
 
395
  "over the (refusal_rate, KL_divergence) frontier."
396
  ),
397
  "n_directions": 2,
398
+ "direction_method": "diff_means",
399
  "norm_preserve": True,
400
  "regularization": 0.0,
401
  "refinement_passes": 1,
 
431
  "boundary rather than the statistical activation difference."
432
  ),
433
  "n_directions": 4,
434
+ "direction_method": "svd",
435
  "norm_preserve": True,
436
  "regularization": 0.0,
437
  "refinement_passes": 1,
 
584
  hub_token: str | None = None,
585
  hub_community_org: str | None = None,
586
  n_directions: int | None = None,
587
+ direction_method: str | None = None,
588
  norm_preserve: bool | None = None,
589
  regularization: float | None = None,
590
  refinement_passes: int | None = None,
 
678
  method_cfg = METHODS[method]
679
  self.method = method
680
  self.n_directions = n_directions if n_directions is not None else method_cfg["n_directions"]
681
+ self.direction_method = direction_method if direction_method is not None else method_cfg.get("direction_method", "svd")
682
  self.norm_preserve = norm_preserve if norm_preserve is not None else method_cfg["norm_preserve"]
683
  self.regularization = regularization if regularization is not None else method_cfg["regularization"]
684
  self.refinement_passes = refinement_passes if refinement_passes is not None else method_cfg["refinement_passes"]
 
956
  self.log(f"Loading model: {self.model_name}")
957
  self.log(f"Device: {self.device} | Dtype: {self.dtype}")
958
  self.log(f"Method: {method_label}")
959
+ self.log(f" Directions: {self.n_directions} ({self.direction_method}) | Norm-preserve: {self.norm_preserve}")
960
  self.log(f" Regularization: {self.regularization} | Refinement passes: {self.refinement_passes}")
961
 
962
  self.handle = load_model(
 
1420
  else:
1421
  max_length = 384 if collect_multi_pos else 256
1422
  free_gb = dev.get_total_free_gb()
1423
+ # Scale memory thresholds by model size — a 1.2B model needs far
1424
+ # less KV-cache memory per token than a 7B model. Baseline
1425
+ # thresholds (4 / 2 GB) were tuned for 7B (hidden=4096, layers=32).
1426
+ _h = self.handle.hidden_size if self.handle else 4096
1427
+ _l = n_layers if n_layers else 32
1428
+ _mem_scale = (_h / 4096) * (_l / 32)
1429
+ _tight_gb = max(4.0 * _mem_scale, 0.5)
1430
+ _low_gb = max(2.0 * _mem_scale, 0.25)
1431
  if dev.is_gpu_available():
1432
+ if self.max_seq_length is None and free_gb < _low_gb:
1433
  max_length = 64
1434
+ self.log(f" Low GPU memory ({free_gb:.1f} GB free, threshold {_low_gb:.1f} GB), using max_length={max_length}")
1435
+ elif self.max_seq_length is None and free_gb < _tight_gb:
1436
  max_length = 128
1437
+ self.log(f" Tight GPU memory ({free_gb:.1f} GB free, threshold {_tight_gb:.1f} GB), using max_length={max_length}")
1438
 
1439
  device = self._get_model_device(model)
1440
 
1441
  # Batch prompts for throughput — hooks unbatch per-prompt activations
1442
+ batch_size = 16 if free_gb > _tight_gb else 8 if free_gb > _low_gb else 1
1443
  # Left-pad so position -1 is always the last real token in every batch element
1444
  orig_padding_side = getattr(tokenizer, "padding_side", "right")
1445
  if batch_size > 1:
 
1526
  wasserstein_extractor = WassersteinOptimalExtractor()
1527
  self.log("Using Wasserstein-optimal direction extraction (cost-minimizing GEP)")
1528
 
1529
+ # Optionally use LEACE for theoretically optimal concept erasure
1530
+ leace_extractor = None
1531
+ if self.direction_method == "leace":
1532
+ from obliteratus.analysis.leace import LEACEExtractor
1533
+ leace_extractor = LEACEExtractor()
1534
+ self.log("Using LEACE (closed-form optimal concept erasure) for direction extraction")
1535
+
1536
  # Optionally use whitened SVD for cleaner direction extraction
1537
  whitened_extractor = None
1538
+ if self.use_whitened_svd and n_dirs > 1 and not self.use_wasserstein_optimal and leace_extractor is None:
1539
  from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
1540
  whitened_extractor = WhitenedSVDExtractor()
1541
  self.log("Using whitened SVD (covariance-normalized) for direction extraction")
 
1582
  if idx < 5:
1583
  self.log(f" layer {idx}: Wasserstein extraction failed ({e}), falling back to SVD")
1584
 
1585
+ if leace_extractor is not None:
1586
+ # LEACE: closed-form optimal concept erasure direction
1587
+ if idx in self._harmful_acts and idx in self._harmless_acts:
1588
+ try:
1589
+ l_result = leace_extractor.extract(
1590
+ self._harmful_acts[idx],
1591
+ self._harmless_acts[idx],
1592
+ layer_idx=idx,
1593
+ )
1594
+ self.refusal_directions[idx] = l_result.direction
1595
+ self.refusal_subspaces[idx] = l_result.direction.unsqueeze(0)
1596
+ norms[idx] = l_result.generalized_eigenvalue
1597
+
1598
+ if idx < 5 or idx == n_layers - 1:
1599
+ self.log(
1600
+ f" layer {idx}: LEACE eigenvalue={l_result.generalized_eigenvalue:.4f}, "
1601
+ f"erasure_loss={l_result.erasure_loss:.4f}, "
1602
+ f"cond={l_result.within_class_condition:.0f}"
1603
+ )
1604
+ continue
1605
+ except Exception as e:
1606
+ if idx < 5:
1607
+ self.log(f" layer {idx}: LEACE failed ({e}), falling back to diff-of-means")
1608
+
1609
  if n_dirs == 1:
1610
  # Classic single-direction: difference-in-means
1611
  diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
 
3648
  except Exception:
3649
  pass
3650
 
3651
+ # Use LEACE when enabled (matching main _distill)
3652
+ leace_extractor = None
3653
+ if self.direction_method == "leace":
3654
+ try:
3655
+ from obliteratus.analysis.leace import LEACEExtractor
3656
+ leace_extractor = LEACEExtractor()
3657
+ except Exception:
3658
+ pass
3659
+
3660
  # Use whitened SVD when enabled (matching main _distill)
3661
  whitened_extractor = None
3662
+ if self.use_whitened_svd and n_dirs > 1 and wasserstein_extractor is None and leace_extractor is None:
3663
  from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
3664
  whitened_extractor = WhitenedSVDExtractor()
3665
 
 
3692
  except Exception:
3693
  pass # Fall through to SVD
3694
 
3695
+ # LEACE path (matching main _distill)
3696
+ if leace_extractor is not None:
3697
+ if idx in self._harmful_acts and idx in self._harmless_acts:
3698
+ try:
3699
+ l_result = leace_extractor.extract(
3700
+ self._harmful_acts[idx],
3701
+ self._harmless_acts[idx],
3702
+ layer_idx=idx,
3703
+ )
3704
+ self.refusal_directions[idx] = l_result.direction
3705
+ self.refusal_subspaces[idx] = l_result.direction.unsqueeze(0)
3706
+ norms[idx] = l_result.generalized_eigenvalue
3707
+ continue
3708
+ except Exception:
3709
+ pass # Fall through to diff-of-means
3710
+
3711
  if n_dirs == 1:
3712
  diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
3713
  norm = diff.norm()
 
5825
  "method": self.method,
5826
  "method_config": {
5827
  "n_directions": self.n_directions,
5828
+ "direction_method": self.direction_method,
5829
  "norm_preserve": self.norm_preserve,
5830
  "regularization": self.regularization,
5831
  "refinement_passes": self.refinement_passes,
obliteratus/analysis/leace.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LEACE (LEAst-squares Concept Erasure) for refusal direction extraction.
2
+
3
+ Closed-form optimal concept erasure that finds the minimum-rank edit to
4
+ representations preventing any linear classifier from recovering the
5
+ concept (harmful vs harmless). Unlike SVD, LEACE produces directions
6
+ that are provably optimal for erasure with minimal distortion.
7
+
8
+ The key insight: instead of finding directions of maximum variance (SVD)
9
+ or maximum mean difference (diff-of-means), LEACE solves a constrained
10
+ optimization problem: find the smallest perturbation to representations
11
+ such that no linear probe can distinguish harmful from harmless.
12
+
13
+ Mathematical formulation (rank-1 erasure):
14
+ Given class-conditional means mu_0, mu_1 and within-class
15
+ covariance S_w:
16
+ 1. Compute mean difference: delta = mu_1 - mu_0
17
+ 2. Compute within-class covariance: S_w = (S_0 + S_1) / 2
18
+ 3. Solve generalized eigenvalue problem: S_b v = lambda S_w v
19
+ where S_b = delta @ delta^T (between-class scatter)
20
+ 4. The top generalized eigenvector is the LEACE direction
21
+ 5. Erase by projecting out: x' = x - (x @ v) * v^T
22
+
23
+ This is mathematically equivalent to Fisher's Linear Discriminant but
24
+ applied as an erasure operation. The direction maximizes class
25
+ separability relative to within-class spread, making it the optimal
26
+ single direction to remove for concept erasure.
27
+
28
+ Advantages over SVD:
29
+ - Theoretically optimal: minimizes representation distortion for
30
+ guaranteed erasure of linear concept information
31
+ - Handles rogue dimensions naturally: within-class normalization
32
+ prevents high-variance but non-discriminative dimensions from
33
+ dominating
34
+ - No hyperparameters beyond regularization epsilon
35
+ - Closed-form solution (no iterative optimization)
36
+
37
+ References:
38
+ - Belrose et al. (2023): LEACE: Perfect linear concept erasure in
39
+ closed form. NeurIPS 2023.
40
+ - Ravfogel et al. (2022): RLACE: Adversarial concept erasure
41
+ (iterative precursor to LEACE).
42
+ """
43
+
44
+ from __future__ import annotations
45
+
46
+ from dataclasses import dataclass
47
+
48
+ import torch
49
+
50
+
51
+ @dataclass
52
+ class LEACEResult:
53
+ """Result of LEACE direction extraction for a single layer."""
54
+
55
+ layer_idx: int
56
+ direction: torch.Tensor # (hidden_dim,) unit vector
57
+ generalized_eigenvalue: float # lambda from GEP (discriminability)
58
+ within_class_condition: float # condition number of S_w
59
+ mean_diff_norm: float # ||mu_1 - mu_0||
60
+ erasure_loss: float # expected squared distortion from erasure
61
+
62
+
63
+ class LEACEExtractor:
64
+ """Extract refusal directions via LEACE (closed-form concept erasure).
65
+
66
+ Finds the direction that maximally separates harmful from harmless
67
+ activations relative to within-class variance, then erases it.
68
+ This is the provably optimal rank-1 concept erasure.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ regularization_eps: float = 1e-4,
74
+ shrinkage: float = 0.0,
75
+ ):
76
+ """
77
+ Args:
78
+ regularization_eps: Tikhonov regularization for S_w inversion.
79
+ Larger values produce more conservative (but stable) results.
80
+ shrinkage: Ledoit-Wolf shrinkage toward identity (0..1).
81
+ 0 = no shrinkage, 1 = full shrinkage to scaled identity.
82
+ Useful when n_samples < hidden_dim.
83
+ """
84
+ self.regularization_eps = regularization_eps
85
+ self.shrinkage = shrinkage
86
+
87
+ def extract(
88
+ self,
89
+ harmful_activations: list[torch.Tensor],
90
+ harmless_activations: list[torch.Tensor],
91
+ layer_idx: int = 0,
92
+ ) -> LEACEResult:
93
+ """Extract the LEACE direction for a single layer.
94
+
95
+ Args:
96
+ harmful_activations: List of (hidden_dim,) tensors from harmful prompts.
97
+ harmless_activations: List of (hidden_dim,) tensors from harmless prompts.
98
+ layer_idx: Layer index (for metadata).
99
+
100
+ Returns:
101
+ LEACEResult with the optimal erasure direction.
102
+ """
103
+ H = torch.stack(harmful_activations).float() # (n_h, d)
104
+ B = torch.stack(harmless_activations).float() # (n_b, d)
105
+
106
+ if H.dim() == 3:
107
+ H = H.squeeze(1)
108
+ if B.dim() == 3:
109
+ B = B.squeeze(1)
110
+
111
+ n_h, d = H.shape
112
+ n_b = B.shape[0]
113
+
114
+ # Class-conditional means
115
+ mu_h = H.mean(dim=0) # (d,)
116
+ mu_b = B.mean(dim=0) # (d,)
117
+
118
+ # Mean difference (between-class direction)
119
+ delta = mu_h - mu_b # (d,)
120
+ delta_norm = delta.norm().item()
121
+
122
+ # Within-class covariance: S_w = (S_h + S_b) / 2
123
+ # where S_h = (H - mu_h)^T (H - mu_h) / (n_h - 1) etc.
124
+ H_centered = H - mu_h.unsqueeze(0)
125
+ B_centered = B - mu_b.unsqueeze(0)
126
+
127
+ S_h = (H_centered.T @ H_centered) / max(n_h - 1, 1)
128
+ S_b = (B_centered.T @ B_centered) / max(n_b - 1, 1)
129
+ S_w = (S_h + S_b) / 2.0 # (d, d)
130
+
131
+ # Apply Ledoit-Wolf shrinkage if requested
132
+ if self.shrinkage > 0:
133
+ trace_S_w = S_w.trace().item()
134
+ S_w = (1 - self.shrinkage) * S_w + self.shrinkage * (trace_S_w / d) * torch.eye(d, device=S_w.device)
135
+
136
+ # Regularize S_w for numerical stability
137
+ S_w_reg = S_w + self.regularization_eps * torch.eye(d, device=S_w.device)
138
+
139
+ # Condition number of S_w (for diagnostics)
140
+ try:
141
+ eigs_w = torch.linalg.eigvalsh(S_w_reg)
142
+ eigs_w = eigs_w.clamp(min=0)
143
+ pos_eigs = eigs_w[eigs_w > eigs_w.max() * 1e-10]
144
+ condition = (pos_eigs.max() / pos_eigs.min()).item() if pos_eigs.numel() > 0 else float('inf')
145
+ except Exception:
146
+ condition = float('inf')
147
+
148
+ # LEACE direction via S_w^{-1} @ delta
149
+ # The generalized eigenvector for rank-1 S_between = delta @ delta^T
150
+ # reduces to: v = S_w^{-1} @ delta (up to normalization)
151
+ try:
152
+ # Use solve for numerical stability (avoids explicit inverse)
153
+ v = torch.linalg.solve(S_w_reg, delta) # (d,)
154
+ except torch.linalg.LinAlgError:
155
+ # Fallback: pseudoinverse
156
+ v = torch.linalg.lstsq(S_w_reg, delta.unsqueeze(1)).solution.squeeze(1)
157
+
158
+ # Normalize to unit length
159
+ v_norm = v.norm()
160
+ if v_norm > 1e-8:
161
+ direction = v / v_norm
162
+ else:
163
+ # Degenerate case: fall back to normalized mean difference
164
+ direction = delta / max(delta_norm, 1e-8)
165
+
166
+ # Generalized eigenvalue: lambda = delta^T @ S_w^{-1} @ delta
167
+ # This measures how discriminable the classes are after whitening
168
+ gen_eigenvalue = (delta @ v).item()
169
+
170
+ # Erasure loss: expected squared distortion E[||x - x'||^2]
171
+ # For rank-1 projection: loss = v^T @ S_total @ v where S_total
172
+ # is the total (pooled) covariance
173
+ all_acts = torch.cat([H, B], dim=0)
174
+ mu_total = all_acts.mean(dim=0)
175
+ centered_total = all_acts - mu_total.unsqueeze(0)
176
+ S_total = (centered_total.T @ centered_total) / max(all_acts.shape[0] - 1, 1)
177
+ erasure_loss = (direction @ S_total @ direction).item()
178
+
179
+ return LEACEResult(
180
+ layer_idx=layer_idx,
181
+ direction=direction,
182
+ generalized_eigenvalue=gen_eigenvalue,
183
+ within_class_condition=condition,
184
+ mean_diff_norm=delta_norm,
185
+ erasure_loss=erasure_loss,
186
+ )
187
+
188
+ def extract_all_layers(
189
+ self,
190
+ harmful_acts: dict[int, list[torch.Tensor]],
191
+ harmless_acts: dict[int, list[torch.Tensor]],
192
+ ) -> dict[int, LEACEResult]:
193
+ """Extract LEACE directions for all layers.
194
+
195
+ Args:
196
+ harmful_acts: {layer_idx: [activations]} from activation collection.
197
+ harmless_acts: {layer_idx: [activations]} from activation collection.
198
+
199
+ Returns:
200
+ {layer_idx: LEACEResult} for each layer.
201
+ """
202
+ results = {}
203
+ for idx in sorted(harmful_acts.keys()):
204
+ if idx not in harmless_acts:
205
+ continue
206
+ results[idx] = self.extract(
207
+ harmful_acts[idx],
208
+ harmless_acts[idx],
209
+ layer_idx=idx,
210
+ )
211
+ return results
212
+
213
+ @staticmethod
214
+ def compare_with_diff_of_means(
215
+ leace_result: LEACEResult,
216
+ harmful_mean: torch.Tensor,
217
+ harmless_mean: torch.Tensor,
218
+ ) -> dict[str, float]:
219
+ """Compare LEACE direction with simple diff-of-means.
220
+
221
+ Returns cosine similarity and diagnostic metrics showing how much
222
+ the within-class normalization rotates the direction.
223
+ """
224
+ diff = harmful_mean.squeeze() - harmless_mean.squeeze()
225
+ diff_norm = diff.norm()
226
+ if diff_norm > 1e-8:
227
+ diff_normalized = diff / diff_norm
228
+ else:
229
+ diff_normalized = diff
230
+
231
+ cosine_sim = (leace_result.direction @ diff_normalized).abs().item()
232
+
233
+ return {
234
+ "cosine_similarity": cosine_sim,
235
+ "leace_eigenvalue": leace_result.generalized_eigenvalue,
236
+ "leace_erasure_loss": leace_result.erasure_loss,
237
+ "within_class_condition": leace_result.within_class_condition,
238
+ "mean_diff_norm": leace_result.mean_diff_norm,
239
+ }
obliteratus/bayesian_optimizer.py CHANGED
@@ -142,28 +142,35 @@ def _parametric_layer_weight(
142
  min_weight: float,
143
  spread: float,
144
  ) -> float:
145
- """Compute ablation weight for a layer using a parametric bell curve.
146
 
147
- This is the Heretic-style parametric kernel:
148
- - max_weight: peak ablation strength (0..1)
149
- - peak_position: normalized position of peak (0..1 maps to layer 0..n_layers-1)
150
- - min_weight: minimum ablation weight at the tails
151
- - spread: controls width of the bell curve (higher = wider)
152
 
153
- Returns a value in [min_weight, max_weight] representing how strongly
154
- to ablate this layer (1.0 = full projection, 0.0 = no projection).
 
 
 
 
 
155
  """
156
  if n_layers <= 1:
157
  return max_weight
158
 
159
  normalized_pos = layer_idx / (n_layers - 1)
160
- peak = peak_position
161
- # Gaussian-shaped kernel
162
- dist = abs(normalized_pos - peak)
163
- sigma = max(spread, 0.01)
164
- gauss = math.exp(-0.5 * (dist / sigma) ** 2)
 
165
 
166
- return min_weight + (max_weight - min_weight) * gauss
 
167
 
168
 
169
  def _interpolate_direction(
@@ -171,37 +178,56 @@ def _interpolate_direction(
171
  layer_idx: int,
172
  float_dir_idx: float,
173
  ) -> torch.Tensor:
174
- """Get an interpolated refusal direction from a float-valued index.
 
 
 
 
 
 
 
 
175
 
176
- Non-integer values interpolate between adjacent SVD directions in the
177
- refusal subspace, unlocking a continuous space of directions beyond
178
- the discrete top-k.
 
179
 
180
  Args:
181
- pipeline: Pipeline with extracted refusal subspaces.
182
- layer_idx: Which layer's subspace to use.
183
- float_dir_idx: Continuous direction index (e.g., 0.7 interpolates
184
- between direction 0 and direction 1).
185
 
186
  Returns:
187
  Normalized direction tensor.
188
  """
189
- subspace = pipeline.refusal_subspaces.get(layer_idx)
190
- if subspace is None or subspace.shape[0] == 0:
 
191
  return pipeline.refusal_directions.get(layer_idx, torch.zeros(1))
192
 
193
- n_dirs = subspace.shape[0]
194
- # Clamp to valid range
195
- float_dir_idx = max(0.0, min(float_dir_idx, n_dirs - 1))
 
 
196
 
197
  lo = int(float_dir_idx)
198
- hi = min(lo + 1, n_dirs - 1)
 
 
 
 
 
 
199
 
200
  if lo == hi:
201
- d = subspace[lo]
202
  else:
 
203
  alpha = float_dir_idx - lo
204
- d = (1.0 - alpha) * subspace[lo] + alpha * subspace[hi]
205
 
206
  norm = d.norm()
207
  if norm > 1e-8:
@@ -342,9 +368,14 @@ def run_bayesian_optimization(
342
  for live_data, saved_clone in original_params: # noqa: F821
343
  live_data.copy_(saved_clone.to(live_data.device))
344
 
345
- # Warm-start values for the parametric kernel
346
- # Estimate peak position from strongest layer
347
- if pipeline._strong_layers:
 
 
 
 
 
348
  peak_layer = pipeline._strong_layers[0]
349
  warm_peak = peak_layer / max(n_total_layers - 1, 1)
350
  else:
@@ -356,56 +387,56 @@ def run_bayesian_optimization(
356
  # Suppress Optuna's verbose logging
357
  optuna.logging.set_verbosity(optuna.logging.WARNING)
358
 
359
- # Max SVD directions available (for float direction interpolation)
360
- max_n_dirs = max(
361
- (pipeline.refusal_subspaces[idx].shape[0]
362
- for idx in pipeline._strong_layers
363
- if idx in pipeline.refusal_subspaces),
364
- default=1,
365
- )
366
 
367
  # ── Phase 1: Parametric kernel optimization (compact search space) ──
 
 
368
 
369
  def objective(trial: optuna.Trial) -> tuple[float, float]:
370
  """Multi-objective: minimize (refusal_rate, kl_divergence)."""
371
  _restore_all()
372
 
373
- # Parametric kernel: 4 params describe the entire layer weighting
374
- max_weight = trial.suggest_float("max_weight", 0.5, 1.0)
375
- peak_position = trial.suggest_float("peak_position", 0.1, 0.9)
376
- min_weight = trial.suggest_float("min_weight", 0.0, 0.3)
377
- spread = trial.suggest_float("spread", 0.1, 0.6)
378
 
379
- # Component-specific scaling (Heretic insight: MLP more damaging)
380
- attn_scale = trial.suggest_float("attn_scale", 0.5, 1.0)
381
- mlp_scale = trial.suggest_float("mlp_scale", 0.3, 1.0)
 
 
382
 
383
- # Float direction index (continuous interpolation between SVD dirs)
384
- dir_idx = trial.suggest_float("dir_idx", 0.0, max(max_n_dirs - 1, 0.0))
385
 
386
- # Compute per-layer regularization from parametric kernel
387
- layer_regs: dict[int, float] = {}
 
388
  for idx in pipeline._strong_layers:
389
- weight = _parametric_layer_weight(
390
- idx, n_total_layers, max_weight, peak_position, min_weight, spread,
391
- )
392
- # Convert weight to regularization (weight=1 → reg=0, weight=0 → reg=1)
393
- layer_regs[idx] = 1.0 - weight
394
 
395
  # Apply projection with trial's parameters
396
  for idx in pipeline._strong_layers:
397
- if idx not in pipeline.refusal_subspaces:
398
  continue
399
 
400
- # Use interpolated direction
401
  direction = _interpolate_direction(pipeline, idx, dir_idx)
402
  d_col = direction.to(device=next(layer_modules[idx].parameters()).device)
403
  d_col = d_col.unsqueeze(-1) if d_col.dim() == 1 else d_col
404
 
405
- reg = layer_regs[idx]
406
-
407
- # Attention projection (with attn_scale)
408
- attn_reg = 1.0 - (1.0 - reg) * attn_scale
409
  try:
410
  attn = get_attention_module(layer_modules[idx], arch)
411
  pipeline._project_out_advanced(
@@ -416,8 +447,8 @@ def run_bayesian_optimization(
416
  except (AttributeError, RuntimeError):
417
  pass
418
 
419
- # MLP/FFN projection (with mlp_scale)
420
- mlp_reg = 1.0 - (1.0 - reg) * mlp_scale
421
  try:
422
  ffn = get_ffn_module(layer_modules[idx], arch)
423
  count = pipeline._project_out_advanced(
@@ -439,18 +470,20 @@ def run_bayesian_optimization(
439
  refusal = _measure_refusal_rate(pipeline, n_prompts=n_refusal_prompts)
440
  kl = _measure_kl_divergence(pipeline, reference_logits, kl_prompts)
441
 
442
- # Track best combined score
443
  nonlocal best_score, best_result
444
  combined = refusal + 0.5 * kl
445
  if combined < best_score:
446
  best_score = combined
447
- best_result = dict(layer_regs)
 
 
 
448
 
449
  pipeline.log(
450
  f" Trial {trial.number + 1}/{n_trials}: "
451
  f"refusal={refusal:.0%}, KL={kl:.4f} "
452
- f"(peak={peak_position:.2f}, spread={spread:.2f}, "
453
- f"attn={attn_scale:.2f}, mlp={mlp_scale:.2f}, dir={dir_idx:.2f})"
454
  )
455
 
456
  return refusal, kl
@@ -462,16 +495,33 @@ def run_bayesian_optimization(
462
  study_name="obliteratus_parametric_optimization",
463
  )
464
 
465
- # Enqueue warm-start trial with analysis-derived estimates
466
- warm_params = {
467
- "max_weight": 0.9,
468
- "peak_position": warm_peak,
469
- "min_weight": 0.05,
470
- "spread": 0.3,
471
- "attn_scale": 0.8,
472
- "mlp_scale": 0.6,
473
- "dir_idx": 0.0,
474
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  study.enqueue_trial(warm_params)
476
 
477
  pipeline.log(f"Bayesian optimization: running {n_trials} trials (parametric kernel)...")
@@ -490,25 +540,32 @@ def run_bayesian_optimization(
490
  p = best_trial.params
491
  best_result = {}
492
  for idx in pipeline._strong_layers:
493
- weight = _parametric_layer_weight(
 
 
 
 
 
494
  idx, n_total_layers,
495
- p["max_weight"], p["peak_position"],
496
- p["min_weight"], p["spread"],
497
  )
498
- best_result[idx] = 1.0 - weight
 
499
 
500
  pipeline.log(
501
  f" Best trial: refusal={best_trial.values[0]:.0%}, "
502
  f"KL={best_trial.values[1]:.4f}"
503
  )
504
  pipeline.log(
505
- f" Kernel: peak={p['peak_position']:.2f}, spread={p['spread']:.2f}, "
506
- f"max={p['max_weight']:.2f}, min={p['min_weight']:.2f}"
507
  )
508
  pipeline.log(
509
- f" Components: attn={p['attn_scale']:.2f}, mlp={p['mlp_scale']:.2f}, "
510
- f"dir_idx={p['dir_idx']:.2f}"
511
  )
 
512
 
513
  # Store the best direction index for use during EXCISE
514
  best_dir_idx = p.get("dir_idx", 0.0)
@@ -518,9 +575,9 @@ def run_bayesian_optimization(
518
  new_dir = _interpolate_direction(pipeline, idx, best_dir_idx)
519
  pipeline.refusal_directions[idx] = new_dir
520
 
521
- # Store component scales for use in EXCISE
522
- pipeline._bayesian_attn_scale = p.get("attn_scale", 1.0)
523
- pipeline._bayesian_mlp_scale = p.get("mlp_scale", 1.0)
524
 
525
  elif best_result:
526
  pipeline.log(f" Using best combined score: {best_score:.4f}")
 
142
  min_weight: float,
143
  spread: float,
144
  ) -> float:
145
+ """Compute ablation weight for a layer using a piecewise-linear tent kernel.
146
 
147
+ Faithful reproduction of Heretic's parametric kernel (p-e-w/heretic):
148
+ - max_weight: peak ablation strength at peak_position
149
+ - peak_position: normalized position of peak (0..1)
150
+ - min_weight: weight at the edges of the tent
151
+ - spread: normalized distance from peak to tent edge (min_weight_distance)
152
 
153
+ Layers beyond ``spread`` from the peak get weight 0 (skipped entirely).
154
+ Within the tent, weight drops linearly from max_weight to min_weight.
155
+ This matches Heretic's actual formula::
156
+
157
+ distance = abs(layer_index - max_weight_position)
158
+ if distance > min_weight_distance: skip
159
+ weight = max_weight + (distance / min_weight_distance) * (min_weight - max_weight)
160
  """
161
  if n_layers <= 1:
162
  return max_weight
163
 
164
  normalized_pos = layer_idx / (n_layers - 1)
165
+ dist = abs(normalized_pos - peak_position)
166
+ min_weight_distance = max(spread, 0.01)
167
+
168
+ # Hard cutoff: layers outside the tent get 0 weight (Heretic skips them)
169
+ if dist > min_weight_distance:
170
+ return 0.0
171
 
172
+ # Linear interpolation: max_weight at peak → min_weight at edges
173
+ return max_weight + (dist / min_weight_distance) * (min_weight - max_weight)
174
 
175
 
176
  def _interpolate_direction(
 
178
  layer_idx: int,
179
  float_dir_idx: float,
180
  ) -> torch.Tensor:
181
+ """Get an interpolated refusal direction from a float-valued layer index.
182
+
183
+ Faithful reproduction of Heretic's direction interpolation: the index
184
+ selects which *layer's* diff-of-means direction to use, with float
185
+ values interpolating between adjacent layers' directions. This is
186
+ fundamentally different from interpolating between SVD components
187
+ within a single layer — it searches across the layer axis.
188
+
189
+ From Heretic source (model.py)::
190
 
191
+ weight, index = math.modf(direction_index + 1)
192
+ refusal_direction = F.normalize(
193
+ refusal_directions[int(index)].lerp(
194
+ refusal_directions[int(index) + 1], weight), p=2, dim=0)
195
 
196
  Args:
197
+ pipeline: Pipeline with extracted refusal directions per layer.
198
+ layer_idx: The layer being projected (used as fallback).
199
+ float_dir_idx: Continuous direction index selects which layer's
200
+ direction to use (e.g., 5.3 interpolates 70% layer-5 + 30% layer-6).
201
 
202
  Returns:
203
  Normalized direction tensor.
204
  """
205
+ # Build sorted list of layer indices that have refusal directions
206
+ sorted_layers = sorted(pipeline.refusal_directions.keys())
207
+ if not sorted_layers:
208
  return pipeline.refusal_directions.get(layer_idx, torch.zeros(1))
209
 
210
+ n_layers_with_dirs = len(sorted_layers)
211
+
212
+ # Heretic uses direction_index + 1 offset; we map float_dir_idx into
213
+ # the sorted layer list, clamped to valid range.
214
+ float_dir_idx = max(0.0, min(float_dir_idx, n_layers_with_dirs - 1))
215
 
216
  lo = int(float_dir_idx)
217
+ hi = min(lo + 1, n_layers_with_dirs - 1)
218
+
219
+ lo_layer = sorted_layers[lo]
220
+ hi_layer = sorted_layers[hi]
221
+
222
+ d_lo = pipeline.refusal_directions[lo_layer]
223
+ d_hi = pipeline.refusal_directions[hi_layer]
224
 
225
  if lo == hi:
226
+ d = d_lo
227
  else:
228
+ # Linear interpolation between adjacent layers' directions
229
  alpha = float_dir_idx - lo
230
+ d = (1.0 - alpha) * d_lo + alpha * d_hi
231
 
232
  norm = d.norm()
233
  if norm > 1e-8:
 
368
  for live_data, saved_clone in original_params: # noqa: F821
369
  live_data.copy_(saved_clone.to(live_data.device))
370
 
371
+ # Warm-start values for the parametric kernel.
372
+ # If the informed pipeline provided analysis-derived warm-start params,
373
+ # use those (they're much better than the default heuristic).
374
+ informed_warm = getattr(pipeline, "_informed_warm_start", None)
375
+ if informed_warm:
376
+ warm_peak = informed_warm.get("peak_position", 0.5)
377
+ pipeline.log(f" Using analysis-informed warm-start (peak={warm_peak:.2f})")
378
+ elif pipeline._strong_layers:
379
  peak_layer = pipeline._strong_layers[0]
380
  warm_peak = peak_layer / max(n_total_layers - 1, 1)
381
  else:
 
387
  # Suppress Optuna's verbose logging
388
  optuna.logging.set_verbosity(optuna.logging.WARNING)
389
 
390
+ # Max layers with directions (for float direction interpolation)
391
+ n_layers_with_dirs = len([
392
+ idx for idx in pipeline._strong_layers
393
+ if idx in pipeline.refusal_directions
394
+ ])
 
 
395
 
396
  # ── Phase 1: Parametric kernel optimization (compact search space) ──
397
+ # Heretic uses SEPARATE kernel parameters for attention and MLP,
398
+ # allowing them to peak at different layers (8 params + 1 dir_idx = 9).
399
 
400
  def objective(trial: optuna.Trial) -> tuple[float, float]:
401
  """Multi-objective: minimize (refusal_rate, kl_divergence)."""
402
  _restore_all()
403
 
404
+ # Attention kernel: 4 params
405
+ attn_max = trial.suggest_float("attn_max_weight", 0.5, 1.0)
406
+ attn_peak = trial.suggest_float("attn_peak_position", 0.1, 0.9)
407
+ attn_min = trial.suggest_float("attn_min_weight", 0.0, 0.3)
408
+ attn_spread = trial.suggest_float("attn_spread", 0.1, 0.6)
409
 
410
+ # MLP kernel: 4 params (separate can peak at a different layer)
411
+ mlp_max = trial.suggest_float("mlp_max_weight", 0.3, 1.0)
412
+ mlp_peak = trial.suggest_float("mlp_peak_position", 0.1, 0.9)
413
+ mlp_min = trial.suggest_float("mlp_min_weight", 0.0, 0.3)
414
+ mlp_spread = trial.suggest_float("mlp_spread", 0.1, 0.6)
415
 
416
+ # Float direction index (cross-layer interpolation, Heretic-style)
417
+ dir_idx = trial.suggest_float("dir_idx", 0.0, max(n_layers_with_dirs - 1, 0.0))
418
 
419
+ # Compute per-layer, per-component regularization from kernels
420
+ attn_regs: dict[int, float] = {}
421
+ mlp_regs: dict[int, float] = {}
422
  for idx in pipeline._strong_layers:
423
+ attn_w = _parametric_layer_weight(idx, n_total_layers, attn_max, attn_peak, attn_min, attn_spread)
424
+ mlp_w = _parametric_layer_weight(idx, n_total_layers, mlp_max, mlp_peak, mlp_min, mlp_spread)
425
+ attn_regs[idx] = 1.0 - attn_w
426
+ mlp_regs[idx] = 1.0 - mlp_w
 
427
 
428
  # Apply projection with trial's parameters
429
  for idx in pipeline._strong_layers:
430
+ if idx not in pipeline.refusal_directions:
431
  continue
432
 
433
+ # Use cross-layer interpolated direction
434
  direction = _interpolate_direction(pipeline, idx, dir_idx)
435
  d_col = direction.to(device=next(layer_modules[idx].parameters()).device)
436
  d_col = d_col.unsqueeze(-1) if d_col.dim() == 1 else d_col
437
 
438
+ # Attention projection (with per-component kernel)
439
+ attn_reg = attn_regs[idx]
 
 
440
  try:
441
  attn = get_attention_module(layer_modules[idx], arch)
442
  pipeline._project_out_advanced(
 
447
  except (AttributeError, RuntimeError):
448
  pass
449
 
450
+ # MLP/FFN projection (with per-component kernel)
451
+ mlp_reg = mlp_regs[idx]
452
  try:
453
  ffn = get_ffn_module(layer_modules[idx], arch)
454
  count = pipeline._project_out_advanced(
 
470
  refusal = _measure_refusal_rate(pipeline, n_prompts=n_refusal_prompts)
471
  kl = _measure_kl_divergence(pipeline, reference_logits, kl_prompts)
472
 
473
+ # Track best combined score (use average of attn/mlp regs for layer_regs)
474
  nonlocal best_score, best_result
475
  combined = refusal + 0.5 * kl
476
  if combined < best_score:
477
  best_score = combined
478
+ best_result = {
479
+ idx: (attn_regs[idx] + mlp_regs[idx]) / 2.0
480
+ for idx in pipeline._strong_layers
481
+ }
482
 
483
  pipeline.log(
484
  f" Trial {trial.number + 1}/{n_trials}: "
485
  f"refusal={refusal:.0%}, KL={kl:.4f} "
486
+ f"(attn_peak={attn_peak:.2f}, mlp_peak={mlp_peak:.2f}, dir={dir_idx:.2f})"
 
487
  )
488
 
489
  return refusal, kl
 
495
  study_name="obliteratus_parametric_optimization",
496
  )
497
 
498
+ # Enqueue warm-start trial with analysis-derived estimates.
499
+ # Translate informed pipeline params to the new per-component format.
500
+ if informed_warm:
501
+ iw = informed_warm
502
+ warm_params = {
503
+ "attn_max_weight": iw.get("max_weight", 0.9),
504
+ "attn_peak_position": iw.get("peak_position", warm_peak),
505
+ "attn_min_weight": iw.get("min_weight", 0.05),
506
+ "attn_spread": iw.get("spread", 0.3),
507
+ "mlp_max_weight": iw.get("max_weight", 0.9) * iw.get("mlp_scale", 0.6),
508
+ "mlp_peak_position": iw.get("peak_position", warm_peak),
509
+ "mlp_min_weight": iw.get("min_weight", 0.05),
510
+ "mlp_spread": iw.get("spread", 0.3),
511
+ "dir_idx": iw.get("dir_idx", 0.0),
512
+ }
513
+ else:
514
+ warm_params = {
515
+ "attn_max_weight": 0.9,
516
+ "attn_peak_position": warm_peak,
517
+ "attn_min_weight": 0.05,
518
+ "attn_spread": 0.3,
519
+ "mlp_max_weight": 0.6,
520
+ "mlp_peak_position": warm_peak,
521
+ "mlp_min_weight": 0.05,
522
+ "mlp_spread": 0.3,
523
+ "dir_idx": 0.0,
524
+ }
525
  study.enqueue_trial(warm_params)
526
 
527
  pipeline.log(f"Bayesian optimization: running {n_trials} trials (parametric kernel)...")
 
540
  p = best_trial.params
541
  best_result = {}
542
  for idx in pipeline._strong_layers:
543
+ attn_w = _parametric_layer_weight(
544
+ idx, n_total_layers,
545
+ p["attn_max_weight"], p["attn_peak_position"],
546
+ p["attn_min_weight"], p["attn_spread"],
547
+ )
548
+ mlp_w = _parametric_layer_weight(
549
  idx, n_total_layers,
550
+ p["mlp_max_weight"], p["mlp_peak_position"],
551
+ p["mlp_min_weight"], p["mlp_spread"],
552
  )
553
+ best_result[idx] = (attn_w + mlp_w) / 2.0 # average for layer-level reg
554
+ best_result[idx] = 1.0 - best_result[idx]
555
 
556
  pipeline.log(
557
  f" Best trial: refusal={best_trial.values[0]:.0%}, "
558
  f"KL={best_trial.values[1]:.4f}"
559
  )
560
  pipeline.log(
561
+ f" Attn kernel: peak={p['attn_peak_position']:.2f}, "
562
+ f"spread={p['attn_spread']:.2f}, max={p['attn_max_weight']:.2f}"
563
  )
564
  pipeline.log(
565
+ f" MLP kernel: peak={p['mlp_peak_position']:.2f}, "
566
+ f"spread={p['mlp_spread']:.2f}, max={p['mlp_max_weight']:.2f}"
567
  )
568
+ pipeline.log(f" dir_idx={p['dir_idx']:.2f}")
569
 
570
  # Store the best direction index for use during EXCISE
571
  best_dir_idx = p.get("dir_idx", 0.0)
 
575
  new_dir = _interpolate_direction(pipeline, idx, best_dir_idx)
576
  pipeline.refusal_directions[idx] = new_dir
577
 
578
+ # Store component scales for use in EXCISE (backward compat)
579
+ pipeline._bayesian_attn_scale = p.get("attn_max_weight", 1.0)
580
+ pipeline._bayesian_mlp_scale = p.get("mlp_max_weight", 1.0)
581
 
582
  elif best_result:
583
  pipeline.log(f" Using best combined score: {best_score:.4f}")
obliteratus/cli.py CHANGED
@@ -109,7 +109,12 @@ def main(argv: list[str] | None = None):
109
  ],
110
  help="Liberation method (default: advanced)",
111
  )
112
- p.add_argument("--n-directions", type=int, default=None, help="Override: number of SVD directions to extract")
 
 
 
 
 
113
  p.add_argument("--regularization", type=float, default=None, help="Override: fraction to preserve (0.0-1.0)")
114
  p.add_argument("--refinement-passes", type=int, default=None, help="Override: number of iterative passes")
115
  p.add_argument(
@@ -591,6 +596,7 @@ def _cmd_abliterate(args):
591
  dtype=args.dtype,
592
  method=method,
593
  n_directions=args.n_directions,
 
594
  regularization=args.regularization,
595
  refinement_passes=args.refinement_passes,
596
  quantization=args.quantization,
 
109
  ],
110
  help="Liberation method (default: advanced)",
111
  )
112
+ p.add_argument("--n-directions", type=int, default=None, help="Override: number of refusal directions to extract")
113
+ p.add_argument(
114
+ "--direction-method", type=str, default=None,
115
+ choices=["diff_means", "svd", "leace"],
116
+ help="Direction extraction method: diff_means (simple, robust), svd (multi-direction), leace (optimal erasure)",
117
+ )
118
  p.add_argument("--regularization", type=float, default=None, help="Override: fraction to preserve (0.0-1.0)")
119
  p.add_argument("--refinement-passes", type=int, default=None, help="Override: number of iterative passes")
120
  p.add_argument(
 
596
  dtype=args.dtype,
597
  method=method,
598
  n_directions=args.n_directions,
599
+ direction_method=getattr(args, "direction_method", None),
600
  regularization=args.regularization,
601
  refinement_passes=args.refinement_passes,
602
  quantization=args.quantization,
obliteratus/informed_pipeline.py CHANGED
@@ -73,15 +73,17 @@ INFORMED_METHOD = {
73
  "description": (
74
  "Runs analysis modules between PROBE and DISTILL to auto-configure "
75
  "direction extraction, layer selection, and projection strategy based "
76
- "on the model's actual refusal geometry."
 
77
  ),
78
- "n_directions": 4, # overridden by analysis
 
79
  "norm_preserve": True,
80
  "regularization": 0.0, # overridden by analysis
81
  "refinement_passes": 2, # overridden by analysis
82
  "project_biases": True,
83
  "use_chat_template": True,
84
- "use_whitened_svd": True, # overridden by analysis
85
  "true_iterative_refinement": True,
86
  }
87
 
@@ -126,7 +128,8 @@ class AnalysisInsights:
126
  clean_layers: list[int] = field(default_factory=list)
127
 
128
  # Derived configuration
129
- recommended_n_directions: int = 4
 
130
  recommended_regularization: float = 0.0
131
  recommended_refinement_passes: int = 2
132
  recommended_layers: list[int] = field(default_factory=list)
@@ -217,12 +220,19 @@ class InformedAbliterationPipeline(AbliterationPipeline):
217
  hub_token=hub_token,
218
  hub_community_org=hub_community_org,
219
  quantization=quantization,
220
- # Set informed defaults
 
 
221
  norm_preserve=True,
222
  project_biases=True,
223
  use_chat_template=True,
224
- use_whitened_svd=True,
225
  true_iterative_refinement=True,
 
 
 
 
 
226
  )
227
  self.method = "informed"
228
 
@@ -311,7 +321,11 @@ class InformedAbliterationPipeline(AbliterationPipeline):
311
  if self._run_defense:
312
  self._analyze_defense_robustness()
313
 
314
- # 5. Derive configuration from insights
 
 
 
 
315
  self._derive_configuration()
316
 
317
  elapsed = time.time() - t0
@@ -392,6 +406,7 @@ class InformedAbliterationPipeline(AbliterationPipeline):
392
  sample_layers = candidate_layers[::step]
393
 
394
  polyhedral_count = 0
 
395
  best_cone_result = None
396
  best_strength = 0.0
397
 
@@ -405,34 +420,43 @@ class InformedAbliterationPipeline(AbliterationPipeline):
405
  layer_idx=layer_idx,
406
  )
407
 
 
408
  if result.is_polyhedral:
409
  polyhedral_count += 1
410
 
411
- # Track the strongest layer's cone analysis
412
  general_strength = result.general_direction.norm().item() if result.general_direction.numel() > 1 else 0
413
  if general_strength > best_strength:
414
  best_strength = general_strength
415
  best_cone_result = result
416
 
417
- if best_cone_result is not None:
418
- self._insights.cone_is_polyhedral = best_cone_result.is_polyhedral
419
- self._insights.cone_dimensionality = best_cone_result.cone_dimensionality
420
- self._insights.mean_pairwise_cosine = best_cone_result.mean_pairwise_cosine
421
-
422
- # Store per-category directions for category-aware excision
423
- for cd in best_cone_result.category_directions:
424
- self._insights.per_category_directions[cd.category] = cd.direction
425
- self._insights.direction_specificity[cd.category] = cd.specificity
426
-
427
- cone_type = "POLYHEDRAL" if best_cone_result.is_polyhedral else "LINEAR"
428
- self.log(f" Cone type: {cone_type}")
429
- self.log(f" Dimensionality: {best_cone_result.cone_dimensionality:.2f}")
430
- self.log(f" Mean pairwise cosine: {best_cone_result.mean_pairwise_cosine:.3f}")
431
- self.log(f" Categories detected: {best_cone_result.category_count}")
432
- self.log(f" Polyhedral at {polyhedral_count}/{len(sample_layers)} sampled layers")
433
-
434
- for cd in sorted(best_cone_result.category_directions, key=lambda x: -x.strength)[:5]:
435
- self.log(f" {cd.category:15s} DSI={cd.specificity:.3f} str={cd.strength:.3f}")
 
 
 
 
 
 
 
 
436
  else:
437
  self.log(" No cone results — using default linear assumption")
438
 
@@ -517,6 +541,71 @@ class InformedAbliterationPipeline(AbliterationPipeline):
517
  self.log(f" Most entangled layers: {emap.most_entangled_layers}")
518
  self.log(f" Cleanest layers: {emap.least_entangled_layers}")
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  # ── Configuration Derivation ─────────────────────────────────────
521
 
522
  def _derive_configuration(self):
@@ -528,18 +617,32 @@ class InformedAbliterationPipeline(AbliterationPipeline):
528
  self.log("-" * 50)
529
  insights = self._insights
530
 
531
- # 1. n_directions: based on cone geometry
532
- if insights.cone_is_polyhedral:
533
- # Polyhedral cone need more directions to capture all facets
 
 
534
  n_dirs = max(4, min(8, int(insights.cone_dimensionality * 2)))
 
 
535
  self.log(f" Polyhedral cone (dim={insights.cone_dimensionality:.1f}) "
536
- f"→ n_directions={n_dirs}")
 
 
 
 
 
 
 
537
  else:
538
- # Linear cone → fewer directions suffice
539
- n_dirs = max(1, min(4, int(insights.cone_dimensionality + 1)))
 
 
540
  self.log(f" Linear cone (dim={insights.cone_dimensionality:.1f}) "
541
- f"→ n_directions={n_dirs}")
542
  insights.recommended_n_directions = n_dirs
 
543
  self.n_directions = n_dirs
544
 
545
  # 2. regularization: based on alignment method + entanglement
@@ -586,15 +689,22 @@ class InformedAbliterationPipeline(AbliterationPipeline):
586
 
587
  # 4. Layer selection: cluster-aware + entanglement-gated
588
  if insights.cluster_representative_layers:
589
- # Start from cluster representatives
590
  base_layers = list(insights.cluster_representative_layers)
591
 
592
- # Expand: add all layers from clusters that have strong signals
593
- all_cluster_layers = []
 
 
 
 
 
594
  for cluster in insights.direction_clusters:
595
- all_cluster_layers.extend(cluster)
596
- if all_cluster_layers:
597
- base_layers = sorted(set(all_cluster_layers))
 
 
598
 
599
  # Gate: remove highly entangled layers
600
  skip = set()
@@ -621,13 +731,9 @@ class InformedAbliterationPipeline(AbliterationPipeline):
621
  self.log(f" RSI={insights.mean_refusal_sparsity_index:.2f} "
622
  f"→ standard dense projection")
623
 
624
- # 6. Whitened SVD: always use for multi-direction, skip for single
625
- if n_dirs > 1:
626
- self.use_whitened_svd = True
627
- self.log(f" Multi-direction ({n_dirs}) → whitened SVD enabled")
628
- else:
629
- self.use_whitened_svd = False
630
- self.log(" Single direction → standard diff-in-means")
631
 
632
  # ── Informed DISTILL ─────────────────────────────────────────────
633
 
@@ -650,7 +756,38 @@ class InformedAbliterationPipeline(AbliterationPipeline):
650
  n_layers = len(self._harmful_means)
651
  norms: dict[int, float] = {}
652
 
653
- if self.use_whitened_svd and self.n_directions > 1:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
655
  whitened_extractor = WhitenedSVDExtractor()
656
  self.log(f"Using whitened SVD with {self.n_directions} directions")
@@ -658,6 +795,29 @@ class InformedAbliterationPipeline(AbliterationPipeline):
658
  whitened_extractor = None
659
 
660
  for idx in range(n_layers):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  if self.n_directions == 1:
662
  diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
663
  norm = diff.norm().item()
@@ -691,6 +851,41 @@ class InformedAbliterationPipeline(AbliterationPipeline):
691
  self.refusal_directions[idx] = primary / primary.norm()
692
  norms[idx] = S[:k].sum().item()
693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
  # Layer selection: use analysis-recommended layers if available,
695
  # otherwise fall back to knee detection
696
  if self._insights.recommended_layers:
@@ -728,15 +923,117 @@ class InformedAbliterationPipeline(AbliterationPipeline):
728
  def _excise_informed(self):
729
  """Excise refusal directions with analysis-informed strategy.
730
 
731
- Uses sparse surgery if analysis recommends it, otherwise falls
732
- back to the standard projection with analysis-tuned parameters.
 
 
733
  """
734
  if self._insights.use_sparse_surgery:
735
  self._excise_sparse()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
  else:
737
- # Standard excision with analysis-tuned parameters
738
- # (regularization, norm_preserve, etc. already configured)
739
- self._excise()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
  def _excise_sparse(self):
742
  """Sparse direction surgery — only modifies high-projection rows."""
@@ -825,14 +1122,22 @@ class InformedAbliterationPipeline(AbliterationPipeline):
825
  1. Residual refusal signal (via activation probing)
826
  2. Self-repair / Ouroboros effect (via defense robustness)
827
  3. Triggers additional targeted passes at compensating layers
 
 
 
 
828
  """
829
  # Run standard verification first
830
  self._verify()
831
 
832
  # Check if Ouroboros compensation is needed
833
  refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
 
834
  ouroboros_pass = 0
835
 
 
 
 
836
  while (refusal_rate > self._ouroboros_threshold
837
  and ouroboros_pass < self._max_ouroboros_passes):
838
  ouroboros_pass += 1
@@ -849,9 +1154,9 @@ class InformedAbliterationPipeline(AbliterationPipeline):
849
  self._distill_inner()
850
  self.log(f"Found {len(self._strong_layers)} layers with residual refusal")
851
 
852
- # Re-excise at the new strong layers
853
  if self._strong_layers:
854
- self._excise()
855
  else:
856
  self.log("No strong layers found — stopping Ouroboros compensation")
857
  break
@@ -859,7 +1164,24 @@ class InformedAbliterationPipeline(AbliterationPipeline):
859
  # Re-verify
860
  self._verify()
861
  refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
862
- self.log(f"After Ouroboros pass {ouroboros_pass}: refusal rate = {refusal_rate:.0%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
864
  self._report.ouroboros_passes = ouroboros_pass
865
  self._report.final_refusal_rate = refusal_rate
@@ -903,6 +1225,7 @@ class InformedAbliterationPipeline(AbliterationPipeline):
903
  },
904
  "derived_config": {
905
  "n_directions": insights.recommended_n_directions,
 
906
  "regularization": insights.recommended_regularization,
907
  "refinement_passes": insights.recommended_refinement_passes,
908
  "layers_used": insights.recommended_layers,
@@ -981,6 +1304,7 @@ class InformedAbliterationPipeline(AbliterationPipeline):
981
 
982
  lines.append("Derived Configuration:")
983
  lines.append(f" n_directions: {insights.recommended_n_directions}")
 
984
  lines.append(f" regularization: {insights.recommended_regularization}")
985
  lines.append(f" refinement_passes: {insights.recommended_refinement_passes}")
986
  lines.append(f" sparse surgery: {insights.use_sparse_surgery}")
 
73
  "description": (
74
  "Runs analysis modules between PROBE and DISTILL to auto-configure "
75
  "direction extraction, layer selection, and projection strategy based "
76
+ "on the model's actual refusal geometry. Defaults to single diff-of-means "
77
+ "direction + Bayesian optimization (Heretic-style)."
78
  ),
79
+ "n_directions": 1, # overridden by analysis
80
+ "direction_method": "diff_means", # overridden by analysis; "leace" also available
81
  "norm_preserve": True,
82
  "regularization": 0.0, # overridden by analysis
83
  "refinement_passes": 2, # overridden by analysis
84
  "project_biases": True,
85
  "use_chat_template": True,
86
+ "use_whitened_svd": False, # overridden by analysis
87
  "true_iterative_refinement": True,
88
  }
89
 
 
128
  clean_layers: list[int] = field(default_factory=list)
129
 
130
  # Derived configuration
131
+ recommended_n_directions: int = 1
132
+ recommended_direction_method: str = "diff_means"
133
  recommended_regularization: float = 0.0
134
  recommended_refinement_passes: int = 2
135
  recommended_layers: list[int] = field(default_factory=list)
 
220
  hub_token=hub_token,
221
  hub_community_org=hub_community_org,
222
  quantization=quantization,
223
+ # Set informed defaults: single direction + Bayesian opt
224
+ n_directions=1,
225
+ direction_method="diff_means",
226
  norm_preserve=True,
227
  project_biases=True,
228
  use_chat_template=True,
229
+ use_whitened_svd=False,
230
  true_iterative_refinement=True,
231
+ use_kl_optimization=True,
232
+ float_layer_interpolation=True,
233
+ layer_adaptive_strength=True,
234
+ winsorize_activations=True,
235
+ winsorize_percentile=0.01,
236
  )
237
  self.method = "informed"
238
 
 
321
  if self._run_defense:
322
  self._analyze_defense_robustness()
323
 
324
+ # 5. Sparse Surgery Analysis (RSI computation)
325
+ if self._run_sparse:
326
+ self._analyze_sparsity()
327
+
328
+ # 6. Derive configuration from insights
329
  self._derive_configuration()
330
 
331
  elapsed = time.time() - t0
 
406
  sample_layers = candidate_layers[::step]
407
 
408
  polyhedral_count = 0
409
+ all_results = []
410
  best_cone_result = None
411
  best_strength = 0.0
412
 
 
420
  layer_idx=layer_idx,
421
  )
422
 
423
+ all_results.append(result)
424
  if result.is_polyhedral:
425
  polyhedral_count += 1
426
 
427
+ # Track the strongest layer's cone analysis for per-category directions
428
  general_strength = result.general_direction.norm().item() if result.general_direction.numel() > 1 else 0
429
  if general_strength > best_strength:
430
  best_strength = general_strength
431
  best_cone_result = result
432
 
433
+ if all_results:
434
+ # Aggregate cone geometry across sampled layers (majority vote +
435
+ # mean dimensionality) instead of relying on a single layer.
436
+ n_sampled = len(all_results)
437
+ is_polyhedral = polyhedral_count > n_sampled / 2
438
+ avg_dimensionality = sum(r.cone_dimensionality for r in all_results) / n_sampled
439
+ avg_pairwise_cos = sum(r.mean_pairwise_cosine for r in all_results) / n_sampled
440
+
441
+ self._insights.cone_is_polyhedral = is_polyhedral
442
+ self._insights.cone_dimensionality = avg_dimensionality
443
+ self._insights.mean_pairwise_cosine = avg_pairwise_cos
444
+
445
+ # Store per-category directions from the strongest layer
446
+ if best_cone_result is not None:
447
+ for cd in best_cone_result.category_directions:
448
+ self._insights.per_category_directions[cd.category] = cd.direction
449
+ self._insights.direction_specificity[cd.category] = cd.specificity
450
+
451
+ cone_type = "POLYHEDRAL" if is_polyhedral else "LINEAR"
452
+ self.log(f" Cone type: {cone_type} (majority vote: {polyhedral_count}/{n_sampled} layers)")
453
+ self.log(f" Avg dimensionality: {avg_dimensionality:.2f}")
454
+ self.log(f" Avg pairwise cosine: {avg_pairwise_cos:.3f}")
455
+ if best_cone_result is not None:
456
+ self.log(f" Categories detected: {best_cone_result.category_count}")
457
+
458
+ for cd in sorted(best_cone_result.category_directions, key=lambda x: -x.strength)[:5]:
459
+ self.log(f" {cd.category:15s} DSI={cd.specificity:.3f} str={cd.strength:.3f}")
460
  else:
461
  self.log(" No cone results — using default linear assumption")
462
 
 
541
  self.log(f" Most entangled layers: {emap.most_entangled_layers}")
542
  self.log(f" Cleanest layers: {emap.least_entangled_layers}")
543
 
544
+ def _analyze_sparsity(self):
545
+ """Compute Refusal Sparsity Index to decide sparse vs dense excision."""
546
+ self.log("\n[5/5] Refusal Sparsity Analysis")
547
+ self.log("-" * 40)
548
+
549
+ from obliteratus.analysis.sparse_surgery import SparseDirectionSurgeon
550
+ from obliteratus.strategies.utils import (
551
+ get_ffn_module,
552
+ get_layer_modules,
553
+ )
554
+
555
+ # Need refusal directions — use quick diff-in-means
556
+ quick_directions = {}
557
+ for idx in sorted(self._harmful_means.keys()):
558
+ diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze()
559
+ norm = diff.norm().item()
560
+ if norm > 1e-10:
561
+ quick_directions[idx] = diff / diff.norm()
562
+
563
+ if not quick_directions:
564
+ self.log(" No refusal directions — skipping sparsity analysis")
565
+ return
566
+
567
+ # Gather FFN output weights for representative layers (sample for speed)
568
+ layers = get_layer_modules(self.handle)
569
+ arch = self.handle.architecture
570
+ n_layers = len(layers)
571
+ sample_idxs = sorted(quick_directions.keys())
572
+ step = max(1, len(sample_idxs) // 8)
573
+ sample_idxs = sample_idxs[::step]
574
+
575
+ weights = {}
576
+ sampled_dirs = {}
577
+ for idx in sample_idxs:
578
+ if idx >= n_layers:
579
+ continue
580
+ try:
581
+ ffn = get_ffn_module(layers[idx], arch)
582
+ for name in ["down_proj", "c_proj", "dense_4h_to_h", "fc_out", "fc2", "w2"]:
583
+ proj = getattr(ffn, name, None)
584
+ if proj is not None and hasattr(proj, "weight"):
585
+ W = proj.weight.data
586
+ d = quick_directions[idx]
587
+ if W.shape[-1] == d.shape[0]:
588
+ weights[idx] = W
589
+ sampled_dirs[idx] = d
590
+ break
591
+ except (AttributeError, RuntimeError):
592
+ continue
593
+
594
+ if not weights:
595
+ self.log(" Could not access FFN weights — skipping sparsity analysis")
596
+ return
597
+
598
+ surgeon = SparseDirectionSurgeon(auto_sparsity=True)
599
+ plan = surgeon.plan_surgery(weights, sampled_dirs)
600
+
601
+ self._insights.mean_refusal_sparsity_index = plan.mean_refusal_sparsity_index
602
+ self._insights.recommended_sparsity = plan.recommended_sparsity
603
+
604
+ self.log(f" Mean RSI: {plan.mean_refusal_sparsity_index:.3f}")
605
+ self.log(f" Recommended sparsity: {plan.recommended_sparsity:.1%}")
606
+ self.log(f" Most sparse layer: {plan.most_sparse_layer}")
607
+ self.log(f" Most dense layer: {plan.most_dense_layer}")
608
+
609
  # ── Configuration Derivation ─────────────────────────────────────
610
 
611
  def _derive_configuration(self):
 
617
  self.log("-" * 50)
618
  insights = self._insights
619
 
620
+ # 1. n_directions + direction_method: based on cone geometry
621
+ # Default: single direction via diff-of-means (proven most robust).
622
+ # Only escalate to multi-direction when analysis confirms polyhedral geometry.
623
+ if insights.cone_is_polyhedral and insights.cone_dimensionality > 2.0:
624
+ # Clearly polyhedral cone → use multiple directions via SVD
625
  n_dirs = max(4, min(8, int(insights.cone_dimensionality * 2)))
626
+ self.direction_method = "svd"
627
+ self.use_whitened_svd = True
628
  self.log(f" Polyhedral cone (dim={insights.cone_dimensionality:.1f}) "
629
+ f"→ n_directions={n_dirs}, method=svd (whitened)")
630
+ elif insights.cone_is_polyhedral:
631
+ # Mildly polyhedral → LEACE gives better single-direction erasure
632
+ n_dirs = 1
633
+ self.direction_method = "leace"
634
+ self.use_whitened_svd = False
635
+ self.log(f" Mildly polyhedral (dim={insights.cone_dimensionality:.1f}) "
636
+ f"→ n_directions=1, method=leace")
637
  else:
638
+ # Linear cone → single direction via diff-of-means (simplest, most robust)
639
+ n_dirs = 1
640
+ self.direction_method = "diff_means"
641
+ self.use_whitened_svd = False
642
  self.log(f" Linear cone (dim={insights.cone_dimensionality:.1f}) "
643
+ f"→ n_directions=1, method=diff_means")
644
  insights.recommended_n_directions = n_dirs
645
+ insights.recommended_direction_method = self.direction_method
646
  self.n_directions = n_dirs
647
 
648
  # 2. regularization: based on alignment method + entanglement
 
689
 
690
  # 4. Layer selection: cluster-aware + entanglement-gated
691
  if insights.cluster_representative_layers:
692
+ # Start from cluster representatives (strongest per cluster)
693
  base_layers = list(insights.cluster_representative_layers)
694
 
695
+ # Conservative expansion: for each cluster, add at most the top-2
696
+ # strongest layers (by refusal norm) beyond the representative,
697
+ # to avoid over-modifying weak layers in large clusters.
698
+ norms = {}
699
+ for idx in self._harmful_means:
700
+ if idx in self._harmless_means:
701
+ norms[idx] = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze().norm().item()
702
  for cluster in insights.direction_clusters:
703
+ ranked = sorted(cluster, key=lambda ly: norms.get(ly, 0), reverse=True)
704
+ # Add up to 2 additional strong layers per cluster
705
+ for ly in ranked[:3]: # representative + up to 2 more
706
+ base_layers.append(ly)
707
+ base_layers = sorted(set(base_layers))
708
 
709
  # Gate: remove highly entangled layers
710
  skip = set()
 
731
  self.log(f" RSI={insights.mean_refusal_sparsity_index:.2f} "
732
  f"→ standard dense projection")
733
 
734
+ # 6. Direction method summary (already set in step 1)
735
+ self.log(f" Direction method: {self.direction_method} "
736
+ f"(whitened_svd={'on' if self.use_whitened_svd else 'off'})")
 
 
 
 
737
 
738
  # ── Informed DISTILL ─────────────────────────────────────────────
739
 
 
756
  n_layers = len(self._harmful_means)
757
  norms: dict[int, float] = {}
758
 
759
+ # ── Small-model direction cap (matching base _distill) ────────
760
+ # On small models, each SVD direction removes a proportionally
761
+ # larger fraction of weight energy. Cap to prevent over-ablation.
762
+ hidden_size = self.handle.hidden_size if self.handle else 0
763
+ total_params = getattr(self.handle, 'total_params', 0) if self.handle else 0
764
+ if total_params == 0 and self.handle:
765
+ try:
766
+ total_params = sum(p.numel() for p in self.handle.model.parameters())
767
+ except Exception:
768
+ pass
769
+ if self.n_directions > 1 and (
770
+ (0 < hidden_size < 2048)
771
+ or (0 < total_params < 2_000_000_000)
772
+ or n_layers <= 16
773
+ ):
774
+ max_dirs = max(1, min(self.n_directions, 2))
775
+ if max_dirs < self.n_directions:
776
+ self.log(
777
+ f"Capped n_directions from {self.n_directions} to {max_dirs} "
778
+ f"for small model (hidden={hidden_size}, "
779
+ f"params={total_params / 1e9:.1f}B, layers={n_layers})"
780
+ )
781
+ self.n_directions = max_dirs
782
+
783
+ # LEACE extractor for optimal concept erasure
784
+ leace_extractor = None
785
+ if self.direction_method == "leace":
786
+ from obliteratus.analysis.leace import LEACEExtractor
787
+ leace_extractor = LEACEExtractor()
788
+ self.log(f"Using LEACE (closed-form optimal concept erasure)")
789
+
790
+ if self.use_whitened_svd and self.n_directions > 1 and leace_extractor is None:
791
  from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
792
  whitened_extractor = WhitenedSVDExtractor()
793
  self.log(f"Using whitened SVD with {self.n_directions} directions")
 
795
  whitened_extractor = None
796
 
797
  for idx in range(n_layers):
798
+ # LEACE path: theoretically optimal single-direction erasure
799
+ if leace_extractor is not None:
800
+ if idx in self._harmful_acts and idx in self._harmless_acts:
801
+ try:
802
+ l_result = leace_extractor.extract(
803
+ self._harmful_acts[idx],
804
+ self._harmless_acts[idx],
805
+ layer_idx=idx,
806
+ )
807
+ self.refusal_directions[idx] = l_result.direction
808
+ self.refusal_subspaces[idx] = l_result.direction.unsqueeze(0)
809
+ norms[idx] = l_result.generalized_eigenvalue
810
+
811
+ if idx < 5 or idx == n_layers - 1:
812
+ self.log(
813
+ f" layer {idx}: LEACE eigenvalue={l_result.generalized_eigenvalue:.4f}, "
814
+ f"erasure_loss={l_result.erasure_loss:.4f}"
815
+ )
816
+ continue
817
+ except Exception as e:
818
+ if idx < 5:
819
+ self.log(f" layer {idx}: LEACE failed ({e}), falling back")
820
+
821
  if self.n_directions == 1:
822
  diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
823
  norm = diff.norm().item()
 
851
  self.refusal_directions[idx] = primary / primary.norm()
852
  norms[idx] = S[:k].sum().item()
853
 
854
+ # Enrich subspaces with per-category cone directions when available.
855
+ # This uses the actual refusal cone generators instead of purely
856
+ # data-agnostic SVD components.
857
+ cat_dirs = self._insights.per_category_directions
858
+ if cat_dirs and self._insights.cone_is_polyhedral and self.n_directions > 1:
859
+ cat_tensors = list(cat_dirs.values())
860
+ # Stack and orthogonalize category directions
861
+ cat_stack = torch.stack(cat_tensors) # (n_cats, hidden)
862
+ cat_norms = cat_stack.norm(dim=1, keepdim=True).clamp(min=1e-8)
863
+ cat_stack = cat_stack / cat_norms
864
+ # Blend into strong-signal layers: replace later SVD components
865
+ # with category directions (which are geometrically meaningful)
866
+ n_cat = cat_stack.shape[0]
867
+ for idx in norms:
868
+ sub = self.refusal_subspaces.get(idx)
869
+ if sub is None or sub.shape[0] <= 1:
870
+ continue
871
+ # Keep the first SVD direction (strongest), replace remaining
872
+ # with category directions projected to be orthogonal to it
873
+ primary = sub[0:1] # (1, hidden)
874
+ # Project category directions orthogonal to primary
875
+ cos = (cat_stack @ primary.squeeze(0)) # (n_cat,)
876
+ ortho_cats = cat_stack - cos.unsqueeze(1) * primary
877
+ ortho_norms = ortho_cats.norm(dim=1)
878
+ # Keep only directions that survived orthogonalization
879
+ valid = ortho_norms > 0.1
880
+ if valid.sum() > 0:
881
+ ortho_cats = ortho_cats[valid]
882
+ ortho_cats = ortho_cats / ortho_cats.norm(dim=1, keepdim=True)
883
+ # Take up to (n_directions - 1) category directions
884
+ n_take = min(self.n_directions - 1, ortho_cats.shape[0])
885
+ new_sub = torch.cat([primary, ortho_cats[:n_take]], dim=0)
886
+ self.refusal_subspaces[idx] = new_sub
887
+ self.log(f"Enriched subspaces with {n_cat} per-category cone directions")
888
+
889
  # Layer selection: use analysis-recommended layers if available,
890
  # otherwise fall back to knee detection
891
  if self._insights.recommended_layers:
 
923
  def _excise_informed(self):
924
  """Excise refusal directions with analysis-informed strategy.
925
 
926
+ Uses Bayesian optimization (when available) with analysis-derived
927
+ warm-start parameters, falling back to sparse surgery or standard
928
+ projection. This is the key integration: analysis maps the geometry,
929
+ Bayesian optimization finds the optimal projection strength.
930
  """
931
  if self._insights.use_sparse_surgery:
932
  self._excise_sparse()
933
+ return
934
+
935
+ # Enable Bayesian optimization using analysis insights for warm-start.
936
+ # The analysis provides much better initial parameters than the default
937
+ # heuristic (strongest-layer-based peak), dramatically narrowing the
938
+ # search space and improving convergence.
939
+ self._configure_bayesian_warm_start()
940
+ self._excise()
941
+
942
+ def _configure_bayesian_warm_start(self):
943
+ """Configure Bayesian optimization with analysis-derived warm-start.
944
+
945
+ Translates analysis insights into a much tighter search space:
946
+ - peak_position from cluster representative layers
947
+ - spread from cluster structure (narrow clusters → narrow spread)
948
+ - component scaling from entanglement analysis
949
+ - KL budget from alignment method detection
950
+ """
951
+ insights = self._insights
952
+
953
+ # Enable Bayesian optimization (50 trials default, same as heretic)
954
+ self._bayesian_trials = 50
955
+
956
+ # Also set heretic-compatible flags on the pipeline so the base
957
+ # _excise_inner() picks them up during Bayesian optimization.
958
+ self.layer_adaptive_strength = True
959
+ self.float_layer_interpolation = True
960
+ self.use_kl_optimization = True
961
+
962
+ # KL budget: tighter for methods that are fragile (CAI, RLHF),
963
+ # looser for concentrated methods (DPO, SFT).
964
+ method = insights.detected_alignment_method
965
+ if method == "dpo":
966
+ self.kl_budget = 0.5
967
+ elif method == "rlhf":
968
+ self.kl_budget = 0.3
969
+ elif method == "cai":
970
+ self.kl_budget = 0.2
971
+ elif method == "sft":
972
+ self.kl_budget = 0.4
973
+ else:
974
+ self.kl_budget = 0.35
975
+
976
+ self.log(f"Bayesian optimization enabled (50 trials, KL budget={self.kl_budget})")
977
+ self.log("Analysis insights will warm-start the optimizer")
978
+
979
+ # Compute analysis-derived warm-start for the parametric kernel.
980
+ # The Bayesian optimizer reads these from the pipeline if present.
981
+ n_layers = len(self._harmful_means) if self._harmful_means else 32
982
+ if insights.cluster_representative_layers and n_layers > 1:
983
+ # Peak position: normalized position of the strongest cluster rep
984
+ norms = {}
985
+ for idx in self._harmful_means:
986
+ if idx in self._harmless_means:
987
+ norms[idx] = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze().norm().item()
988
+ reps = insights.cluster_representative_layers
989
+ if norms:
990
+ best_rep = max(reps, key=lambda ly: norms.get(ly, 0))
991
+ else:
992
+ best_rep = reps[len(reps) // 2]
993
+ warm_peak = best_rep / max(n_layers - 1, 1)
994
+
995
+ # Spread: narrow if clusters are tight, wide if clusters span many layers
996
+ if insights.direction_clusters:
997
+ cluster_widths = [
998
+ (max(c) - min(c)) / max(n_layers - 1, 1)
999
+ for c in insights.direction_clusters if len(c) > 1
1000
+ ]
1001
+ warm_spread = max(0.1, min(0.6, sum(cluster_widths) / len(cluster_widths) if cluster_widths else 0.3))
1002
+ else:
1003
+ warm_spread = 0.3
1004
+
1005
+ # Min weight: higher if high persistence (refusal spread across all layers)
1006
+ warm_min = min(0.3, max(0.0, insights.direction_persistence * 0.2))
1007
+
1008
+ # Attn/MLP scaling: reduce MLP scaling if entanglement is high
1009
+ # (MLP projections cause more capability damage)
1010
+ if insights.entanglement_score > 0.5:
1011
+ warm_mlp = 0.4
1012
+ warm_attn = 0.7
1013
+ else:
1014
+ warm_mlp = 0.6
1015
+ warm_attn = 0.8
1016
  else:
1017
+ warm_peak = 0.5
1018
+ warm_spread = 0.3
1019
+ warm_min = 0.05
1020
+ warm_mlp = 0.6
1021
+ warm_attn = 0.8
1022
+
1023
+ # Store warm-start params for the Bayesian optimizer to pick up
1024
+ self._informed_warm_start = {
1025
+ "max_weight": 0.9,
1026
+ "peak_position": warm_peak,
1027
+ "min_weight": warm_min,
1028
+ "spread": warm_spread,
1029
+ "attn_scale": warm_attn,
1030
+ "mlp_scale": warm_mlp,
1031
+ "dir_idx": 0.0,
1032
+ }
1033
+ self.log(
1034
+ f" Warm-start: peak={warm_peak:.2f}, spread={warm_spread:.2f}, "
1035
+ f"min={warm_min:.2f}, attn={warm_attn:.2f}, mlp={warm_mlp:.2f}"
1036
+ )
1037
 
1038
  def _excise_sparse(self):
1039
  """Sparse direction surgery — only modifies high-projection rows."""
 
1122
  1. Residual refusal signal (via activation probing)
1123
  2. Self-repair / Ouroboros effect (via defense robustness)
1124
  3. Triggers additional targeted passes at compensating layers
1125
+
1126
+ KL-gated: stops early if model damage (KL divergence) is getting
1127
+ worse even though refusal persists. This prevents the death spiral
1128
+ where each pass damages the model without removing refusal.
1129
  """
1130
  # Run standard verification first
1131
  self._verify()
1132
 
1133
  # Check if Ouroboros compensation is needed
1134
  refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
1135
+ prev_kl = self._quality_metrics.get("kl_divergence", 0.0)
1136
  ouroboros_pass = 0
1137
 
1138
+ # KL budget: stop if KL exceeds this threshold (model too damaged)
1139
+ kl_ceiling = getattr(self, "kl_budget", 0.5) * 2.0 # 2x budget as hard ceiling
1140
+
1141
  while (refusal_rate > self._ouroboros_threshold
1142
  and ouroboros_pass < self._max_ouroboros_passes):
1143
  ouroboros_pass += 1
 
1154
  self._distill_inner()
1155
  self.log(f"Found {len(self._strong_layers)} layers with residual refusal")
1156
 
1157
+ # Re-excise at the new strong layers using informed strategy
1158
  if self._strong_layers:
1159
+ self._excise_informed()
1160
  else:
1161
  self.log("No strong layers found — stopping Ouroboros compensation")
1162
  break
 
1164
  # Re-verify
1165
  self._verify()
1166
  refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
1167
+ current_kl = self._quality_metrics.get("kl_divergence", 0.0)
1168
+ self.log(f"After Ouroboros pass {ouroboros_pass}: refusal={refusal_rate:.0%}, KL={current_kl:.4f}")
1169
+
1170
+ # KL-gated early stopping: if KL is rising and exceeds ceiling,
1171
+ # the model is being damaged faster than refusal is being removed.
1172
+ if current_kl > kl_ceiling:
1173
+ self.log(
1174
+ f"KL divergence {current_kl:.4f} exceeds ceiling {kl_ceiling:.4f} — "
1175
+ f"stopping to prevent further model damage"
1176
+ )
1177
+ break
1178
+ if ouroboros_pass > 1 and current_kl > prev_kl * 1.5 and refusal_rate > 0.3:
1179
+ self.log(
1180
+ f"KL rising sharply ({prev_kl:.4f} → {current_kl:.4f}) with "
1181
+ f"refusal still at {refusal_rate:.0%} — stopping (diminishing returns)"
1182
+ )
1183
+ break
1184
+ prev_kl = current_kl
1185
 
1186
  self._report.ouroboros_passes = ouroboros_pass
1187
  self._report.final_refusal_rate = refusal_rate
 
1225
  },
1226
  "derived_config": {
1227
  "n_directions": insights.recommended_n_directions,
1228
+ "direction_method": insights.recommended_direction_method,
1229
  "regularization": insights.recommended_regularization,
1230
  "refinement_passes": insights.recommended_refinement_passes,
1231
  "layers_used": insights.recommended_layers,
 
1304
 
1305
  lines.append("Derived Configuration:")
1306
  lines.append(f" n_directions: {insights.recommended_n_directions}")
1307
+ lines.append(f" direction_method: {insights.recommended_direction_method}")
1308
  lines.append(f" regularization: {insights.recommended_regularization}")
1309
  lines.append(f" refinement_passes: {insights.recommended_refinement_passes}")
1310
  lines.append(f" sparse surgery: {insights.use_sparse_surgery}")
scripts/run_benchmark_remote.sh CHANGED
@@ -144,12 +144,18 @@ def _patched_collect(self, layer_modules, prompts, label):
144
  torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
145
  for i in range(torch.cuda.device_count())
146
  )
147
- if free_gb < 2.0:
 
 
 
 
 
 
148
  max_length = 64
149
- self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
150
- elif free_gb < 4.0:
151
  max_length = 128
152
- self.log(f" Tight GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
153
 
154
  device = self._get_model_device(model)
155
 
 
144
  torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
145
  for i in range(torch.cuda.device_count())
146
  )
147
+ # Scale thresholds by model size (baseline: 7B with hidden=4096, 32 layers)
148
+ _h = self.handle.hidden_size if self.handle else 4096
149
+ _l = n_layers if n_layers else 32
150
+ _ms = (_h / 4096) * (_l / 32)
151
+ _tight = max(4.0 * _ms, 0.5)
152
+ _low = max(2.0 * _ms, 0.25)
153
+ if free_gb < _low:
154
  max_length = 64
155
+ self.log(f" Low GPU memory ({free_gb:.1f} GB free, threshold {_low:.1f} GB), using max_length={max_length}")
156
+ elif free_gb < _tight:
157
  max_length = 128
158
+ self.log(f" Tight GPU memory ({free_gb:.1f} GB free, threshold {_tight:.1f} GB), using max_length={max_length}")
159
 
160
  device = self._get_model_device(model)
161
 
tests/test_informed_pipeline.py CHANGED
@@ -50,7 +50,8 @@ class TestAnalysisInsights:
50
  assert insights.cluster_count == 0
51
  assert insights.direction_persistence == 0.0
52
  assert insights.use_sparse_surgery is False
53
- assert insights.recommended_n_directions == 4
 
54
  assert insights.recommended_regularization == 0.0
55
  assert insights.recommended_refinement_passes == 2
56
  assert insights.recommended_layers == []
@@ -86,12 +87,16 @@ class TestInformedMethod:
86
  assert cfg["norm_preserve"] is True
87
  assert cfg["project_biases"] is True
88
  assert cfg["use_chat_template"] is True
89
- assert cfg["use_whitened_svd"] is True
90
  assert cfg["true_iterative_refinement"] is True
 
 
 
91
 
92
  def test_informed_method_standalone(self):
93
  assert INFORMED_METHOD["label"] == "Informed (Analysis-Guided)"
94
- assert INFORMED_METHOD["n_directions"] == 4
 
95
  assert INFORMED_METHOD["norm_preserve"] is True
96
 
97
 
@@ -121,8 +126,10 @@ class TestPipelineInit:
121
  assert pipeline.norm_preserve is True
122
  assert pipeline.project_biases is True
123
  assert pipeline.use_chat_template is True
124
- assert pipeline.use_whitened_svd is True
125
  assert pipeline.true_iterative_refinement is True
 
 
126
 
127
  def test_custom_flags(self):
128
  p = InformedAbliterationPipeline(
@@ -162,17 +169,31 @@ class TestConfigurationDerivation:
162
  cone_dimensionality=3.5,
163
  )
164
  p._derive_configuration()
165
- # Polyhedral with dim 3.5 n_dirs = max(4, min(8, int(3.5*2))) = 7
166
  assert p.n_directions == 7
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- def test_linear_cone_fewer_directions(self):
169
  p = self._make_pipeline_with_insights(
170
  cone_is_polyhedral=False,
171
  cone_dimensionality=1.0,
172
  )
173
  p._derive_configuration()
174
- # Linear with dim 1.0 n_dirs = max(1, min(4, int(1.0+1))) = 2
175
- assert p.n_directions == 2
 
 
176
 
177
  def test_dpo_zero_regularization(self):
178
  p = self._make_pipeline_with_insights(
@@ -282,6 +303,7 @@ class TestConfigurationDerivation:
282
  p._derive_configuration()
283
  assert p.n_directions > 1
284
  assert p.use_whitened_svd is True
 
285
 
286
  def test_no_whitened_svd_for_single_direction(self):
287
  p = self._make_pipeline_with_insights(
@@ -289,9 +311,9 @@ class TestConfigurationDerivation:
289
  cone_dimensionality=0.5,
290
  )
291
  p._derive_configuration()
292
- # dim 0.5 → max(1, min(4, int(0.5+1))) = 1
293
  assert p.n_directions == 1
294
  assert p.use_whitened_svd is False
 
295
 
296
 
297
  # ---------------------------------------------------------------------------
@@ -320,10 +342,12 @@ class TestFormatInsights:
320
 
321
  def test_format_includes_derived_config(self, insights):
322
  insights.recommended_n_directions = 6
 
323
  insights.recommended_regularization = 0.2
324
  insights.recommended_refinement_passes = 3
325
  text = InformedAbliterationPipeline.format_insights(insights)
326
  assert "n_directions: 6" in text
 
327
  assert "regularization: 0.2" in text
328
  assert "refinement_passes: 3" in text
329
 
@@ -372,14 +396,16 @@ class TestEdgeCases:
372
  model_name="test",
373
  on_log=lambda m: None,
374
  )
375
- # Very high dimensionality
376
  p._insights.cone_is_polyhedral = True
377
  p._insights.cone_dimensionality = 10.0
378
  p._derive_configuration()
379
  assert p.n_directions <= 8 # capped
 
380
 
381
- # Very low dimensionality
382
  p._insights.cone_is_polyhedral = False
383
  p._insights.cone_dimensionality = 0.1
384
  p._derive_configuration()
385
- assert p.n_directions >= 1 # at least 1
 
 
50
  assert insights.cluster_count == 0
51
  assert insights.direction_persistence == 0.0
52
  assert insights.use_sparse_surgery is False
53
+ assert insights.recommended_n_directions == 1
54
+ assert insights.recommended_direction_method == "diff_means"
55
  assert insights.recommended_regularization == 0.0
56
  assert insights.recommended_refinement_passes == 2
57
  assert insights.recommended_layers == []
 
87
  assert cfg["norm_preserve"] is True
88
  assert cfg["project_biases"] is True
89
  assert cfg["use_chat_template"] is True
90
+ assert cfg["use_whitened_svd"] is False
91
  assert cfg["true_iterative_refinement"] is True
92
+ assert cfg["n_directions"] == 1
93
+ assert cfg["direction_method"] == "diff_means"
94
+ assert cfg["use_kl_optimization"] is True
95
 
96
  def test_informed_method_standalone(self):
97
  assert INFORMED_METHOD["label"] == "Informed (Analysis-Guided)"
98
+ assert INFORMED_METHOD["n_directions"] == 1
99
+ assert INFORMED_METHOD["direction_method"] == "diff_means"
100
  assert INFORMED_METHOD["norm_preserve"] is True
101
 
102
 
 
126
  assert pipeline.norm_preserve is True
127
  assert pipeline.project_biases is True
128
  assert pipeline.use_chat_template is True
129
+ assert pipeline.use_whitened_svd is False
130
  assert pipeline.true_iterative_refinement is True
131
+ assert pipeline.direction_method == "diff_means"
132
+ assert pipeline.n_directions == 1
133
 
134
  def test_custom_flags(self):
135
  p = InformedAbliterationPipeline(
 
169
  cone_dimensionality=3.5,
170
  )
171
  p._derive_configuration()
172
+ # Clearly polyhedral with dim 3.5 > 2.0 SVD multi-direction
173
  assert p.n_directions == 7
174
+ assert p.direction_method == "svd"
175
+ assert p.use_whitened_svd is True
176
+
177
+ def test_mildly_polyhedral_uses_leace(self):
178
+ p = self._make_pipeline_with_insights(
179
+ cone_is_polyhedral=True,
180
+ cone_dimensionality=1.5,
181
+ )
182
+ p._derive_configuration()
183
+ # Mildly polyhedral (dim <= 2.0) → single LEACE direction
184
+ assert p.n_directions == 1
185
+ assert p.direction_method == "leace"
186
 
187
+ def test_linear_cone_uses_diff_means(self):
188
  p = self._make_pipeline_with_insights(
189
  cone_is_polyhedral=False,
190
  cone_dimensionality=1.0,
191
  )
192
  p._derive_configuration()
193
+ # Linear conesingle diff-of-means direction
194
+ assert p.n_directions == 1
195
+ assert p.direction_method == "diff_means"
196
+ assert p.use_whitened_svd is False
197
 
198
  def test_dpo_zero_regularization(self):
199
  p = self._make_pipeline_with_insights(
 
303
  p._derive_configuration()
304
  assert p.n_directions > 1
305
  assert p.use_whitened_svd is True
306
+ assert p.direction_method == "svd"
307
 
308
  def test_no_whitened_svd_for_single_direction(self):
309
  p = self._make_pipeline_with_insights(
 
311
  cone_dimensionality=0.5,
312
  )
313
  p._derive_configuration()
 
314
  assert p.n_directions == 1
315
  assert p.use_whitened_svd is False
316
+ assert p.direction_method == "diff_means"
317
 
318
 
319
  # ---------------------------------------------------------------------------
 
342
 
343
  def test_format_includes_derived_config(self, insights):
344
  insights.recommended_n_directions = 6
345
+ insights.recommended_direction_method = "svd"
346
  insights.recommended_regularization = 0.2
347
  insights.recommended_refinement_passes = 3
348
  text = InformedAbliterationPipeline.format_insights(insights)
349
  assert "n_directions: 6" in text
350
+ assert "direction_method: svd" in text
351
  assert "regularization: 0.2" in text
352
  assert "refinement_passes: 3" in text
353
 
 
396
  model_name="test",
397
  on_log=lambda m: None,
398
  )
399
+ # Very high dimensionality → multi-direction SVD, capped at 8
400
  p._insights.cone_is_polyhedral = True
401
  p._insights.cone_dimensionality = 10.0
402
  p._derive_configuration()
403
  assert p.n_directions <= 8 # capped
404
+ assert p.direction_method == "svd"
405
 
406
+ # Very low dimensionality → single diff-of-means
407
  p._insights.cone_is_polyhedral = False
408
  p._insights.cone_dimensionality = 0.1
409
  p._derive_configuration()
410
+ assert p.n_directions == 1
411
+ assert p.direction_method == "diff_means"
tests/test_leace.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for LEACE (LEAst-squares Concept Erasure) direction extraction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from obliteratus.analysis.leace import LEACEExtractor, LEACEResult
9
+
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Fixtures
13
+ # ---------------------------------------------------------------------------
14
+
15
+ @pytest.fixture
16
+ def extractor():
17
+ return LEACEExtractor(regularization_eps=1e-4)
18
+
19
+
20
+ @pytest.fixture
21
+ def separable_data():
22
+ """Generate clearly separable harmful/harmless activations."""
23
+ torch.manual_seed(42)
24
+ d = 64
25
+ n = 20
26
+
27
+ # Harmful activations: cluster around [1, 0, 0, ...]
28
+ harmful_dir = torch.zeros(d)
29
+ harmful_dir[0] = 1.0
30
+ harmful = [harmful_dir + 0.1 * torch.randn(d) for _ in range(n)]
31
+
32
+ # Harmless activations: cluster around [-1, 0, 0, ...]
33
+ harmless = [-harmful_dir + 0.1 * torch.randn(d) for _ in range(n)]
34
+
35
+ return harmful, harmless
36
+
37
+
38
+ @pytest.fixture
39
+ def isotropic_data():
40
+ """Data where classes differ only in mean, with isotropic variance."""
41
+ torch.manual_seed(123)
42
+ d = 32
43
+ n = 30
44
+
45
+ direction = torch.randn(d)
46
+ direction = direction / direction.norm()
47
+
48
+ harmful = [direction * 2.0 + torch.randn(d) for _ in range(n)]
49
+ harmless = [-direction * 2.0 + torch.randn(d) for _ in range(n)]
50
+
51
+ return harmful, harmless, direction
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # LEACEResult
56
+ # ---------------------------------------------------------------------------
57
+
58
+ class TestLEACEResult:
59
+ def test_result_fields(self, extractor, separable_data):
60
+ harmful, harmless = separable_data
61
+ result = extractor.extract(harmful, harmless, layer_idx=5)
62
+
63
+ assert isinstance(result, LEACEResult)
64
+ assert result.layer_idx == 5
65
+ assert result.direction.shape == (64,)
66
+ assert result.generalized_eigenvalue > 0
67
+ assert result.within_class_condition > 0
68
+ assert result.mean_diff_norm > 0
69
+ assert result.erasure_loss >= 0
70
+
71
+ def test_direction_is_unit_vector(self, extractor, separable_data):
72
+ harmful, harmless = separable_data
73
+ result = extractor.extract(harmful, harmless)
74
+ norm = result.direction.norm().item()
75
+ assert abs(norm - 1.0) < 1e-5
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # Direction quality
80
+ # ---------------------------------------------------------------------------
81
+
82
+ class TestDirectionQuality:
83
+ def test_finds_true_direction(self, extractor, separable_data):
84
+ """LEACE should find a direction aligned with the true separation axis."""
85
+ harmful, harmless = separable_data
86
+ result = extractor.extract(harmful, harmless)
87
+
88
+ # True direction is [1, 0, 0, ...]
89
+ true_dir = torch.zeros(64)
90
+ true_dir[0] = 1.0
91
+
92
+ cosine = (result.direction @ true_dir).abs().item()
93
+ # With 20 samples in 64 dims, some noise is expected
94
+ assert cosine > 0.5, f"LEACE direction not aligned with true direction: {cosine}"
95
+
96
+ def test_isotropic_matches_diff_of_means(self, extractor, isotropic_data):
97
+ """With isotropic noise, LEACE should roughly match diff-of-means."""
98
+ harmful, harmless, true_dir = isotropic_data
99
+ result = extractor.extract(harmful, harmless)
100
+
101
+ # Diff of means
102
+ diff = torch.stack(harmful).mean(0) - torch.stack(harmless).mean(0)
103
+ diff_normalized = diff / diff.norm()
104
+
105
+ cosine = (result.direction @ diff_normalized).abs().item()
106
+ # With finite samples and regularization, some deviation is expected
107
+ assert cosine > 0.5
108
+
109
+ def test_leace_differs_from_diff_means_with_anisotropic_noise(self):
110
+ """With anisotropic noise, LEACE should find a better direction than diff-of-means."""
111
+ torch.manual_seed(77)
112
+ d = 64
113
+ n = 50
114
+
115
+ # True refusal direction
116
+ true_dir = torch.zeros(d)
117
+ true_dir[0] = 1.0
118
+
119
+ # Add anisotropic noise: high variance in dim 1 (NOT the refusal direction)
120
+ noise_scale = torch.ones(d) * 0.1
121
+ noise_scale[1] = 5.0 # Rogue dimension
122
+
123
+ harmful = [true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)]
124
+ harmless = [-true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)]
125
+
126
+ extractor = LEACEExtractor()
127
+ result = extractor.extract(harmful, harmless)
128
+
129
+ cosine_to_true = (result.direction @ true_dir).abs().item()
130
+ # LEACE should still find the true direction, not be distracted by rogue dim
131
+ assert cosine_to_true > 0.5, f"LEACE distracted by rogue dimension: {cosine_to_true}"
132
+
133
+
134
+ # ---------------------------------------------------------------------------
135
+ # Comparison with diff-of-means
136
+ # ---------------------------------------------------------------------------
137
+
138
+ class TestCompareWithDiffOfMeans:
139
+ def test_comparison_output(self, extractor, separable_data):
140
+ harmful, harmless = separable_data
141
+ result = extractor.extract(harmful, harmless)
142
+
143
+ harmful_mean = torch.stack(harmful).mean(0)
144
+ harmless_mean = torch.stack(harmless).mean(0)
145
+
146
+ comparison = LEACEExtractor.compare_with_diff_of_means(
147
+ result, harmful_mean, harmless_mean,
148
+ )
149
+
150
+ assert "cosine_similarity" in comparison
151
+ assert "leace_eigenvalue" in comparison
152
+ assert "leace_erasure_loss" in comparison
153
+ assert "within_class_condition" in comparison
154
+ assert "mean_diff_norm" in comparison
155
+ assert 0 <= comparison["cosine_similarity"] <= 1.0
156
+
157
+
158
+ # ---------------------------------------------------------------------------
159
+ # Multi-layer extraction
160
+ # ---------------------------------------------------------------------------
161
+
162
+ class TestMultiLayer:
163
+ def test_extract_all_layers(self, extractor):
164
+ torch.manual_seed(42)
165
+ d = 32
166
+ n = 15
167
+
168
+ harmful_acts = {}
169
+ harmless_acts = {}
170
+ for layer in [0, 1, 2, 5]:
171
+ harmful_acts[layer] = [torch.randn(d) + 0.5 for _ in range(n)]
172
+ harmless_acts[layer] = [torch.randn(d) - 0.5 for _ in range(n)]
173
+
174
+ results = extractor.extract_all_layers(harmful_acts, harmless_acts)
175
+
176
+ assert set(results.keys()) == {0, 1, 2, 5}
177
+ for idx, result in results.items():
178
+ assert result.layer_idx == idx
179
+ assert result.direction.shape == (d,)
180
+
181
+
182
+ # ---------------------------------------------------------------------------
183
+ # Edge cases
184
+ # ---------------------------------------------------------------------------
185
+
186
+ class TestEdgeCases:
187
+ def test_single_sample(self, extractor):
188
+ """Should handle single sample per class gracefully."""
189
+ d = 32
190
+ harmful = [torch.randn(d)]
191
+ harmless = [torch.randn(d)]
192
+
193
+ result = extractor.extract(harmful, harmless)
194
+ assert result.direction.shape == (d,)
195
+ assert torch.isfinite(result.direction).all()
196
+
197
+ def test_identical_activations(self, extractor):
198
+ """Should handle case where harmful == harmless."""
199
+ d = 32
200
+ x = torch.randn(d)
201
+ harmful = [x.clone() for _ in range(5)]
202
+ harmless = [x.clone() for _ in range(5)]
203
+
204
+ result = extractor.extract(harmful, harmless)
205
+ assert result.direction.shape == (d,)
206
+ # Direction norm should be ~0 or direction is a fallback
207
+ assert torch.isfinite(result.direction).all()
208
+
209
+ def test_3d_input_squeezed(self, extractor):
210
+ """Should handle (n, 1, d) shaped inputs."""
211
+ d = 32
212
+ harmful = [torch.randn(1, d) for _ in range(10)]
213
+ harmless = [torch.randn(1, d) for _ in range(10)]
214
+
215
+ result = extractor.extract(harmful, harmless)
216
+ assert result.direction.shape == (d,)
217
+
218
+ def test_shrinkage(self):
219
+ """Shrinkage should produce valid results."""
220
+ torch.manual_seed(42)
221
+ d = 64
222
+ n = 10 # n < d → need shrinkage
223
+
224
+ harmful = [torch.randn(d) + 0.3 for _ in range(n)]
225
+ harmless = [torch.randn(d) - 0.3 for _ in range(n)]
226
+
227
+ extractor = LEACEExtractor(shrinkage=0.5)
228
+ result = extractor.extract(harmful, harmless)
229
+ assert result.direction.shape == (d,)
230
+ assert torch.isfinite(result.direction).all()