Spaces:
Running on Zero
Running on Zero
Upload 135 files
Browse files- app.py +15 -4
- obliteratus/abliterate.py +100 -15
- obliteratus/analysis/leace.py +239 -0
- obliteratus/bayesian_optimizer.py +147 -90
- obliteratus/cli.py +7 -1
- obliteratus/informed_pipeline.py +380 -56
- scripts/run_benchmark_remote.sh +10 -4
- tests/test_informed_pipeline.py +38 -12
- tests/test_leace.py +230 -0
app.py
CHANGED
|
@@ -316,6 +316,7 @@ def _get_preset_defaults(method_display: str):
|
|
| 316 |
cfg = _PRESET_CONFIGS.get(method_key, _PRESET_CONFIGS["advanced"])
|
| 317 |
return {
|
| 318 |
"n_directions": cfg.get("n_directions", 4),
|
|
|
|
| 319 |
"regularization": cfg.get("regularization", 0.3),
|
| 320 |
"refinement_passes": cfg.get("refinement_passes", 2),
|
| 321 |
"norm_preserve": cfg.get("norm_preserve", True),
|
|
@@ -348,6 +349,7 @@ def _on_method_change(method_display: str):
|
|
| 348 |
d = _get_preset_defaults(method_display)
|
| 349 |
return (
|
| 350 |
d["n_directions"],
|
|
|
|
| 351 |
d["regularization"],
|
| 352 |
d["refinement_passes"],
|
| 353 |
d["reflection_strength"],
|
|
@@ -1731,8 +1733,9 @@ def _format_multi_model_results(results: list[dict], context: dict | None = None
|
|
| 1731 |
def obliterate(model_choice: str, method_choice: str,
|
| 1732 |
prompt_volume_choice: str, dataset_source_choice: str,
|
| 1733 |
custom_harmful: str, custom_harmless: str,
|
| 1734 |
-
# Advanced params (sliders)
|
| 1735 |
-
adv_n_directions: int,
|
|
|
|
| 1736 |
adv_refinement_passes: int, adv_reflection_strength: float,
|
| 1737 |
adv_embed_regularization: float, adv_steering_strength: float,
|
| 1738 |
adv_transplant_blend: float,
|
|
@@ -1906,6 +1909,7 @@ def obliterate(model_choice: str, method_choice: str,
|
|
| 1906 |
on_log=on_log,
|
| 1907 |
# Advanced overrides from UI
|
| 1908 |
n_directions=int(adv_n_directions),
|
|
|
|
| 1909 |
regularization=float(adv_regularization),
|
| 1910 |
refinement_passes=int(adv_refinement_passes),
|
| 1911 |
norm_preserve=adv_norm_preserve,
|
|
@@ -3930,7 +3934,13 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
|
|
| 3930 |
with gr.Row():
|
| 3931 |
adv_n_directions = gr.Slider(
|
| 3932 |
1, 8, value=_defaults["n_directions"], step=1,
|
| 3933 |
-
label="Directions", info="Number of refusal directions to extract
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3934 |
)
|
| 3935 |
adv_regularization = gr.Slider(
|
| 3936 |
0.0, 1.0, value=_defaults["regularization"], step=0.05,
|
|
@@ -3999,7 +4009,8 @@ with gr.Blocks(theme=THEME, css=CSS, js=_JS, title="OBLITERATUS", fill_height=Tr
|
|
| 3999 |
|
| 4000 |
# List of all advanced controls (order must match _on_method_change return)
|
| 4001 |
_adv_controls = [
|
| 4002 |
-
adv_n_directions,
|
|
|
|
| 4003 |
adv_reflection_strength, adv_embed_regularization,
|
| 4004 |
adv_steering_strength, adv_transplant_blend,
|
| 4005 |
adv_spectral_bands, adv_spectral_threshold,
|
|
|
|
| 316 |
cfg = _PRESET_CONFIGS.get(method_key, _PRESET_CONFIGS["advanced"])
|
| 317 |
return {
|
| 318 |
"n_directions": cfg.get("n_directions", 4),
|
| 319 |
+
"direction_method": cfg.get("direction_method", "svd"),
|
| 320 |
"regularization": cfg.get("regularization", 0.3),
|
| 321 |
"refinement_passes": cfg.get("refinement_passes", 2),
|
| 322 |
"norm_preserve": cfg.get("norm_preserve", True),
|
|
|
|
| 349 |
d = _get_preset_defaults(method_display)
|
| 350 |
return (
|
| 351 |
d["n_directions"],
|
| 352 |
+
d["direction_method"],
|
| 353 |
d["regularization"],
|
| 354 |
d["refinement_passes"],
|
| 355 |
d["reflection_strength"],
|
|
|
|
| 1733 |
def obliterate(model_choice: str, method_choice: str,
|
| 1734 |
prompt_volume_choice: str, dataset_source_choice: str,
|
| 1735 |
custom_harmful: str, custom_harmless: str,
|
| 1736 |
+
# Advanced params (sliders + radio)
|
| 1737 |
+
adv_n_directions: int, adv_direction_method: str,
|
| 1738 |
+
adv_regularization: float,
|
| 1739 |
adv_refinement_passes: int, adv_reflection_strength: float,
|
| 1740 |
adv_embed_regularization: float, adv_steering_strength: float,
|
| 1741 |
adv_transplant_blend: float,
|
|
|
|
| 1909 |
on_log=on_log,
|
| 1910 |
# Advanced overrides from UI
|
| 1911 |
n_directions=int(adv_n_directions),
|
| 1912 |
+
direction_method=adv_direction_method,
|
| 1913 |
regularization=float(adv_regularization),
|
| 1914 |
refinement_passes=int(adv_refinement_passes),
|
| 1915 |
norm_preserve=adv_norm_preserve,
|
|
|
|
| 3934 |
with gr.Row():
|
| 3935 |
adv_n_directions = gr.Slider(
|
| 3936 |
1, 8, value=_defaults["n_directions"], step=1,
|
| 3937 |
+
label="Directions", info="Number of refusal directions to extract",
|
| 3938 |
+
)
|
| 3939 |
+
adv_direction_method = gr.Radio(
|
| 3940 |
+
choices=["diff_means", "svd", "leace"],
|
| 3941 |
+
value=_defaults["direction_method"],
|
| 3942 |
+
label="Direction Method",
|
| 3943 |
+
info="diff_means: simple & robust, svd: multi-direction, leace: optimal erasure",
|
| 3944 |
)
|
| 3945 |
adv_regularization = gr.Slider(
|
| 3946 |
0.0, 1.0, value=_defaults["regularization"], step=0.05,
|
|
|
|
| 4009 |
|
| 4010 |
# List of all advanced controls (order must match _on_method_change return)
|
| 4011 |
_adv_controls = [
|
| 4012 |
+
adv_n_directions, adv_direction_method,
|
| 4013 |
+
adv_regularization, adv_refinement_passes,
|
| 4014 |
adv_reflection_strength, adv_embed_regularization,
|
| 4015 |
adv_steering_strength, adv_transplant_blend,
|
| 4016 |
adv_spectral_bands, adv_spectral_threshold,
|
obliteratus/abliterate.py
CHANGED
|
@@ -63,6 +63,7 @@ METHODS = {
|
|
| 63 |
"label": "Basic (Arditi et al.)",
|
| 64 |
"description": "Single refusal direction via difference-in-means",
|
| 65 |
"n_directions": 1,
|
|
|
|
| 66 |
"norm_preserve": False,
|
| 67 |
"regularization": 0.0,
|
| 68 |
"refinement_passes": 1,
|
|
@@ -75,6 +76,7 @@ METHODS = {
|
|
| 75 |
"label": "Advanced (Multi-direction + Norm-preserving)",
|
| 76 |
"description": "SVD-based multi-direction extraction with norm preservation",
|
| 77 |
"n_directions": 4,
|
|
|
|
| 78 |
"norm_preserve": True,
|
| 79 |
"regularization": 0.3,
|
| 80 |
"embed_regularization": 0.5,
|
|
@@ -97,6 +99,7 @@ METHODS = {
|
|
| 97 |
"Zero regularization for maximum refusal removal."
|
| 98 |
),
|
| 99 |
"n_directions": 8,
|
|
|
|
| 100 |
"norm_preserve": True,
|
| 101 |
"regularization": 0.0,
|
| 102 |
"refinement_passes": 3,
|
|
@@ -124,6 +127,7 @@ METHODS = {
|
|
| 124 |
"separating trained-in refusal patterns from per-layer artifacts."
|
| 125 |
),
|
| 126 |
"n_directions": 6,
|
|
|
|
| 127 |
"norm_preserve": True,
|
| 128 |
"regularization": 0.0,
|
| 129 |
"refinement_passes": 2,
|
|
@@ -146,25 +150,31 @@ METHODS = {
|
|
| 146 |
"Uses InformedAbliterationPipeline for the full feedback loop. "
|
| 147 |
"Auto-detects alignment method (DPO/RLHF/CAI/SFT), maps concept "
|
| 148 |
"cone geometry, performs cluster-aware layer selection, and gates "
|
| 149 |
-
"projection by safety-capability entanglement.
|
| 150 |
-
"
|
| 151 |
-
"
|
| 152 |
),
|
| 153 |
-
"n_directions":
|
|
|
|
| 154 |
"norm_preserve": True,
|
| 155 |
"regularization": 0.0,
|
| 156 |
"refinement_passes": 2,
|
| 157 |
"project_biases": True,
|
| 158 |
"use_chat_template": True,
|
| 159 |
-
"use_whitened_svd":
|
| 160 |
"true_iterative_refinement": True,
|
| 161 |
"use_jailbreak_contrast": False,
|
| 162 |
-
"layer_adaptive_strength":
|
| 163 |
"safety_neuron_masking": False,
|
| 164 |
"per_expert_directions": False,
|
| 165 |
"attention_head_surgery": False,
|
| 166 |
"use_sae_features": False,
|
| 167 |
-
"use_wasserstein_optimal":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
},
|
| 169 |
"surgical": {
|
| 170 |
"label": "Surgical (Full SOTA MoE-Aware)",
|
|
@@ -176,6 +186,7 @@ METHODS = {
|
|
| 176 |
"minimizing capability damage via precision targeting."
|
| 177 |
),
|
| 178 |
"n_directions": 8,
|
|
|
|
| 179 |
"norm_preserve": True,
|
| 180 |
"regularization": 0.0,
|
| 181 |
"refinement_passes": 2,
|
|
@@ -204,6 +215,7 @@ METHODS = {
|
|
| 204 |
"techniques plus the inversion layer."
|
| 205 |
),
|
| 206 |
"n_directions": 8,
|
|
|
|
| 207 |
"norm_preserve": True,
|
| 208 |
"regularization": 0.0,
|
| 209 |
"refinement_passes": 2,
|
|
@@ -234,6 +246,7 @@ METHODS = {
|
|
| 234 |
"Best for maximizing quality when compute budget allows ~50 trials."
|
| 235 |
),
|
| 236 |
"n_directions": 4,
|
|
|
|
| 237 |
"norm_preserve": True,
|
| 238 |
"regularization": 0.0,
|
| 239 |
"refinement_passes": 1,
|
|
@@ -275,6 +288,7 @@ METHODS = {
|
|
| 275 |
"runtime overhead except lightweight steering hooks."
|
| 276 |
),
|
| 277 |
"n_directions": 4,
|
|
|
|
| 278 |
"norm_preserve": True,
|
| 279 |
"regularization": 0.0,
|
| 280 |
"refinement_passes": 2,
|
|
@@ -320,6 +334,7 @@ METHODS = {
|
|
| 320 |
"abliterated models were created with."
|
| 321 |
),
|
| 322 |
"n_directions": 1,
|
|
|
|
| 323 |
"norm_preserve": False,
|
| 324 |
"regularization": 0.0,
|
| 325 |
"refinement_passes": 1,
|
|
@@ -347,6 +362,7 @@ METHODS = {
|
|
| 347 |
"whitened SVD, no iterative refinement."
|
| 348 |
),
|
| 349 |
"n_directions": 4,
|
|
|
|
| 350 |
"norm_preserve": False,
|
| 351 |
# Ridge alpha=0.3 → effective reg = alpha/(1+alpha) = 0.3/1.3 ≈ 0.231
|
| 352 |
# For orthonormal V: P_V^alpha = 1/(1+alpha) * VV^T = 0.769 * VV^T
|
|
@@ -379,6 +395,7 @@ METHODS = {
|
|
| 379 |
"over the (refusal_rate, KL_divergence) frontier."
|
| 380 |
),
|
| 381 |
"n_directions": 2,
|
|
|
|
| 382 |
"norm_preserve": True,
|
| 383 |
"regularization": 0.0,
|
| 384 |
"refinement_passes": 1,
|
|
@@ -414,6 +431,7 @@ METHODS = {
|
|
| 414 |
"boundary rather than the statistical activation difference."
|
| 415 |
),
|
| 416 |
"n_directions": 4,
|
|
|
|
| 417 |
"norm_preserve": True,
|
| 418 |
"regularization": 0.0,
|
| 419 |
"refinement_passes": 1,
|
|
@@ -566,6 +584,7 @@ class AbliterationPipeline:
|
|
| 566 |
hub_token: str | None = None,
|
| 567 |
hub_community_org: str | None = None,
|
| 568 |
n_directions: int | None = None,
|
|
|
|
| 569 |
norm_preserve: bool | None = None,
|
| 570 |
regularization: float | None = None,
|
| 571 |
refinement_passes: int | None = None,
|
|
@@ -659,6 +678,7 @@ class AbliterationPipeline:
|
|
| 659 |
method_cfg = METHODS[method]
|
| 660 |
self.method = method
|
| 661 |
self.n_directions = n_directions if n_directions is not None else method_cfg["n_directions"]
|
|
|
|
| 662 |
self.norm_preserve = norm_preserve if norm_preserve is not None else method_cfg["norm_preserve"]
|
| 663 |
self.regularization = regularization if regularization is not None else method_cfg["regularization"]
|
| 664 |
self.refinement_passes = refinement_passes if refinement_passes is not None else method_cfg["refinement_passes"]
|
|
@@ -936,7 +956,7 @@ class AbliterationPipeline:
|
|
| 936 |
self.log(f"Loading model: {self.model_name}")
|
| 937 |
self.log(f"Device: {self.device} | Dtype: {self.dtype}")
|
| 938 |
self.log(f"Method: {method_label}")
|
| 939 |
-
self.log(f" Directions: {self.n_directions} | Norm-preserve: {self.norm_preserve}")
|
| 940 |
self.log(f" Regularization: {self.regularization} | Refinement passes: {self.refinement_passes}")
|
| 941 |
|
| 942 |
self.handle = load_model(
|
|
@@ -1400,18 +1420,26 @@ class AbliterationPipeline:
|
|
| 1400 |
else:
|
| 1401 |
max_length = 384 if collect_multi_pos else 256
|
| 1402 |
free_gb = dev.get_total_free_gb()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1403 |
if dev.is_gpu_available():
|
| 1404 |
-
if self.max_seq_length is None and free_gb <
|
| 1405 |
max_length = 64
|
| 1406 |
-
self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
|
| 1407 |
-
elif self.max_seq_length is None and free_gb <
|
| 1408 |
max_length = 128
|
| 1409 |
-
self.log(f" Tight GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
|
| 1410 |
|
| 1411 |
device = self._get_model_device(model)
|
| 1412 |
|
| 1413 |
# Batch prompts for throughput — hooks unbatch per-prompt activations
|
| 1414 |
-
batch_size = 16 if free_gb >
|
| 1415 |
# Left-pad so position -1 is always the last real token in every batch element
|
| 1416 |
orig_padding_side = getattr(tokenizer, "padding_side", "right")
|
| 1417 |
if batch_size > 1:
|
|
@@ -1498,9 +1526,16 @@ class AbliterationPipeline:
|
|
| 1498 |
wasserstein_extractor = WassersteinOptimalExtractor()
|
| 1499 |
self.log("Using Wasserstein-optimal direction extraction (cost-minimizing GEP)")
|
| 1500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
# Optionally use whitened SVD for cleaner direction extraction
|
| 1502 |
whitened_extractor = None
|
| 1503 |
-
if self.use_whitened_svd and n_dirs > 1 and not self.use_wasserstein_optimal:
|
| 1504 |
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
|
| 1505 |
whitened_extractor = WhitenedSVDExtractor()
|
| 1506 |
self.log("Using whitened SVD (covariance-normalized) for direction extraction")
|
|
@@ -1547,6 +1582,30 @@ class AbliterationPipeline:
|
|
| 1547 |
if idx < 5:
|
| 1548 |
self.log(f" layer {idx}: Wasserstein extraction failed ({e}), falling back to SVD")
|
| 1549 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1550 |
if n_dirs == 1:
|
| 1551 |
# Classic single-direction: difference-in-means
|
| 1552 |
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
|
|
@@ -3589,9 +3648,18 @@ class AbliterationPipeline:
|
|
| 3589 |
except Exception:
|
| 3590 |
pass
|
| 3591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3592 |
# Use whitened SVD when enabled (matching main _distill)
|
| 3593 |
whitened_extractor = None
|
| 3594 |
-
if self.use_whitened_svd and n_dirs > 1 and wasserstein_extractor is None:
|
| 3595 |
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
|
| 3596 |
whitened_extractor = WhitenedSVDExtractor()
|
| 3597 |
|
|
@@ -3624,6 +3692,22 @@ class AbliterationPipeline:
|
|
| 3624 |
except Exception:
|
| 3625 |
pass # Fall through to SVD
|
| 3626 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3627 |
if n_dirs == 1:
|
| 3628 |
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
|
| 3629 |
norm = diff.norm()
|
|
@@ -5741,6 +5825,7 @@ class AbliterationPipeline:
|
|
| 5741 |
"method": self.method,
|
| 5742 |
"method_config": {
|
| 5743 |
"n_directions": self.n_directions,
|
|
|
|
| 5744 |
"norm_preserve": self.norm_preserve,
|
| 5745 |
"regularization": self.regularization,
|
| 5746 |
"refinement_passes": self.refinement_passes,
|
|
|
|
| 63 |
"label": "Basic (Arditi et al.)",
|
| 64 |
"description": "Single refusal direction via difference-in-means",
|
| 65 |
"n_directions": 1,
|
| 66 |
+
"direction_method": "diff_means",
|
| 67 |
"norm_preserve": False,
|
| 68 |
"regularization": 0.0,
|
| 69 |
"refinement_passes": 1,
|
|
|
|
| 76 |
"label": "Advanced (Multi-direction + Norm-preserving)",
|
| 77 |
"description": "SVD-based multi-direction extraction with norm preservation",
|
| 78 |
"n_directions": 4,
|
| 79 |
+
"direction_method": "svd",
|
| 80 |
"norm_preserve": True,
|
| 81 |
"regularization": 0.3,
|
| 82 |
"embed_regularization": 0.5,
|
|
|
|
| 99 |
"Zero regularization for maximum refusal removal."
|
| 100 |
),
|
| 101 |
"n_directions": 8,
|
| 102 |
+
"direction_method": "svd",
|
| 103 |
"norm_preserve": True,
|
| 104 |
"regularization": 0.0,
|
| 105 |
"refinement_passes": 3,
|
|
|
|
| 127 |
"separating trained-in refusal patterns from per-layer artifacts."
|
| 128 |
),
|
| 129 |
"n_directions": 6,
|
| 130 |
+
"direction_method": "svd",
|
| 131 |
"norm_preserve": True,
|
| 132 |
"regularization": 0.0,
|
| 133 |
"refinement_passes": 2,
|
|
|
|
| 150 |
"Uses InformedAbliterationPipeline for the full feedback loop. "
|
| 151 |
"Auto-detects alignment method (DPO/RLHF/CAI/SFT), maps concept "
|
| 152 |
"cone geometry, performs cluster-aware layer selection, and gates "
|
| 153 |
+
"projection by safety-capability entanglement. Defaults to single "
|
| 154 |
+
"diff-of-means direction + Bayesian optimization (Heretic-style). "
|
| 155 |
+
"LEACE available via direction_method='leace'."
|
| 156 |
),
|
| 157 |
+
"n_directions": 1,
|
| 158 |
+
"direction_method": "diff_means",
|
| 159 |
"norm_preserve": True,
|
| 160 |
"regularization": 0.0,
|
| 161 |
"refinement_passes": 2,
|
| 162 |
"project_biases": True,
|
| 163 |
"use_chat_template": True,
|
| 164 |
+
"use_whitened_svd": False,
|
| 165 |
"true_iterative_refinement": True,
|
| 166 |
"use_jailbreak_contrast": False,
|
| 167 |
+
"layer_adaptive_strength": True,
|
| 168 |
"safety_neuron_masking": False,
|
| 169 |
"per_expert_directions": False,
|
| 170 |
"attention_head_surgery": False,
|
| 171 |
"use_sae_features": False,
|
| 172 |
+
"use_wasserstein_optimal": False,
|
| 173 |
+
"use_kl_optimization": True,
|
| 174 |
+
"kl_budget": 0.5,
|
| 175 |
+
"float_layer_interpolation": True,
|
| 176 |
+
"winsorize_activations": True,
|
| 177 |
+
"winsorize_percentile": 0.01,
|
| 178 |
},
|
| 179 |
"surgical": {
|
| 180 |
"label": "Surgical (Full SOTA MoE-Aware)",
|
|
|
|
| 186 |
"minimizing capability damage via precision targeting."
|
| 187 |
),
|
| 188 |
"n_directions": 8,
|
| 189 |
+
"direction_method": "svd",
|
| 190 |
"norm_preserve": True,
|
| 191 |
"regularization": 0.0,
|
| 192 |
"refinement_passes": 2,
|
|
|
|
| 215 |
"techniques plus the inversion layer."
|
| 216 |
),
|
| 217 |
"n_directions": 8,
|
| 218 |
+
"direction_method": "svd",
|
| 219 |
"norm_preserve": True,
|
| 220 |
"regularization": 0.0,
|
| 221 |
"refinement_passes": 2,
|
|
|
|
| 246 |
"Best for maximizing quality when compute budget allows ~50 trials."
|
| 247 |
),
|
| 248 |
"n_directions": 4,
|
| 249 |
+
"direction_method": "svd",
|
| 250 |
"norm_preserve": True,
|
| 251 |
"regularization": 0.0,
|
| 252 |
"refinement_passes": 1,
|
|
|
|
| 288 |
"runtime overhead except lightweight steering hooks."
|
| 289 |
),
|
| 290 |
"n_directions": 4,
|
| 291 |
+
"direction_method": "svd",
|
| 292 |
"norm_preserve": True,
|
| 293 |
"regularization": 0.0,
|
| 294 |
"refinement_passes": 2,
|
|
|
|
| 334 |
"abliterated models were created with."
|
| 335 |
),
|
| 336 |
"n_directions": 1,
|
| 337 |
+
"direction_method": "diff_means",
|
| 338 |
"norm_preserve": False,
|
| 339 |
"regularization": 0.0,
|
| 340 |
"refinement_passes": 1,
|
|
|
|
| 362 |
"whitened SVD, no iterative refinement."
|
| 363 |
),
|
| 364 |
"n_directions": 4,
|
| 365 |
+
"direction_method": "svd",
|
| 366 |
"norm_preserve": False,
|
| 367 |
# Ridge alpha=0.3 → effective reg = alpha/(1+alpha) = 0.3/1.3 ≈ 0.231
|
| 368 |
# For orthonormal V: P_V^alpha = 1/(1+alpha) * VV^T = 0.769 * VV^T
|
|
|
|
| 395 |
"over the (refusal_rate, KL_divergence) frontier."
|
| 396 |
),
|
| 397 |
"n_directions": 2,
|
| 398 |
+
"direction_method": "diff_means",
|
| 399 |
"norm_preserve": True,
|
| 400 |
"regularization": 0.0,
|
| 401 |
"refinement_passes": 1,
|
|
|
|
| 431 |
"boundary rather than the statistical activation difference."
|
| 432 |
),
|
| 433 |
"n_directions": 4,
|
| 434 |
+
"direction_method": "svd",
|
| 435 |
"norm_preserve": True,
|
| 436 |
"regularization": 0.0,
|
| 437 |
"refinement_passes": 1,
|
|
|
|
| 584 |
hub_token: str | None = None,
|
| 585 |
hub_community_org: str | None = None,
|
| 586 |
n_directions: int | None = None,
|
| 587 |
+
direction_method: str | None = None,
|
| 588 |
norm_preserve: bool | None = None,
|
| 589 |
regularization: float | None = None,
|
| 590 |
refinement_passes: int | None = None,
|
|
|
|
| 678 |
method_cfg = METHODS[method]
|
| 679 |
self.method = method
|
| 680 |
self.n_directions = n_directions if n_directions is not None else method_cfg["n_directions"]
|
| 681 |
+
self.direction_method = direction_method if direction_method is not None else method_cfg.get("direction_method", "svd")
|
| 682 |
self.norm_preserve = norm_preserve if norm_preserve is not None else method_cfg["norm_preserve"]
|
| 683 |
self.regularization = regularization if regularization is not None else method_cfg["regularization"]
|
| 684 |
self.refinement_passes = refinement_passes if refinement_passes is not None else method_cfg["refinement_passes"]
|
|
|
|
| 956 |
self.log(f"Loading model: {self.model_name}")
|
| 957 |
self.log(f"Device: {self.device} | Dtype: {self.dtype}")
|
| 958 |
self.log(f"Method: {method_label}")
|
| 959 |
+
self.log(f" Directions: {self.n_directions} ({self.direction_method}) | Norm-preserve: {self.norm_preserve}")
|
| 960 |
self.log(f" Regularization: {self.regularization} | Refinement passes: {self.refinement_passes}")
|
| 961 |
|
| 962 |
self.handle = load_model(
|
|
|
|
| 1420 |
else:
|
| 1421 |
max_length = 384 if collect_multi_pos else 256
|
| 1422 |
free_gb = dev.get_total_free_gb()
|
| 1423 |
+
# Scale memory thresholds by model size — a 1.2B model needs far
|
| 1424 |
+
# less KV-cache memory per token than a 7B model. Baseline
|
| 1425 |
+
# thresholds (4 / 2 GB) were tuned for 7B (hidden=4096, layers=32).
|
| 1426 |
+
_h = self.handle.hidden_size if self.handle else 4096
|
| 1427 |
+
_l = n_layers if n_layers else 32
|
| 1428 |
+
_mem_scale = (_h / 4096) * (_l / 32)
|
| 1429 |
+
_tight_gb = max(4.0 * _mem_scale, 0.5)
|
| 1430 |
+
_low_gb = max(2.0 * _mem_scale, 0.25)
|
| 1431 |
if dev.is_gpu_available():
|
| 1432 |
+
if self.max_seq_length is None and free_gb < _low_gb:
|
| 1433 |
max_length = 64
|
| 1434 |
+
self.log(f" Low GPU memory ({free_gb:.1f} GB free, threshold {_low_gb:.1f} GB), using max_length={max_length}")
|
| 1435 |
+
elif self.max_seq_length is None and free_gb < _tight_gb:
|
| 1436 |
max_length = 128
|
| 1437 |
+
self.log(f" Tight GPU memory ({free_gb:.1f} GB free, threshold {_tight_gb:.1f} GB), using max_length={max_length}")
|
| 1438 |
|
| 1439 |
device = self._get_model_device(model)
|
| 1440 |
|
| 1441 |
# Batch prompts for throughput — hooks unbatch per-prompt activations
|
| 1442 |
+
batch_size = 16 if free_gb > _tight_gb else 8 if free_gb > _low_gb else 1
|
| 1443 |
# Left-pad so position -1 is always the last real token in every batch element
|
| 1444 |
orig_padding_side = getattr(tokenizer, "padding_side", "right")
|
| 1445 |
if batch_size > 1:
|
|
|
|
| 1526 |
wasserstein_extractor = WassersteinOptimalExtractor()
|
| 1527 |
self.log("Using Wasserstein-optimal direction extraction (cost-minimizing GEP)")
|
| 1528 |
|
| 1529 |
+
# Optionally use LEACE for theoretically optimal concept erasure
|
| 1530 |
+
leace_extractor = None
|
| 1531 |
+
if self.direction_method == "leace":
|
| 1532 |
+
from obliteratus.analysis.leace import LEACEExtractor
|
| 1533 |
+
leace_extractor = LEACEExtractor()
|
| 1534 |
+
self.log("Using LEACE (closed-form optimal concept erasure) for direction extraction")
|
| 1535 |
+
|
| 1536 |
# Optionally use whitened SVD for cleaner direction extraction
|
| 1537 |
whitened_extractor = None
|
| 1538 |
+
if self.use_whitened_svd and n_dirs > 1 and not self.use_wasserstein_optimal and leace_extractor is None:
|
| 1539 |
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
|
| 1540 |
whitened_extractor = WhitenedSVDExtractor()
|
| 1541 |
self.log("Using whitened SVD (covariance-normalized) for direction extraction")
|
|
|
|
| 1582 |
if idx < 5:
|
| 1583 |
self.log(f" layer {idx}: Wasserstein extraction failed ({e}), falling back to SVD")
|
| 1584 |
|
| 1585 |
+
if leace_extractor is not None:
|
| 1586 |
+
# LEACE: closed-form optimal concept erasure direction
|
| 1587 |
+
if idx in self._harmful_acts and idx in self._harmless_acts:
|
| 1588 |
+
try:
|
| 1589 |
+
l_result = leace_extractor.extract(
|
| 1590 |
+
self._harmful_acts[idx],
|
| 1591 |
+
self._harmless_acts[idx],
|
| 1592 |
+
layer_idx=idx,
|
| 1593 |
+
)
|
| 1594 |
+
self.refusal_directions[idx] = l_result.direction
|
| 1595 |
+
self.refusal_subspaces[idx] = l_result.direction.unsqueeze(0)
|
| 1596 |
+
norms[idx] = l_result.generalized_eigenvalue
|
| 1597 |
+
|
| 1598 |
+
if idx < 5 or idx == n_layers - 1:
|
| 1599 |
+
self.log(
|
| 1600 |
+
f" layer {idx}: LEACE eigenvalue={l_result.generalized_eigenvalue:.4f}, "
|
| 1601 |
+
f"erasure_loss={l_result.erasure_loss:.4f}, "
|
| 1602 |
+
f"cond={l_result.within_class_condition:.0f}"
|
| 1603 |
+
)
|
| 1604 |
+
continue
|
| 1605 |
+
except Exception as e:
|
| 1606 |
+
if idx < 5:
|
| 1607 |
+
self.log(f" layer {idx}: LEACE failed ({e}), falling back to diff-of-means")
|
| 1608 |
+
|
| 1609 |
if n_dirs == 1:
|
| 1610 |
# Classic single-direction: difference-in-means
|
| 1611 |
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
|
|
|
|
| 3648 |
except Exception:
|
| 3649 |
pass
|
| 3650 |
|
| 3651 |
+
# Use LEACE when enabled (matching main _distill)
|
| 3652 |
+
leace_extractor = None
|
| 3653 |
+
if self.direction_method == "leace":
|
| 3654 |
+
try:
|
| 3655 |
+
from obliteratus.analysis.leace import LEACEExtractor
|
| 3656 |
+
leace_extractor = LEACEExtractor()
|
| 3657 |
+
except Exception:
|
| 3658 |
+
pass
|
| 3659 |
+
|
| 3660 |
# Use whitened SVD when enabled (matching main _distill)
|
| 3661 |
whitened_extractor = None
|
| 3662 |
+
if self.use_whitened_svd and n_dirs > 1 and wasserstein_extractor is None and leace_extractor is None:
|
| 3663 |
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
|
| 3664 |
whitened_extractor = WhitenedSVDExtractor()
|
| 3665 |
|
|
|
|
| 3692 |
except Exception:
|
| 3693 |
pass # Fall through to SVD
|
| 3694 |
|
| 3695 |
+
# LEACE path (matching main _distill)
|
| 3696 |
+
if leace_extractor is not None:
|
| 3697 |
+
if idx in self._harmful_acts and idx in self._harmless_acts:
|
| 3698 |
+
try:
|
| 3699 |
+
l_result = leace_extractor.extract(
|
| 3700 |
+
self._harmful_acts[idx],
|
| 3701 |
+
self._harmless_acts[idx],
|
| 3702 |
+
layer_idx=idx,
|
| 3703 |
+
)
|
| 3704 |
+
self.refusal_directions[idx] = l_result.direction
|
| 3705 |
+
self.refusal_subspaces[idx] = l_result.direction.unsqueeze(0)
|
| 3706 |
+
norms[idx] = l_result.generalized_eigenvalue
|
| 3707 |
+
continue
|
| 3708 |
+
except Exception:
|
| 3709 |
+
pass # Fall through to diff-of-means
|
| 3710 |
+
|
| 3711 |
if n_dirs == 1:
|
| 3712 |
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
|
| 3713 |
norm = diff.norm()
|
|
|
|
| 5825 |
"method": self.method,
|
| 5826 |
"method_config": {
|
| 5827 |
"n_directions": self.n_directions,
|
| 5828 |
+
"direction_method": self.direction_method,
|
| 5829 |
"norm_preserve": self.norm_preserve,
|
| 5830 |
"regularization": self.regularization,
|
| 5831 |
"refinement_passes": self.refinement_passes,
|
obliteratus/analysis/leace.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LEACE (LEAst-squares Concept Erasure) for refusal direction extraction.
|
| 2 |
+
|
| 3 |
+
Closed-form optimal concept erasure that finds the minimum-rank edit to
|
| 4 |
+
representations preventing any linear classifier from recovering the
|
| 5 |
+
concept (harmful vs harmless). Unlike SVD, LEACE produces directions
|
| 6 |
+
that are provably optimal for erasure with minimal distortion.
|
| 7 |
+
|
| 8 |
+
The key insight: instead of finding directions of maximum variance (SVD)
|
| 9 |
+
or maximum mean difference (diff-of-means), LEACE solves a constrained
|
| 10 |
+
optimization problem: find the smallest perturbation to representations
|
| 11 |
+
such that no linear probe can distinguish harmful from harmless.
|
| 12 |
+
|
| 13 |
+
Mathematical formulation (rank-1 erasure):
|
| 14 |
+
Given class-conditional means mu_0, mu_1 and within-class
|
| 15 |
+
covariance S_w:
|
| 16 |
+
1. Compute mean difference: delta = mu_1 - mu_0
|
| 17 |
+
2. Compute within-class covariance: S_w = (S_0 + S_1) / 2
|
| 18 |
+
3. Solve generalized eigenvalue problem: S_b v = lambda S_w v
|
| 19 |
+
where S_b = delta @ delta^T (between-class scatter)
|
| 20 |
+
4. The top generalized eigenvector is the LEACE direction
|
| 21 |
+
5. Erase by projecting out: x' = x - (x @ v) * v^T
|
| 22 |
+
|
| 23 |
+
This is mathematically equivalent to Fisher's Linear Discriminant but
|
| 24 |
+
applied as an erasure operation. The direction maximizes class
|
| 25 |
+
separability relative to within-class spread, making it the optimal
|
| 26 |
+
single direction to remove for concept erasure.
|
| 27 |
+
|
| 28 |
+
Advantages over SVD:
|
| 29 |
+
- Theoretically optimal: minimizes representation distortion for
|
| 30 |
+
guaranteed erasure of linear concept information
|
| 31 |
+
- Handles rogue dimensions naturally: within-class normalization
|
| 32 |
+
prevents high-variance but non-discriminative dimensions from
|
| 33 |
+
dominating
|
| 34 |
+
- No hyperparameters beyond regularization epsilon
|
| 35 |
+
- Closed-form solution (no iterative optimization)
|
| 36 |
+
|
| 37 |
+
References:
|
| 38 |
+
- Belrose et al. (2023): LEACE: Perfect linear concept erasure in
|
| 39 |
+
closed form. NeurIPS 2023.
|
| 40 |
+
- Ravfogel et al. (2022): RLACE: Adversarial concept erasure
|
| 41 |
+
(iterative precursor to LEACE).
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
from __future__ import annotations
|
| 45 |
+
|
| 46 |
+
from dataclasses import dataclass
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class LEACEResult:
|
| 53 |
+
"""Result of LEACE direction extraction for a single layer."""
|
| 54 |
+
|
| 55 |
+
layer_idx: int
|
| 56 |
+
direction: torch.Tensor # (hidden_dim,) unit vector
|
| 57 |
+
generalized_eigenvalue: float # lambda from GEP (discriminability)
|
| 58 |
+
within_class_condition: float # condition number of S_w
|
| 59 |
+
mean_diff_norm: float # ||mu_1 - mu_0||
|
| 60 |
+
erasure_loss: float # expected squared distortion from erasure
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class LEACEExtractor:
|
| 64 |
+
"""Extract refusal directions via LEACE (closed-form concept erasure).
|
| 65 |
+
|
| 66 |
+
Finds the direction that maximally separates harmful from harmless
|
| 67 |
+
activations relative to within-class variance, then erases it.
|
| 68 |
+
This is the provably optimal rank-1 concept erasure.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
regularization_eps: float = 1e-4,
|
| 74 |
+
shrinkage: float = 0.0,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
regularization_eps: Tikhonov regularization for S_w inversion.
|
| 79 |
+
Larger values produce more conservative (but stable) results.
|
| 80 |
+
shrinkage: Ledoit-Wolf shrinkage toward identity (0..1).
|
| 81 |
+
0 = no shrinkage, 1 = full shrinkage to scaled identity.
|
| 82 |
+
Useful when n_samples < hidden_dim.
|
| 83 |
+
"""
|
| 84 |
+
self.regularization_eps = regularization_eps
|
| 85 |
+
self.shrinkage = shrinkage
|
| 86 |
+
|
| 87 |
+
def extract(
|
| 88 |
+
self,
|
| 89 |
+
harmful_activations: list[torch.Tensor],
|
| 90 |
+
harmless_activations: list[torch.Tensor],
|
| 91 |
+
layer_idx: int = 0,
|
| 92 |
+
) -> LEACEResult:
|
| 93 |
+
"""Extract the LEACE direction for a single layer.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
harmful_activations: List of (hidden_dim,) tensors from harmful prompts.
|
| 97 |
+
harmless_activations: List of (hidden_dim,) tensors from harmless prompts.
|
| 98 |
+
layer_idx: Layer index (for metadata).
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
LEACEResult with the optimal erasure direction.
|
| 102 |
+
"""
|
| 103 |
+
H = torch.stack(harmful_activations).float() # (n_h, d)
|
| 104 |
+
B = torch.stack(harmless_activations).float() # (n_b, d)
|
| 105 |
+
|
| 106 |
+
if H.dim() == 3:
|
| 107 |
+
H = H.squeeze(1)
|
| 108 |
+
if B.dim() == 3:
|
| 109 |
+
B = B.squeeze(1)
|
| 110 |
+
|
| 111 |
+
n_h, d = H.shape
|
| 112 |
+
n_b = B.shape[0]
|
| 113 |
+
|
| 114 |
+
# Class-conditional means
|
| 115 |
+
mu_h = H.mean(dim=0) # (d,)
|
| 116 |
+
mu_b = B.mean(dim=0) # (d,)
|
| 117 |
+
|
| 118 |
+
# Mean difference (between-class direction)
|
| 119 |
+
delta = mu_h - mu_b # (d,)
|
| 120 |
+
delta_norm = delta.norm().item()
|
| 121 |
+
|
| 122 |
+
# Within-class covariance: S_w = (S_h + S_b) / 2
|
| 123 |
+
# where S_h = (H - mu_h)^T (H - mu_h) / (n_h - 1) etc.
|
| 124 |
+
H_centered = H - mu_h.unsqueeze(0)
|
| 125 |
+
B_centered = B - mu_b.unsqueeze(0)
|
| 126 |
+
|
| 127 |
+
S_h = (H_centered.T @ H_centered) / max(n_h - 1, 1)
|
| 128 |
+
S_b = (B_centered.T @ B_centered) / max(n_b - 1, 1)
|
| 129 |
+
S_w = (S_h + S_b) / 2.0 # (d, d)
|
| 130 |
+
|
| 131 |
+
# Apply Ledoit-Wolf shrinkage if requested
|
| 132 |
+
if self.shrinkage > 0:
|
| 133 |
+
trace_S_w = S_w.trace().item()
|
| 134 |
+
S_w = (1 - self.shrinkage) * S_w + self.shrinkage * (trace_S_w / d) * torch.eye(d, device=S_w.device)
|
| 135 |
+
|
| 136 |
+
# Regularize S_w for numerical stability
|
| 137 |
+
S_w_reg = S_w + self.regularization_eps * torch.eye(d, device=S_w.device)
|
| 138 |
+
|
| 139 |
+
# Condition number of S_w (for diagnostics)
|
| 140 |
+
try:
|
| 141 |
+
eigs_w = torch.linalg.eigvalsh(S_w_reg)
|
| 142 |
+
eigs_w = eigs_w.clamp(min=0)
|
| 143 |
+
pos_eigs = eigs_w[eigs_w > eigs_w.max() * 1e-10]
|
| 144 |
+
condition = (pos_eigs.max() / pos_eigs.min()).item() if pos_eigs.numel() > 0 else float('inf')
|
| 145 |
+
except Exception:
|
| 146 |
+
condition = float('inf')
|
| 147 |
+
|
| 148 |
+
# LEACE direction via S_w^{-1} @ delta
|
| 149 |
+
# The generalized eigenvector for rank-1 S_between = delta @ delta^T
|
| 150 |
+
# reduces to: v = S_w^{-1} @ delta (up to normalization)
|
| 151 |
+
try:
|
| 152 |
+
# Use solve for numerical stability (avoids explicit inverse)
|
| 153 |
+
v = torch.linalg.solve(S_w_reg, delta) # (d,)
|
| 154 |
+
except torch.linalg.LinAlgError:
|
| 155 |
+
# Fallback: pseudoinverse
|
| 156 |
+
v = torch.linalg.lstsq(S_w_reg, delta.unsqueeze(1)).solution.squeeze(1)
|
| 157 |
+
|
| 158 |
+
# Normalize to unit length
|
| 159 |
+
v_norm = v.norm()
|
| 160 |
+
if v_norm > 1e-8:
|
| 161 |
+
direction = v / v_norm
|
| 162 |
+
else:
|
| 163 |
+
# Degenerate case: fall back to normalized mean difference
|
| 164 |
+
direction = delta / max(delta_norm, 1e-8)
|
| 165 |
+
|
| 166 |
+
# Generalized eigenvalue: lambda = delta^T @ S_w^{-1} @ delta
|
| 167 |
+
# This measures how discriminable the classes are after whitening
|
| 168 |
+
gen_eigenvalue = (delta @ v).item()
|
| 169 |
+
|
| 170 |
+
# Erasure loss: expected squared distortion E[||x - x'||^2]
|
| 171 |
+
# For rank-1 projection: loss = v^T @ S_total @ v where S_total
|
| 172 |
+
# is the total (pooled) covariance
|
| 173 |
+
all_acts = torch.cat([H, B], dim=0)
|
| 174 |
+
mu_total = all_acts.mean(dim=0)
|
| 175 |
+
centered_total = all_acts - mu_total.unsqueeze(0)
|
| 176 |
+
S_total = (centered_total.T @ centered_total) / max(all_acts.shape[0] - 1, 1)
|
| 177 |
+
erasure_loss = (direction @ S_total @ direction).item()
|
| 178 |
+
|
| 179 |
+
return LEACEResult(
|
| 180 |
+
layer_idx=layer_idx,
|
| 181 |
+
direction=direction,
|
| 182 |
+
generalized_eigenvalue=gen_eigenvalue,
|
| 183 |
+
within_class_condition=condition,
|
| 184 |
+
mean_diff_norm=delta_norm,
|
| 185 |
+
erasure_loss=erasure_loss,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def extract_all_layers(
|
| 189 |
+
self,
|
| 190 |
+
harmful_acts: dict[int, list[torch.Tensor]],
|
| 191 |
+
harmless_acts: dict[int, list[torch.Tensor]],
|
| 192 |
+
) -> dict[int, LEACEResult]:
|
| 193 |
+
"""Extract LEACE directions for all layers.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
harmful_acts: {layer_idx: [activations]} from activation collection.
|
| 197 |
+
harmless_acts: {layer_idx: [activations]} from activation collection.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
{layer_idx: LEACEResult} for each layer.
|
| 201 |
+
"""
|
| 202 |
+
results = {}
|
| 203 |
+
for idx in sorted(harmful_acts.keys()):
|
| 204 |
+
if idx not in harmless_acts:
|
| 205 |
+
continue
|
| 206 |
+
results[idx] = self.extract(
|
| 207 |
+
harmful_acts[idx],
|
| 208 |
+
harmless_acts[idx],
|
| 209 |
+
layer_idx=idx,
|
| 210 |
+
)
|
| 211 |
+
return results
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def compare_with_diff_of_means(
|
| 215 |
+
leace_result: LEACEResult,
|
| 216 |
+
harmful_mean: torch.Tensor,
|
| 217 |
+
harmless_mean: torch.Tensor,
|
| 218 |
+
) -> dict[str, float]:
|
| 219 |
+
"""Compare LEACE direction with simple diff-of-means.
|
| 220 |
+
|
| 221 |
+
Returns cosine similarity and diagnostic metrics showing how much
|
| 222 |
+
the within-class normalization rotates the direction.
|
| 223 |
+
"""
|
| 224 |
+
diff = harmful_mean.squeeze() - harmless_mean.squeeze()
|
| 225 |
+
diff_norm = diff.norm()
|
| 226 |
+
if diff_norm > 1e-8:
|
| 227 |
+
diff_normalized = diff / diff_norm
|
| 228 |
+
else:
|
| 229 |
+
diff_normalized = diff
|
| 230 |
+
|
| 231 |
+
cosine_sim = (leace_result.direction @ diff_normalized).abs().item()
|
| 232 |
+
|
| 233 |
+
return {
|
| 234 |
+
"cosine_similarity": cosine_sim,
|
| 235 |
+
"leace_eigenvalue": leace_result.generalized_eigenvalue,
|
| 236 |
+
"leace_erasure_loss": leace_result.erasure_loss,
|
| 237 |
+
"within_class_condition": leace_result.within_class_condition,
|
| 238 |
+
"mean_diff_norm": leace_result.mean_diff_norm,
|
| 239 |
+
}
|
obliteratus/bayesian_optimizer.py
CHANGED
|
@@ -142,28 +142,35 @@ def _parametric_layer_weight(
|
|
| 142 |
min_weight: float,
|
| 143 |
spread: float,
|
| 144 |
) -> float:
|
| 145 |
-
"""Compute ablation weight for a layer using a
|
| 146 |
|
| 147 |
-
|
| 148 |
-
- max_weight: peak ablation strength
|
| 149 |
-
- peak_position: normalized position of peak (0..1
|
| 150 |
-
- min_weight:
|
| 151 |
-
- spread:
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
"""
|
| 156 |
if n_layers <= 1:
|
| 157 |
return max_weight
|
| 158 |
|
| 159 |
normalized_pos = layer_idx / (n_layers - 1)
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
| 165 |
|
| 166 |
-
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
def _interpolate_direction(
|
|
@@ -171,37 +178,56 @@ def _interpolate_direction(
|
|
| 171 |
layer_idx: int,
|
| 172 |
float_dir_idx: float,
|
| 173 |
) -> torch.Tensor:
|
| 174 |
-
"""Get an interpolated refusal direction from a float-valued index.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
| 179 |
|
| 180 |
Args:
|
| 181 |
-
pipeline: Pipeline with extracted refusal
|
| 182 |
-
layer_idx:
|
| 183 |
-
float_dir_idx: Continuous direction index
|
| 184 |
-
|
| 185 |
|
| 186 |
Returns:
|
| 187 |
Normalized direction tensor.
|
| 188 |
"""
|
| 189 |
-
|
| 190 |
-
|
|
|
|
| 191 |
return pipeline.refusal_directions.get(layer_idx, torch.zeros(1))
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
|
| 197 |
lo = int(float_dir_idx)
|
| 198 |
-
hi = min(lo + 1,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
if lo == hi:
|
| 201 |
-
d =
|
| 202 |
else:
|
|
|
|
| 203 |
alpha = float_dir_idx - lo
|
| 204 |
-
d = (1.0 - alpha) *
|
| 205 |
|
| 206 |
norm = d.norm()
|
| 207 |
if norm > 1e-8:
|
|
@@ -342,9 +368,14 @@ def run_bayesian_optimization(
|
|
| 342 |
for live_data, saved_clone in original_params: # noqa: F821
|
| 343 |
live_data.copy_(saved_clone.to(live_data.device))
|
| 344 |
|
| 345 |
-
# Warm-start values for the parametric kernel
|
| 346 |
-
#
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
peak_layer = pipeline._strong_layers[0]
|
| 349 |
warm_peak = peak_layer / max(n_total_layers - 1, 1)
|
| 350 |
else:
|
|
@@ -356,56 +387,56 @@ def run_bayesian_optimization(
|
|
| 356 |
# Suppress Optuna's verbose logging
|
| 357 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 358 |
|
| 359 |
-
# Max
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
default=1,
|
| 365 |
-
)
|
| 366 |
|
| 367 |
# ── Phase 1: Parametric kernel optimization (compact search space) ──
|
|
|
|
|
|
|
| 368 |
|
| 369 |
def objective(trial: optuna.Trial) -> tuple[float, float]:
|
| 370 |
"""Multi-objective: minimize (refusal_rate, kl_divergence)."""
|
| 371 |
_restore_all()
|
| 372 |
|
| 373 |
-
#
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
|
| 379 |
-
#
|
| 380 |
-
|
| 381 |
-
|
|
|
|
|
|
|
| 382 |
|
| 383 |
-
# Float direction index (
|
| 384 |
-
dir_idx = trial.suggest_float("dir_idx", 0.0, max(
|
| 385 |
|
| 386 |
-
# Compute per-layer regularization from
|
| 387 |
-
|
|
|
|
| 388 |
for idx in pipeline._strong_layers:
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
layer_regs[idx] = 1.0 - weight
|
| 394 |
|
| 395 |
# Apply projection with trial's parameters
|
| 396 |
for idx in pipeline._strong_layers:
|
| 397 |
-
if idx not in pipeline.
|
| 398 |
continue
|
| 399 |
|
| 400 |
-
# Use interpolated direction
|
| 401 |
direction = _interpolate_direction(pipeline, idx, dir_idx)
|
| 402 |
d_col = direction.to(device=next(layer_modules[idx].parameters()).device)
|
| 403 |
d_col = d_col.unsqueeze(-1) if d_col.dim() == 1 else d_col
|
| 404 |
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
# Attention projection (with attn_scale)
|
| 408 |
-
attn_reg = 1.0 - (1.0 - reg) * attn_scale
|
| 409 |
try:
|
| 410 |
attn = get_attention_module(layer_modules[idx], arch)
|
| 411 |
pipeline._project_out_advanced(
|
|
@@ -416,8 +447,8 @@ def run_bayesian_optimization(
|
|
| 416 |
except (AttributeError, RuntimeError):
|
| 417 |
pass
|
| 418 |
|
| 419 |
-
# MLP/FFN projection (with
|
| 420 |
-
mlp_reg =
|
| 421 |
try:
|
| 422 |
ffn = get_ffn_module(layer_modules[idx], arch)
|
| 423 |
count = pipeline._project_out_advanced(
|
|
@@ -439,18 +470,20 @@ def run_bayesian_optimization(
|
|
| 439 |
refusal = _measure_refusal_rate(pipeline, n_prompts=n_refusal_prompts)
|
| 440 |
kl = _measure_kl_divergence(pipeline, reference_logits, kl_prompts)
|
| 441 |
|
| 442 |
-
# Track best combined score
|
| 443 |
nonlocal best_score, best_result
|
| 444 |
combined = refusal + 0.5 * kl
|
| 445 |
if combined < best_score:
|
| 446 |
best_score = combined
|
| 447 |
-
best_result =
|
|
|
|
|
|
|
|
|
|
| 448 |
|
| 449 |
pipeline.log(
|
| 450 |
f" Trial {trial.number + 1}/{n_trials}: "
|
| 451 |
f"refusal={refusal:.0%}, KL={kl:.4f} "
|
| 452 |
-
f"(
|
| 453 |
-
f"attn={attn_scale:.2f}, mlp={mlp_scale:.2f}, dir={dir_idx:.2f})"
|
| 454 |
)
|
| 455 |
|
| 456 |
return refusal, kl
|
|
@@ -462,16 +495,33 @@ def run_bayesian_optimization(
|
|
| 462 |
study_name="obliteratus_parametric_optimization",
|
| 463 |
)
|
| 464 |
|
| 465 |
-
# Enqueue warm-start trial with analysis-derived estimates
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
study.enqueue_trial(warm_params)
|
| 476 |
|
| 477 |
pipeline.log(f"Bayesian optimization: running {n_trials} trials (parametric kernel)...")
|
|
@@ -490,25 +540,32 @@ def run_bayesian_optimization(
|
|
| 490 |
p = best_trial.params
|
| 491 |
best_result = {}
|
| 492 |
for idx in pipeline._strong_layers:
|
| 493 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
idx, n_total_layers,
|
| 495 |
-
p["
|
| 496 |
-
p["
|
| 497 |
)
|
| 498 |
-
best_result[idx] =
|
|
|
|
| 499 |
|
| 500 |
pipeline.log(
|
| 501 |
f" Best trial: refusal={best_trial.values[0]:.0%}, "
|
| 502 |
f"KL={best_trial.values[1]:.4f}"
|
| 503 |
)
|
| 504 |
pipeline.log(
|
| 505 |
-
f"
|
| 506 |
-
f"
|
| 507 |
)
|
| 508 |
pipeline.log(
|
| 509 |
-
f"
|
| 510 |
-
f"
|
| 511 |
)
|
|
|
|
| 512 |
|
| 513 |
# Store the best direction index for use during EXCISE
|
| 514 |
best_dir_idx = p.get("dir_idx", 0.0)
|
|
@@ -518,9 +575,9 @@ def run_bayesian_optimization(
|
|
| 518 |
new_dir = _interpolate_direction(pipeline, idx, best_dir_idx)
|
| 519 |
pipeline.refusal_directions[idx] = new_dir
|
| 520 |
|
| 521 |
-
# Store component scales for use in EXCISE
|
| 522 |
-
pipeline._bayesian_attn_scale = p.get("
|
| 523 |
-
pipeline._bayesian_mlp_scale = p.get("
|
| 524 |
|
| 525 |
elif best_result:
|
| 526 |
pipeline.log(f" Using best combined score: {best_score:.4f}")
|
|
|
|
| 142 |
min_weight: float,
|
| 143 |
spread: float,
|
| 144 |
) -> float:
|
| 145 |
+
"""Compute ablation weight for a layer using a piecewise-linear tent kernel.
|
| 146 |
|
| 147 |
+
Faithful reproduction of Heretic's parametric kernel (p-e-w/heretic):
|
| 148 |
+
- max_weight: peak ablation strength at peak_position
|
| 149 |
+
- peak_position: normalized position of peak (0..1)
|
| 150 |
+
- min_weight: weight at the edges of the tent
|
| 151 |
+
- spread: normalized distance from peak to tent edge (min_weight_distance)
|
| 152 |
|
| 153 |
+
Layers beyond ``spread`` from the peak get weight 0 (skipped entirely).
|
| 154 |
+
Within the tent, weight drops linearly from max_weight to min_weight.
|
| 155 |
+
This matches Heretic's actual formula::
|
| 156 |
+
|
| 157 |
+
distance = abs(layer_index - max_weight_position)
|
| 158 |
+
if distance > min_weight_distance: skip
|
| 159 |
+
weight = max_weight + (distance / min_weight_distance) * (min_weight - max_weight)
|
| 160 |
"""
|
| 161 |
if n_layers <= 1:
|
| 162 |
return max_weight
|
| 163 |
|
| 164 |
normalized_pos = layer_idx / (n_layers - 1)
|
| 165 |
+
dist = abs(normalized_pos - peak_position)
|
| 166 |
+
min_weight_distance = max(spread, 0.01)
|
| 167 |
+
|
| 168 |
+
# Hard cutoff: layers outside the tent get 0 weight (Heretic skips them)
|
| 169 |
+
if dist > min_weight_distance:
|
| 170 |
+
return 0.0
|
| 171 |
|
| 172 |
+
# Linear interpolation: max_weight at peak → min_weight at edges
|
| 173 |
+
return max_weight + (dist / min_weight_distance) * (min_weight - max_weight)
|
| 174 |
|
| 175 |
|
| 176 |
def _interpolate_direction(
|
|
|
|
| 178 |
layer_idx: int,
|
| 179 |
float_dir_idx: float,
|
| 180 |
) -> torch.Tensor:
|
| 181 |
+
"""Get an interpolated refusal direction from a float-valued layer index.
|
| 182 |
+
|
| 183 |
+
Faithful reproduction of Heretic's direction interpolation: the index
|
| 184 |
+
selects which *layer's* diff-of-means direction to use, with float
|
| 185 |
+
values interpolating between adjacent layers' directions. This is
|
| 186 |
+
fundamentally different from interpolating between SVD components
|
| 187 |
+
within a single layer — it searches across the layer axis.
|
| 188 |
+
|
| 189 |
+
From Heretic source (model.py)::
|
| 190 |
|
| 191 |
+
weight, index = math.modf(direction_index + 1)
|
| 192 |
+
refusal_direction = F.normalize(
|
| 193 |
+
refusal_directions[int(index)].lerp(
|
| 194 |
+
refusal_directions[int(index) + 1], weight), p=2, dim=0)
|
| 195 |
|
| 196 |
Args:
|
| 197 |
+
pipeline: Pipeline with extracted refusal directions per layer.
|
| 198 |
+
layer_idx: The layer being projected (used as fallback).
|
| 199 |
+
float_dir_idx: Continuous direction index — selects which layer's
|
| 200 |
+
direction to use (e.g., 5.3 interpolates 70% layer-5 + 30% layer-6).
|
| 201 |
|
| 202 |
Returns:
|
| 203 |
Normalized direction tensor.
|
| 204 |
"""
|
| 205 |
+
# Build sorted list of layer indices that have refusal directions
|
| 206 |
+
sorted_layers = sorted(pipeline.refusal_directions.keys())
|
| 207 |
+
if not sorted_layers:
|
| 208 |
return pipeline.refusal_directions.get(layer_idx, torch.zeros(1))
|
| 209 |
|
| 210 |
+
n_layers_with_dirs = len(sorted_layers)
|
| 211 |
+
|
| 212 |
+
# Heretic uses direction_index + 1 offset; we map float_dir_idx into
|
| 213 |
+
# the sorted layer list, clamped to valid range.
|
| 214 |
+
float_dir_idx = max(0.0, min(float_dir_idx, n_layers_with_dirs - 1))
|
| 215 |
|
| 216 |
lo = int(float_dir_idx)
|
| 217 |
+
hi = min(lo + 1, n_layers_with_dirs - 1)
|
| 218 |
+
|
| 219 |
+
lo_layer = sorted_layers[lo]
|
| 220 |
+
hi_layer = sorted_layers[hi]
|
| 221 |
+
|
| 222 |
+
d_lo = pipeline.refusal_directions[lo_layer]
|
| 223 |
+
d_hi = pipeline.refusal_directions[hi_layer]
|
| 224 |
|
| 225 |
if lo == hi:
|
| 226 |
+
d = d_lo
|
| 227 |
else:
|
| 228 |
+
# Linear interpolation between adjacent layers' directions
|
| 229 |
alpha = float_dir_idx - lo
|
| 230 |
+
d = (1.0 - alpha) * d_lo + alpha * d_hi
|
| 231 |
|
| 232 |
norm = d.norm()
|
| 233 |
if norm > 1e-8:
|
|
|
|
| 368 |
for live_data, saved_clone in original_params: # noqa: F821
|
| 369 |
live_data.copy_(saved_clone.to(live_data.device))
|
| 370 |
|
| 371 |
+
# Warm-start values for the parametric kernel.
|
| 372 |
+
# If the informed pipeline provided analysis-derived warm-start params,
|
| 373 |
+
# use those (they're much better than the default heuristic).
|
| 374 |
+
informed_warm = getattr(pipeline, "_informed_warm_start", None)
|
| 375 |
+
if informed_warm:
|
| 376 |
+
warm_peak = informed_warm.get("peak_position", 0.5)
|
| 377 |
+
pipeline.log(f" Using analysis-informed warm-start (peak={warm_peak:.2f})")
|
| 378 |
+
elif pipeline._strong_layers:
|
| 379 |
peak_layer = pipeline._strong_layers[0]
|
| 380 |
warm_peak = peak_layer / max(n_total_layers - 1, 1)
|
| 381 |
else:
|
|
|
|
| 387 |
# Suppress Optuna's verbose logging
|
| 388 |
optuna.logging.set_verbosity(optuna.logging.WARNING)
|
| 389 |
|
| 390 |
+
# Max layers with directions (for float direction interpolation)
|
| 391 |
+
n_layers_with_dirs = len([
|
| 392 |
+
idx for idx in pipeline._strong_layers
|
| 393 |
+
if idx in pipeline.refusal_directions
|
| 394 |
+
])
|
|
|
|
|
|
|
| 395 |
|
| 396 |
# ── Phase 1: Parametric kernel optimization (compact search space) ──
|
| 397 |
+
# Heretic uses SEPARATE kernel parameters for attention and MLP,
|
| 398 |
+
# allowing them to peak at different layers (8 params + 1 dir_idx = 9).
|
| 399 |
|
| 400 |
def objective(trial: optuna.Trial) -> tuple[float, float]:
|
| 401 |
"""Multi-objective: minimize (refusal_rate, kl_divergence)."""
|
| 402 |
_restore_all()
|
| 403 |
|
| 404 |
+
# Attention kernel: 4 params
|
| 405 |
+
attn_max = trial.suggest_float("attn_max_weight", 0.5, 1.0)
|
| 406 |
+
attn_peak = trial.suggest_float("attn_peak_position", 0.1, 0.9)
|
| 407 |
+
attn_min = trial.suggest_float("attn_min_weight", 0.0, 0.3)
|
| 408 |
+
attn_spread = trial.suggest_float("attn_spread", 0.1, 0.6)
|
| 409 |
|
| 410 |
+
# MLP kernel: 4 params (separate — can peak at a different layer)
|
| 411 |
+
mlp_max = trial.suggest_float("mlp_max_weight", 0.3, 1.0)
|
| 412 |
+
mlp_peak = trial.suggest_float("mlp_peak_position", 0.1, 0.9)
|
| 413 |
+
mlp_min = trial.suggest_float("mlp_min_weight", 0.0, 0.3)
|
| 414 |
+
mlp_spread = trial.suggest_float("mlp_spread", 0.1, 0.6)
|
| 415 |
|
| 416 |
+
# Float direction index (cross-layer interpolation, Heretic-style)
|
| 417 |
+
dir_idx = trial.suggest_float("dir_idx", 0.0, max(n_layers_with_dirs - 1, 0.0))
|
| 418 |
|
| 419 |
+
# Compute per-layer, per-component regularization from kernels
|
| 420 |
+
attn_regs: dict[int, float] = {}
|
| 421 |
+
mlp_regs: dict[int, float] = {}
|
| 422 |
for idx in pipeline._strong_layers:
|
| 423 |
+
attn_w = _parametric_layer_weight(idx, n_total_layers, attn_max, attn_peak, attn_min, attn_spread)
|
| 424 |
+
mlp_w = _parametric_layer_weight(idx, n_total_layers, mlp_max, mlp_peak, mlp_min, mlp_spread)
|
| 425 |
+
attn_regs[idx] = 1.0 - attn_w
|
| 426 |
+
mlp_regs[idx] = 1.0 - mlp_w
|
|
|
|
| 427 |
|
| 428 |
# Apply projection with trial's parameters
|
| 429 |
for idx in pipeline._strong_layers:
|
| 430 |
+
if idx not in pipeline.refusal_directions:
|
| 431 |
continue
|
| 432 |
|
| 433 |
+
# Use cross-layer interpolated direction
|
| 434 |
direction = _interpolate_direction(pipeline, idx, dir_idx)
|
| 435 |
d_col = direction.to(device=next(layer_modules[idx].parameters()).device)
|
| 436 |
d_col = d_col.unsqueeze(-1) if d_col.dim() == 1 else d_col
|
| 437 |
|
| 438 |
+
# Attention projection (with per-component kernel)
|
| 439 |
+
attn_reg = attn_regs[idx]
|
|
|
|
|
|
|
| 440 |
try:
|
| 441 |
attn = get_attention_module(layer_modules[idx], arch)
|
| 442 |
pipeline._project_out_advanced(
|
|
|
|
| 447 |
except (AttributeError, RuntimeError):
|
| 448 |
pass
|
| 449 |
|
| 450 |
+
# MLP/FFN projection (with per-component kernel)
|
| 451 |
+
mlp_reg = mlp_regs[idx]
|
| 452 |
try:
|
| 453 |
ffn = get_ffn_module(layer_modules[idx], arch)
|
| 454 |
count = pipeline._project_out_advanced(
|
|
|
|
| 470 |
refusal = _measure_refusal_rate(pipeline, n_prompts=n_refusal_prompts)
|
| 471 |
kl = _measure_kl_divergence(pipeline, reference_logits, kl_prompts)
|
| 472 |
|
| 473 |
+
# Track best combined score (use average of attn/mlp regs for layer_regs)
|
| 474 |
nonlocal best_score, best_result
|
| 475 |
combined = refusal + 0.5 * kl
|
| 476 |
if combined < best_score:
|
| 477 |
best_score = combined
|
| 478 |
+
best_result = {
|
| 479 |
+
idx: (attn_regs[idx] + mlp_regs[idx]) / 2.0
|
| 480 |
+
for idx in pipeline._strong_layers
|
| 481 |
+
}
|
| 482 |
|
| 483 |
pipeline.log(
|
| 484 |
f" Trial {trial.number + 1}/{n_trials}: "
|
| 485 |
f"refusal={refusal:.0%}, KL={kl:.4f} "
|
| 486 |
+
f"(attn_peak={attn_peak:.2f}, mlp_peak={mlp_peak:.2f}, dir={dir_idx:.2f})"
|
|
|
|
| 487 |
)
|
| 488 |
|
| 489 |
return refusal, kl
|
|
|
|
| 495 |
study_name="obliteratus_parametric_optimization",
|
| 496 |
)
|
| 497 |
|
| 498 |
+
# Enqueue warm-start trial with analysis-derived estimates.
|
| 499 |
+
# Translate informed pipeline params to the new per-component format.
|
| 500 |
+
if informed_warm:
|
| 501 |
+
iw = informed_warm
|
| 502 |
+
warm_params = {
|
| 503 |
+
"attn_max_weight": iw.get("max_weight", 0.9),
|
| 504 |
+
"attn_peak_position": iw.get("peak_position", warm_peak),
|
| 505 |
+
"attn_min_weight": iw.get("min_weight", 0.05),
|
| 506 |
+
"attn_spread": iw.get("spread", 0.3),
|
| 507 |
+
"mlp_max_weight": iw.get("max_weight", 0.9) * iw.get("mlp_scale", 0.6),
|
| 508 |
+
"mlp_peak_position": iw.get("peak_position", warm_peak),
|
| 509 |
+
"mlp_min_weight": iw.get("min_weight", 0.05),
|
| 510 |
+
"mlp_spread": iw.get("spread", 0.3),
|
| 511 |
+
"dir_idx": iw.get("dir_idx", 0.0),
|
| 512 |
+
}
|
| 513 |
+
else:
|
| 514 |
+
warm_params = {
|
| 515 |
+
"attn_max_weight": 0.9,
|
| 516 |
+
"attn_peak_position": warm_peak,
|
| 517 |
+
"attn_min_weight": 0.05,
|
| 518 |
+
"attn_spread": 0.3,
|
| 519 |
+
"mlp_max_weight": 0.6,
|
| 520 |
+
"mlp_peak_position": warm_peak,
|
| 521 |
+
"mlp_min_weight": 0.05,
|
| 522 |
+
"mlp_spread": 0.3,
|
| 523 |
+
"dir_idx": 0.0,
|
| 524 |
+
}
|
| 525 |
study.enqueue_trial(warm_params)
|
| 526 |
|
| 527 |
pipeline.log(f"Bayesian optimization: running {n_trials} trials (parametric kernel)...")
|
|
|
|
| 540 |
p = best_trial.params
|
| 541 |
best_result = {}
|
| 542 |
for idx in pipeline._strong_layers:
|
| 543 |
+
attn_w = _parametric_layer_weight(
|
| 544 |
+
idx, n_total_layers,
|
| 545 |
+
p["attn_max_weight"], p["attn_peak_position"],
|
| 546 |
+
p["attn_min_weight"], p["attn_spread"],
|
| 547 |
+
)
|
| 548 |
+
mlp_w = _parametric_layer_weight(
|
| 549 |
idx, n_total_layers,
|
| 550 |
+
p["mlp_max_weight"], p["mlp_peak_position"],
|
| 551 |
+
p["mlp_min_weight"], p["mlp_spread"],
|
| 552 |
)
|
| 553 |
+
best_result[idx] = (attn_w + mlp_w) / 2.0 # average for layer-level reg
|
| 554 |
+
best_result[idx] = 1.0 - best_result[idx]
|
| 555 |
|
| 556 |
pipeline.log(
|
| 557 |
f" Best trial: refusal={best_trial.values[0]:.0%}, "
|
| 558 |
f"KL={best_trial.values[1]:.4f}"
|
| 559 |
)
|
| 560 |
pipeline.log(
|
| 561 |
+
f" Attn kernel: peak={p['attn_peak_position']:.2f}, "
|
| 562 |
+
f"spread={p['attn_spread']:.2f}, max={p['attn_max_weight']:.2f}"
|
| 563 |
)
|
| 564 |
pipeline.log(
|
| 565 |
+
f" MLP kernel: peak={p['mlp_peak_position']:.2f}, "
|
| 566 |
+
f"spread={p['mlp_spread']:.2f}, max={p['mlp_max_weight']:.2f}"
|
| 567 |
)
|
| 568 |
+
pipeline.log(f" dir_idx={p['dir_idx']:.2f}")
|
| 569 |
|
| 570 |
# Store the best direction index for use during EXCISE
|
| 571 |
best_dir_idx = p.get("dir_idx", 0.0)
|
|
|
|
| 575 |
new_dir = _interpolate_direction(pipeline, idx, best_dir_idx)
|
| 576 |
pipeline.refusal_directions[idx] = new_dir
|
| 577 |
|
| 578 |
+
# Store component scales for use in EXCISE (backward compat)
|
| 579 |
+
pipeline._bayesian_attn_scale = p.get("attn_max_weight", 1.0)
|
| 580 |
+
pipeline._bayesian_mlp_scale = p.get("mlp_max_weight", 1.0)
|
| 581 |
|
| 582 |
elif best_result:
|
| 583 |
pipeline.log(f" Using best combined score: {best_score:.4f}")
|
obliteratus/cli.py
CHANGED
|
@@ -109,7 +109,12 @@ def main(argv: list[str] | None = None):
|
|
| 109 |
],
|
| 110 |
help="Liberation method (default: advanced)",
|
| 111 |
)
|
| 112 |
-
p.add_argument("--n-directions", type=int, default=None, help="Override: number of
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
p.add_argument("--regularization", type=float, default=None, help="Override: fraction to preserve (0.0-1.0)")
|
| 114 |
p.add_argument("--refinement-passes", type=int, default=None, help="Override: number of iterative passes")
|
| 115 |
p.add_argument(
|
|
@@ -591,6 +596,7 @@ def _cmd_abliterate(args):
|
|
| 591 |
dtype=args.dtype,
|
| 592 |
method=method,
|
| 593 |
n_directions=args.n_directions,
|
|
|
|
| 594 |
regularization=args.regularization,
|
| 595 |
refinement_passes=args.refinement_passes,
|
| 596 |
quantization=args.quantization,
|
|
|
|
| 109 |
],
|
| 110 |
help="Liberation method (default: advanced)",
|
| 111 |
)
|
| 112 |
+
p.add_argument("--n-directions", type=int, default=None, help="Override: number of refusal directions to extract")
|
| 113 |
+
p.add_argument(
|
| 114 |
+
"--direction-method", type=str, default=None,
|
| 115 |
+
choices=["diff_means", "svd", "leace"],
|
| 116 |
+
help="Direction extraction method: diff_means (simple, robust), svd (multi-direction), leace (optimal erasure)",
|
| 117 |
+
)
|
| 118 |
p.add_argument("--regularization", type=float, default=None, help="Override: fraction to preserve (0.0-1.0)")
|
| 119 |
p.add_argument("--refinement-passes", type=int, default=None, help="Override: number of iterative passes")
|
| 120 |
p.add_argument(
|
|
|
|
| 596 |
dtype=args.dtype,
|
| 597 |
method=method,
|
| 598 |
n_directions=args.n_directions,
|
| 599 |
+
direction_method=getattr(args, "direction_method", None),
|
| 600 |
regularization=args.regularization,
|
| 601 |
refinement_passes=args.refinement_passes,
|
| 602 |
quantization=args.quantization,
|
obliteratus/informed_pipeline.py
CHANGED
|
@@ -73,15 +73,17 @@ INFORMED_METHOD = {
|
|
| 73 |
"description": (
|
| 74 |
"Runs analysis modules between PROBE and DISTILL to auto-configure "
|
| 75 |
"direction extraction, layer selection, and projection strategy based "
|
| 76 |
-
"on the model's actual refusal geometry."
|
|
|
|
| 77 |
),
|
| 78 |
-
"n_directions":
|
|
|
|
| 79 |
"norm_preserve": True,
|
| 80 |
"regularization": 0.0, # overridden by analysis
|
| 81 |
"refinement_passes": 2, # overridden by analysis
|
| 82 |
"project_biases": True,
|
| 83 |
"use_chat_template": True,
|
| 84 |
-
"use_whitened_svd":
|
| 85 |
"true_iterative_refinement": True,
|
| 86 |
}
|
| 87 |
|
|
@@ -126,7 +128,8 @@ class AnalysisInsights:
|
|
| 126 |
clean_layers: list[int] = field(default_factory=list)
|
| 127 |
|
| 128 |
# Derived configuration
|
| 129 |
-
recommended_n_directions: int =
|
|
|
|
| 130 |
recommended_regularization: float = 0.0
|
| 131 |
recommended_refinement_passes: int = 2
|
| 132 |
recommended_layers: list[int] = field(default_factory=list)
|
|
@@ -217,12 +220,19 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 217 |
hub_token=hub_token,
|
| 218 |
hub_community_org=hub_community_org,
|
| 219 |
quantization=quantization,
|
| 220 |
-
# Set informed defaults
|
|
|
|
|
|
|
| 221 |
norm_preserve=True,
|
| 222 |
project_biases=True,
|
| 223 |
use_chat_template=True,
|
| 224 |
-
use_whitened_svd=
|
| 225 |
true_iterative_refinement=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
)
|
| 227 |
self.method = "informed"
|
| 228 |
|
|
@@ -311,7 +321,11 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 311 |
if self._run_defense:
|
| 312 |
self._analyze_defense_robustness()
|
| 313 |
|
| 314 |
-
# 5.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
self._derive_configuration()
|
| 316 |
|
| 317 |
elapsed = time.time() - t0
|
|
@@ -392,6 +406,7 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 392 |
sample_layers = candidate_layers[::step]
|
| 393 |
|
| 394 |
polyhedral_count = 0
|
|
|
|
| 395 |
best_cone_result = None
|
| 396 |
best_strength = 0.0
|
| 397 |
|
|
@@ -405,34 +420,43 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 405 |
layer_idx=layer_idx,
|
| 406 |
)
|
| 407 |
|
|
|
|
| 408 |
if result.is_polyhedral:
|
| 409 |
polyhedral_count += 1
|
| 410 |
|
| 411 |
-
# Track the strongest layer's cone analysis
|
| 412 |
general_strength = result.general_direction.norm().item() if result.general_direction.numel() > 1 else 0
|
| 413 |
if general_strength > best_strength:
|
| 414 |
best_strength = general_strength
|
| 415 |
best_cone_result = result
|
| 416 |
|
| 417 |
-
if
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
for
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
else:
|
| 437 |
self.log(" No cone results — using default linear assumption")
|
| 438 |
|
|
@@ -517,6 +541,71 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 517 |
self.log(f" Most entangled layers: {emap.most_entangled_layers}")
|
| 518 |
self.log(f" Cleanest layers: {emap.least_entangled_layers}")
|
| 519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
# ── Configuration Derivation ─────────────────────────────────────
|
| 521 |
|
| 522 |
def _derive_configuration(self):
|
|
@@ -528,18 +617,32 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 528 |
self.log("-" * 50)
|
| 529 |
insights = self._insights
|
| 530 |
|
| 531 |
-
# 1. n_directions: based on cone geometry
|
| 532 |
-
|
| 533 |
-
|
|
|
|
|
|
|
| 534 |
n_dirs = max(4, min(8, int(insights.cone_dimensionality * 2)))
|
|
|
|
|
|
|
| 535 |
self.log(f" Polyhedral cone (dim={insights.cone_dimensionality:.1f}) "
|
| 536 |
-
f"→ n_directions={n_dirs}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
else:
|
| 538 |
-
# Linear cone →
|
| 539 |
-
n_dirs =
|
|
|
|
|
|
|
| 540 |
self.log(f" Linear cone (dim={insights.cone_dimensionality:.1f}) "
|
| 541 |
-
f"→ n_directions=
|
| 542 |
insights.recommended_n_directions = n_dirs
|
|
|
|
| 543 |
self.n_directions = n_dirs
|
| 544 |
|
| 545 |
# 2. regularization: based on alignment method + entanglement
|
|
@@ -586,15 +689,22 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 586 |
|
| 587 |
# 4. Layer selection: cluster-aware + entanglement-gated
|
| 588 |
if insights.cluster_representative_layers:
|
| 589 |
-
# Start from cluster representatives
|
| 590 |
base_layers = list(insights.cluster_representative_layers)
|
| 591 |
|
| 592 |
-
#
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
for cluster in insights.direction_clusters:
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
|
|
|
|
|
|
| 598 |
|
| 599 |
# Gate: remove highly entangled layers
|
| 600 |
skip = set()
|
|
@@ -621,13 +731,9 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 621 |
self.log(f" RSI={insights.mean_refusal_sparsity_index:.2f} "
|
| 622 |
f"→ standard dense projection")
|
| 623 |
|
| 624 |
-
# 6.
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
self.log(f" Multi-direction ({n_dirs}) → whitened SVD enabled")
|
| 628 |
-
else:
|
| 629 |
-
self.use_whitened_svd = False
|
| 630 |
-
self.log(" Single direction → standard diff-in-means")
|
| 631 |
|
| 632 |
# ── Informed DISTILL ─────────────────────────────────────────────
|
| 633 |
|
|
@@ -650,7 +756,38 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 650 |
n_layers = len(self._harmful_means)
|
| 651 |
norms: dict[int, float] = {}
|
| 652 |
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
|
| 655 |
whitened_extractor = WhitenedSVDExtractor()
|
| 656 |
self.log(f"Using whitened SVD with {self.n_directions} directions")
|
|
@@ -658,6 +795,29 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 658 |
whitened_extractor = None
|
| 659 |
|
| 660 |
for idx in range(n_layers):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
if self.n_directions == 1:
|
| 662 |
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
|
| 663 |
norm = diff.norm().item()
|
|
@@ -691,6 +851,41 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 691 |
self.refusal_directions[idx] = primary / primary.norm()
|
| 692 |
norms[idx] = S[:k].sum().item()
|
| 693 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
# Layer selection: use analysis-recommended layers if available,
|
| 695 |
# otherwise fall back to knee detection
|
| 696 |
if self._insights.recommended_layers:
|
|
@@ -728,15 +923,117 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 728 |
def _excise_informed(self):
|
| 729 |
"""Excise refusal directions with analysis-informed strategy.
|
| 730 |
|
| 731 |
-
Uses
|
| 732 |
-
|
|
|
|
|
|
|
| 733 |
"""
|
| 734 |
if self._insights.use_sparse_surgery:
|
| 735 |
self._excise_sparse()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
else:
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
def _excise_sparse(self):
|
| 742 |
"""Sparse direction surgery — only modifies high-projection rows."""
|
|
@@ -825,14 +1122,22 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 825 |
1. Residual refusal signal (via activation probing)
|
| 826 |
2. Self-repair / Ouroboros effect (via defense robustness)
|
| 827 |
3. Triggers additional targeted passes at compensating layers
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
"""
|
| 829 |
# Run standard verification first
|
| 830 |
self._verify()
|
| 831 |
|
| 832 |
# Check if Ouroboros compensation is needed
|
| 833 |
refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
|
|
|
|
| 834 |
ouroboros_pass = 0
|
| 835 |
|
|
|
|
|
|
|
|
|
|
| 836 |
while (refusal_rate > self._ouroboros_threshold
|
| 837 |
and ouroboros_pass < self._max_ouroboros_passes):
|
| 838 |
ouroboros_pass += 1
|
|
@@ -849,9 +1154,9 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 849 |
self._distill_inner()
|
| 850 |
self.log(f"Found {len(self._strong_layers)} layers with residual refusal")
|
| 851 |
|
| 852 |
-
# Re-excise at the new strong layers
|
| 853 |
if self._strong_layers:
|
| 854 |
-
self.
|
| 855 |
else:
|
| 856 |
self.log("No strong layers found — stopping Ouroboros compensation")
|
| 857 |
break
|
|
@@ -859,7 +1164,24 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 859 |
# Re-verify
|
| 860 |
self._verify()
|
| 861 |
refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
|
| 862 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
| 864 |
self._report.ouroboros_passes = ouroboros_pass
|
| 865 |
self._report.final_refusal_rate = refusal_rate
|
|
@@ -903,6 +1225,7 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 903 |
},
|
| 904 |
"derived_config": {
|
| 905 |
"n_directions": insights.recommended_n_directions,
|
|
|
|
| 906 |
"regularization": insights.recommended_regularization,
|
| 907 |
"refinement_passes": insights.recommended_refinement_passes,
|
| 908 |
"layers_used": insights.recommended_layers,
|
|
@@ -981,6 +1304,7 @@ class InformedAbliterationPipeline(AbliterationPipeline):
|
|
| 981 |
|
| 982 |
lines.append("Derived Configuration:")
|
| 983 |
lines.append(f" n_directions: {insights.recommended_n_directions}")
|
|
|
|
| 984 |
lines.append(f" regularization: {insights.recommended_regularization}")
|
| 985 |
lines.append(f" refinement_passes: {insights.recommended_refinement_passes}")
|
| 986 |
lines.append(f" sparse surgery: {insights.use_sparse_surgery}")
|
|
|
|
| 73 |
"description": (
|
| 74 |
"Runs analysis modules between PROBE and DISTILL to auto-configure "
|
| 75 |
"direction extraction, layer selection, and projection strategy based "
|
| 76 |
+
"on the model's actual refusal geometry. Defaults to single diff-of-means "
|
| 77 |
+
"direction + Bayesian optimization (Heretic-style)."
|
| 78 |
),
|
| 79 |
+
"n_directions": 1, # overridden by analysis
|
| 80 |
+
"direction_method": "diff_means", # overridden by analysis; "leace" also available
|
| 81 |
"norm_preserve": True,
|
| 82 |
"regularization": 0.0, # overridden by analysis
|
| 83 |
"refinement_passes": 2, # overridden by analysis
|
| 84 |
"project_biases": True,
|
| 85 |
"use_chat_template": True,
|
| 86 |
+
"use_whitened_svd": False, # overridden by analysis
|
| 87 |
"true_iterative_refinement": True,
|
| 88 |
}
|
| 89 |
|
|
|
|
| 128 |
clean_layers: list[int] = field(default_factory=list)
|
| 129 |
|
| 130 |
# Derived configuration
|
| 131 |
+
recommended_n_directions: int = 1
|
| 132 |
+
recommended_direction_method: str = "diff_means"
|
| 133 |
recommended_regularization: float = 0.0
|
| 134 |
recommended_refinement_passes: int = 2
|
| 135 |
recommended_layers: list[int] = field(default_factory=list)
|
|
|
|
| 220 |
hub_token=hub_token,
|
| 221 |
hub_community_org=hub_community_org,
|
| 222 |
quantization=quantization,
|
| 223 |
+
# Set informed defaults: single direction + Bayesian opt
|
| 224 |
+
n_directions=1,
|
| 225 |
+
direction_method="diff_means",
|
| 226 |
norm_preserve=True,
|
| 227 |
project_biases=True,
|
| 228 |
use_chat_template=True,
|
| 229 |
+
use_whitened_svd=False,
|
| 230 |
true_iterative_refinement=True,
|
| 231 |
+
use_kl_optimization=True,
|
| 232 |
+
float_layer_interpolation=True,
|
| 233 |
+
layer_adaptive_strength=True,
|
| 234 |
+
winsorize_activations=True,
|
| 235 |
+
winsorize_percentile=0.01,
|
| 236 |
)
|
| 237 |
self.method = "informed"
|
| 238 |
|
|
|
|
| 321 |
if self._run_defense:
|
| 322 |
self._analyze_defense_robustness()
|
| 323 |
|
| 324 |
+
# 5. Sparse Surgery Analysis (RSI computation)
|
| 325 |
+
if self._run_sparse:
|
| 326 |
+
self._analyze_sparsity()
|
| 327 |
+
|
| 328 |
+
# 6. Derive configuration from insights
|
| 329 |
self._derive_configuration()
|
| 330 |
|
| 331 |
elapsed = time.time() - t0
|
|
|
|
| 406 |
sample_layers = candidate_layers[::step]
|
| 407 |
|
| 408 |
polyhedral_count = 0
|
| 409 |
+
all_results = []
|
| 410 |
best_cone_result = None
|
| 411 |
best_strength = 0.0
|
| 412 |
|
|
|
|
| 420 |
layer_idx=layer_idx,
|
| 421 |
)
|
| 422 |
|
| 423 |
+
all_results.append(result)
|
| 424 |
if result.is_polyhedral:
|
| 425 |
polyhedral_count += 1
|
| 426 |
|
| 427 |
+
# Track the strongest layer's cone analysis for per-category directions
|
| 428 |
general_strength = result.general_direction.norm().item() if result.general_direction.numel() > 1 else 0
|
| 429 |
if general_strength > best_strength:
|
| 430 |
best_strength = general_strength
|
| 431 |
best_cone_result = result
|
| 432 |
|
| 433 |
+
if all_results:
|
| 434 |
+
# Aggregate cone geometry across sampled layers (majority vote +
|
| 435 |
+
# mean dimensionality) instead of relying on a single layer.
|
| 436 |
+
n_sampled = len(all_results)
|
| 437 |
+
is_polyhedral = polyhedral_count > n_sampled / 2
|
| 438 |
+
avg_dimensionality = sum(r.cone_dimensionality for r in all_results) / n_sampled
|
| 439 |
+
avg_pairwise_cos = sum(r.mean_pairwise_cosine for r in all_results) / n_sampled
|
| 440 |
+
|
| 441 |
+
self._insights.cone_is_polyhedral = is_polyhedral
|
| 442 |
+
self._insights.cone_dimensionality = avg_dimensionality
|
| 443 |
+
self._insights.mean_pairwise_cosine = avg_pairwise_cos
|
| 444 |
+
|
| 445 |
+
# Store per-category directions from the strongest layer
|
| 446 |
+
if best_cone_result is not None:
|
| 447 |
+
for cd in best_cone_result.category_directions:
|
| 448 |
+
self._insights.per_category_directions[cd.category] = cd.direction
|
| 449 |
+
self._insights.direction_specificity[cd.category] = cd.specificity
|
| 450 |
+
|
| 451 |
+
cone_type = "POLYHEDRAL" if is_polyhedral else "LINEAR"
|
| 452 |
+
self.log(f" Cone type: {cone_type} (majority vote: {polyhedral_count}/{n_sampled} layers)")
|
| 453 |
+
self.log(f" Avg dimensionality: {avg_dimensionality:.2f}")
|
| 454 |
+
self.log(f" Avg pairwise cosine: {avg_pairwise_cos:.3f}")
|
| 455 |
+
if best_cone_result is not None:
|
| 456 |
+
self.log(f" Categories detected: {best_cone_result.category_count}")
|
| 457 |
+
|
| 458 |
+
for cd in sorted(best_cone_result.category_directions, key=lambda x: -x.strength)[:5]:
|
| 459 |
+
self.log(f" {cd.category:15s} DSI={cd.specificity:.3f} str={cd.strength:.3f}")
|
| 460 |
else:
|
| 461 |
self.log(" No cone results — using default linear assumption")
|
| 462 |
|
|
|
|
| 541 |
self.log(f" Most entangled layers: {emap.most_entangled_layers}")
|
| 542 |
self.log(f" Cleanest layers: {emap.least_entangled_layers}")
|
| 543 |
|
| 544 |
+
def _analyze_sparsity(self):
|
| 545 |
+
"""Compute Refusal Sparsity Index to decide sparse vs dense excision."""
|
| 546 |
+
self.log("\n[5/5] Refusal Sparsity Analysis")
|
| 547 |
+
self.log("-" * 40)
|
| 548 |
+
|
| 549 |
+
from obliteratus.analysis.sparse_surgery import SparseDirectionSurgeon
|
| 550 |
+
from obliteratus.strategies.utils import (
|
| 551 |
+
get_ffn_module,
|
| 552 |
+
get_layer_modules,
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
# Need refusal directions — use quick diff-in-means
|
| 556 |
+
quick_directions = {}
|
| 557 |
+
for idx in sorted(self._harmful_means.keys()):
|
| 558 |
+
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze()
|
| 559 |
+
norm = diff.norm().item()
|
| 560 |
+
if norm > 1e-10:
|
| 561 |
+
quick_directions[idx] = diff / diff.norm()
|
| 562 |
+
|
| 563 |
+
if not quick_directions:
|
| 564 |
+
self.log(" No refusal directions — skipping sparsity analysis")
|
| 565 |
+
return
|
| 566 |
+
|
| 567 |
+
# Gather FFN output weights for representative layers (sample for speed)
|
| 568 |
+
layers = get_layer_modules(self.handle)
|
| 569 |
+
arch = self.handle.architecture
|
| 570 |
+
n_layers = len(layers)
|
| 571 |
+
sample_idxs = sorted(quick_directions.keys())
|
| 572 |
+
step = max(1, len(sample_idxs) // 8)
|
| 573 |
+
sample_idxs = sample_idxs[::step]
|
| 574 |
+
|
| 575 |
+
weights = {}
|
| 576 |
+
sampled_dirs = {}
|
| 577 |
+
for idx in sample_idxs:
|
| 578 |
+
if idx >= n_layers:
|
| 579 |
+
continue
|
| 580 |
+
try:
|
| 581 |
+
ffn = get_ffn_module(layers[idx], arch)
|
| 582 |
+
for name in ["down_proj", "c_proj", "dense_4h_to_h", "fc_out", "fc2", "w2"]:
|
| 583 |
+
proj = getattr(ffn, name, None)
|
| 584 |
+
if proj is not None and hasattr(proj, "weight"):
|
| 585 |
+
W = proj.weight.data
|
| 586 |
+
d = quick_directions[idx]
|
| 587 |
+
if W.shape[-1] == d.shape[0]:
|
| 588 |
+
weights[idx] = W
|
| 589 |
+
sampled_dirs[idx] = d
|
| 590 |
+
break
|
| 591 |
+
except (AttributeError, RuntimeError):
|
| 592 |
+
continue
|
| 593 |
+
|
| 594 |
+
if not weights:
|
| 595 |
+
self.log(" Could not access FFN weights — skipping sparsity analysis")
|
| 596 |
+
return
|
| 597 |
+
|
| 598 |
+
surgeon = SparseDirectionSurgeon(auto_sparsity=True)
|
| 599 |
+
plan = surgeon.plan_surgery(weights, sampled_dirs)
|
| 600 |
+
|
| 601 |
+
self._insights.mean_refusal_sparsity_index = plan.mean_refusal_sparsity_index
|
| 602 |
+
self._insights.recommended_sparsity = plan.recommended_sparsity
|
| 603 |
+
|
| 604 |
+
self.log(f" Mean RSI: {plan.mean_refusal_sparsity_index:.3f}")
|
| 605 |
+
self.log(f" Recommended sparsity: {plan.recommended_sparsity:.1%}")
|
| 606 |
+
self.log(f" Most sparse layer: {plan.most_sparse_layer}")
|
| 607 |
+
self.log(f" Most dense layer: {plan.most_dense_layer}")
|
| 608 |
+
|
| 609 |
# ── Configuration Derivation ─────────────────────────────────────
|
| 610 |
|
| 611 |
def _derive_configuration(self):
|
|
|
|
| 617 |
self.log("-" * 50)
|
| 618 |
insights = self._insights
|
| 619 |
|
| 620 |
+
# 1. n_directions + direction_method: based on cone geometry
|
| 621 |
+
# Default: single direction via diff-of-means (proven most robust).
|
| 622 |
+
# Only escalate to multi-direction when analysis confirms polyhedral geometry.
|
| 623 |
+
if insights.cone_is_polyhedral and insights.cone_dimensionality > 2.0:
|
| 624 |
+
# Clearly polyhedral cone → use multiple directions via SVD
|
| 625 |
n_dirs = max(4, min(8, int(insights.cone_dimensionality * 2)))
|
| 626 |
+
self.direction_method = "svd"
|
| 627 |
+
self.use_whitened_svd = True
|
| 628 |
self.log(f" Polyhedral cone (dim={insights.cone_dimensionality:.1f}) "
|
| 629 |
+
f"→ n_directions={n_dirs}, method=svd (whitened)")
|
| 630 |
+
elif insights.cone_is_polyhedral:
|
| 631 |
+
# Mildly polyhedral → LEACE gives better single-direction erasure
|
| 632 |
+
n_dirs = 1
|
| 633 |
+
self.direction_method = "leace"
|
| 634 |
+
self.use_whitened_svd = False
|
| 635 |
+
self.log(f" Mildly polyhedral (dim={insights.cone_dimensionality:.1f}) "
|
| 636 |
+
f"→ n_directions=1, method=leace")
|
| 637 |
else:
|
| 638 |
+
# Linear cone → single direction via diff-of-means (simplest, most robust)
|
| 639 |
+
n_dirs = 1
|
| 640 |
+
self.direction_method = "diff_means"
|
| 641 |
+
self.use_whitened_svd = False
|
| 642 |
self.log(f" Linear cone (dim={insights.cone_dimensionality:.1f}) "
|
| 643 |
+
f"→ n_directions=1, method=diff_means")
|
| 644 |
insights.recommended_n_directions = n_dirs
|
| 645 |
+
insights.recommended_direction_method = self.direction_method
|
| 646 |
self.n_directions = n_dirs
|
| 647 |
|
| 648 |
# 2. regularization: based on alignment method + entanglement
|
|
|
|
| 689 |
|
| 690 |
# 4. Layer selection: cluster-aware + entanglement-gated
|
| 691 |
if insights.cluster_representative_layers:
|
| 692 |
+
# Start from cluster representatives (strongest per cluster)
|
| 693 |
base_layers = list(insights.cluster_representative_layers)
|
| 694 |
|
| 695 |
+
# Conservative expansion: for each cluster, add at most the top-2
|
| 696 |
+
# strongest layers (by refusal norm) beyond the representative,
|
| 697 |
+
# to avoid over-modifying weak layers in large clusters.
|
| 698 |
+
norms = {}
|
| 699 |
+
for idx in self._harmful_means:
|
| 700 |
+
if idx in self._harmless_means:
|
| 701 |
+
norms[idx] = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze().norm().item()
|
| 702 |
for cluster in insights.direction_clusters:
|
| 703 |
+
ranked = sorted(cluster, key=lambda ly: norms.get(ly, 0), reverse=True)
|
| 704 |
+
# Add up to 2 additional strong layers per cluster
|
| 705 |
+
for ly in ranked[:3]: # representative + up to 2 more
|
| 706 |
+
base_layers.append(ly)
|
| 707 |
+
base_layers = sorted(set(base_layers))
|
| 708 |
|
| 709 |
# Gate: remove highly entangled layers
|
| 710 |
skip = set()
|
|
|
|
| 731 |
self.log(f" RSI={insights.mean_refusal_sparsity_index:.2f} "
|
| 732 |
f"→ standard dense projection")
|
| 733 |
|
| 734 |
+
# 6. Direction method summary (already set in step 1)
|
| 735 |
+
self.log(f" Direction method: {self.direction_method} "
|
| 736 |
+
f"(whitened_svd={'on' if self.use_whitened_svd else 'off'})")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
|
| 738 |
# ── Informed DISTILL ─────────────────────────────────────────────
|
| 739 |
|
|
|
|
| 756 |
n_layers = len(self._harmful_means)
|
| 757 |
norms: dict[int, float] = {}
|
| 758 |
|
| 759 |
+
# ── Small-model direction cap (matching base _distill) ────────
|
| 760 |
+
# On small models, each SVD direction removes a proportionally
|
| 761 |
+
# larger fraction of weight energy. Cap to prevent over-ablation.
|
| 762 |
+
hidden_size = self.handle.hidden_size if self.handle else 0
|
| 763 |
+
total_params = getattr(self.handle, 'total_params', 0) if self.handle else 0
|
| 764 |
+
if total_params == 0 and self.handle:
|
| 765 |
+
try:
|
| 766 |
+
total_params = sum(p.numel() for p in self.handle.model.parameters())
|
| 767 |
+
except Exception:
|
| 768 |
+
pass
|
| 769 |
+
if self.n_directions > 1 and (
|
| 770 |
+
(0 < hidden_size < 2048)
|
| 771 |
+
or (0 < total_params < 2_000_000_000)
|
| 772 |
+
or n_layers <= 16
|
| 773 |
+
):
|
| 774 |
+
max_dirs = max(1, min(self.n_directions, 2))
|
| 775 |
+
if max_dirs < self.n_directions:
|
| 776 |
+
self.log(
|
| 777 |
+
f"Capped n_directions from {self.n_directions} to {max_dirs} "
|
| 778 |
+
f"for small model (hidden={hidden_size}, "
|
| 779 |
+
f"params={total_params / 1e9:.1f}B, layers={n_layers})"
|
| 780 |
+
)
|
| 781 |
+
self.n_directions = max_dirs
|
| 782 |
+
|
| 783 |
+
# LEACE extractor for optimal concept erasure
|
| 784 |
+
leace_extractor = None
|
| 785 |
+
if self.direction_method == "leace":
|
| 786 |
+
from obliteratus.analysis.leace import LEACEExtractor
|
| 787 |
+
leace_extractor = LEACEExtractor()
|
| 788 |
+
self.log(f"Using LEACE (closed-form optimal concept erasure)")
|
| 789 |
+
|
| 790 |
+
if self.use_whitened_svd and self.n_directions > 1 and leace_extractor is None:
|
| 791 |
from obliteratus.analysis.whitened_svd import WhitenedSVDExtractor
|
| 792 |
whitened_extractor = WhitenedSVDExtractor()
|
| 793 |
self.log(f"Using whitened SVD with {self.n_directions} directions")
|
|
|
|
| 795 |
whitened_extractor = None
|
| 796 |
|
| 797 |
for idx in range(n_layers):
|
| 798 |
+
# LEACE path: theoretically optimal single-direction erasure
|
| 799 |
+
if leace_extractor is not None:
|
| 800 |
+
if idx in self._harmful_acts and idx in self._harmless_acts:
|
| 801 |
+
try:
|
| 802 |
+
l_result = leace_extractor.extract(
|
| 803 |
+
self._harmful_acts[idx],
|
| 804 |
+
self._harmless_acts[idx],
|
| 805 |
+
layer_idx=idx,
|
| 806 |
+
)
|
| 807 |
+
self.refusal_directions[idx] = l_result.direction
|
| 808 |
+
self.refusal_subspaces[idx] = l_result.direction.unsqueeze(0)
|
| 809 |
+
norms[idx] = l_result.generalized_eigenvalue
|
| 810 |
+
|
| 811 |
+
if idx < 5 or idx == n_layers - 1:
|
| 812 |
+
self.log(
|
| 813 |
+
f" layer {idx}: LEACE eigenvalue={l_result.generalized_eigenvalue:.4f}, "
|
| 814 |
+
f"erasure_loss={l_result.erasure_loss:.4f}"
|
| 815 |
+
)
|
| 816 |
+
continue
|
| 817 |
+
except Exception as e:
|
| 818 |
+
if idx < 5:
|
| 819 |
+
self.log(f" layer {idx}: LEACE failed ({e}), falling back")
|
| 820 |
+
|
| 821 |
if self.n_directions == 1:
|
| 822 |
diff = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze(0)
|
| 823 |
norm = diff.norm().item()
|
|
|
|
| 851 |
self.refusal_directions[idx] = primary / primary.norm()
|
| 852 |
norms[idx] = S[:k].sum().item()
|
| 853 |
|
| 854 |
+
# Enrich subspaces with per-category cone directions when available.
|
| 855 |
+
# This uses the actual refusal cone generators instead of purely
|
| 856 |
+
# data-agnostic SVD components.
|
| 857 |
+
cat_dirs = self._insights.per_category_directions
|
| 858 |
+
if cat_dirs and self._insights.cone_is_polyhedral and self.n_directions > 1:
|
| 859 |
+
cat_tensors = list(cat_dirs.values())
|
| 860 |
+
# Stack and orthogonalize category directions
|
| 861 |
+
cat_stack = torch.stack(cat_tensors) # (n_cats, hidden)
|
| 862 |
+
cat_norms = cat_stack.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
| 863 |
+
cat_stack = cat_stack / cat_norms
|
| 864 |
+
# Blend into strong-signal layers: replace later SVD components
|
| 865 |
+
# with category directions (which are geometrically meaningful)
|
| 866 |
+
n_cat = cat_stack.shape[0]
|
| 867 |
+
for idx in norms:
|
| 868 |
+
sub = self.refusal_subspaces.get(idx)
|
| 869 |
+
if sub is None or sub.shape[0] <= 1:
|
| 870 |
+
continue
|
| 871 |
+
# Keep the first SVD direction (strongest), replace remaining
|
| 872 |
+
# with category directions projected to be orthogonal to it
|
| 873 |
+
primary = sub[0:1] # (1, hidden)
|
| 874 |
+
# Project category directions orthogonal to primary
|
| 875 |
+
cos = (cat_stack @ primary.squeeze(0)) # (n_cat,)
|
| 876 |
+
ortho_cats = cat_stack - cos.unsqueeze(1) * primary
|
| 877 |
+
ortho_norms = ortho_cats.norm(dim=1)
|
| 878 |
+
# Keep only directions that survived orthogonalization
|
| 879 |
+
valid = ortho_norms > 0.1
|
| 880 |
+
if valid.sum() > 0:
|
| 881 |
+
ortho_cats = ortho_cats[valid]
|
| 882 |
+
ortho_cats = ortho_cats / ortho_cats.norm(dim=1, keepdim=True)
|
| 883 |
+
# Take up to (n_directions - 1) category directions
|
| 884 |
+
n_take = min(self.n_directions - 1, ortho_cats.shape[0])
|
| 885 |
+
new_sub = torch.cat([primary, ortho_cats[:n_take]], dim=0)
|
| 886 |
+
self.refusal_subspaces[idx] = new_sub
|
| 887 |
+
self.log(f"Enriched subspaces with {n_cat} per-category cone directions")
|
| 888 |
+
|
| 889 |
# Layer selection: use analysis-recommended layers if available,
|
| 890 |
# otherwise fall back to knee detection
|
| 891 |
if self._insights.recommended_layers:
|
|
|
|
| 923 |
def _excise_informed(self):
|
| 924 |
"""Excise refusal directions with analysis-informed strategy.
|
| 925 |
|
| 926 |
+
Uses Bayesian optimization (when available) with analysis-derived
|
| 927 |
+
warm-start parameters, falling back to sparse surgery or standard
|
| 928 |
+
projection. This is the key integration: analysis maps the geometry,
|
| 929 |
+
Bayesian optimization finds the optimal projection strength.
|
| 930 |
"""
|
| 931 |
if self._insights.use_sparse_surgery:
|
| 932 |
self._excise_sparse()
|
| 933 |
+
return
|
| 934 |
+
|
| 935 |
+
# Enable Bayesian optimization using analysis insights for warm-start.
|
| 936 |
+
# The analysis provides much better initial parameters than the default
|
| 937 |
+
# heuristic (strongest-layer-based peak), dramatically narrowing the
|
| 938 |
+
# search space and improving convergence.
|
| 939 |
+
self._configure_bayesian_warm_start()
|
| 940 |
+
self._excise()
|
| 941 |
+
|
| 942 |
+
def _configure_bayesian_warm_start(self):
|
| 943 |
+
"""Configure Bayesian optimization with analysis-derived warm-start.
|
| 944 |
+
|
| 945 |
+
Translates analysis insights into a much tighter search space:
|
| 946 |
+
- peak_position from cluster representative layers
|
| 947 |
+
- spread from cluster structure (narrow clusters → narrow spread)
|
| 948 |
+
- component scaling from entanglement analysis
|
| 949 |
+
- KL budget from alignment method detection
|
| 950 |
+
"""
|
| 951 |
+
insights = self._insights
|
| 952 |
+
|
| 953 |
+
# Enable Bayesian optimization (50 trials default, same as heretic)
|
| 954 |
+
self._bayesian_trials = 50
|
| 955 |
+
|
| 956 |
+
# Also set heretic-compatible flags on the pipeline so the base
|
| 957 |
+
# _excise_inner() picks them up during Bayesian optimization.
|
| 958 |
+
self.layer_adaptive_strength = True
|
| 959 |
+
self.float_layer_interpolation = True
|
| 960 |
+
self.use_kl_optimization = True
|
| 961 |
+
|
| 962 |
+
# KL budget: tighter for methods that are fragile (CAI, RLHF),
|
| 963 |
+
# looser for concentrated methods (DPO, SFT).
|
| 964 |
+
method = insights.detected_alignment_method
|
| 965 |
+
if method == "dpo":
|
| 966 |
+
self.kl_budget = 0.5
|
| 967 |
+
elif method == "rlhf":
|
| 968 |
+
self.kl_budget = 0.3
|
| 969 |
+
elif method == "cai":
|
| 970 |
+
self.kl_budget = 0.2
|
| 971 |
+
elif method == "sft":
|
| 972 |
+
self.kl_budget = 0.4
|
| 973 |
+
else:
|
| 974 |
+
self.kl_budget = 0.35
|
| 975 |
+
|
| 976 |
+
self.log(f"Bayesian optimization enabled (50 trials, KL budget={self.kl_budget})")
|
| 977 |
+
self.log("Analysis insights will warm-start the optimizer")
|
| 978 |
+
|
| 979 |
+
# Compute analysis-derived warm-start for the parametric kernel.
|
| 980 |
+
# The Bayesian optimizer reads these from the pipeline if present.
|
| 981 |
+
n_layers = len(self._harmful_means) if self._harmful_means else 32
|
| 982 |
+
if insights.cluster_representative_layers and n_layers > 1:
|
| 983 |
+
# Peak position: normalized position of the strongest cluster rep
|
| 984 |
+
norms = {}
|
| 985 |
+
for idx in self._harmful_means:
|
| 986 |
+
if idx in self._harmless_means:
|
| 987 |
+
norms[idx] = (self._harmful_means[idx] - self._harmless_means[idx]).squeeze().norm().item()
|
| 988 |
+
reps = insights.cluster_representative_layers
|
| 989 |
+
if norms:
|
| 990 |
+
best_rep = max(reps, key=lambda ly: norms.get(ly, 0))
|
| 991 |
+
else:
|
| 992 |
+
best_rep = reps[len(reps) // 2]
|
| 993 |
+
warm_peak = best_rep / max(n_layers - 1, 1)
|
| 994 |
+
|
| 995 |
+
# Spread: narrow if clusters are tight, wide if clusters span many layers
|
| 996 |
+
if insights.direction_clusters:
|
| 997 |
+
cluster_widths = [
|
| 998 |
+
(max(c) - min(c)) / max(n_layers - 1, 1)
|
| 999 |
+
for c in insights.direction_clusters if len(c) > 1
|
| 1000 |
+
]
|
| 1001 |
+
warm_spread = max(0.1, min(0.6, sum(cluster_widths) / len(cluster_widths) if cluster_widths else 0.3))
|
| 1002 |
+
else:
|
| 1003 |
+
warm_spread = 0.3
|
| 1004 |
+
|
| 1005 |
+
# Min weight: higher if high persistence (refusal spread across all layers)
|
| 1006 |
+
warm_min = min(0.3, max(0.0, insights.direction_persistence * 0.2))
|
| 1007 |
+
|
| 1008 |
+
# Attn/MLP scaling: reduce MLP scaling if entanglement is high
|
| 1009 |
+
# (MLP projections cause more capability damage)
|
| 1010 |
+
if insights.entanglement_score > 0.5:
|
| 1011 |
+
warm_mlp = 0.4
|
| 1012 |
+
warm_attn = 0.7
|
| 1013 |
+
else:
|
| 1014 |
+
warm_mlp = 0.6
|
| 1015 |
+
warm_attn = 0.8
|
| 1016 |
else:
|
| 1017 |
+
warm_peak = 0.5
|
| 1018 |
+
warm_spread = 0.3
|
| 1019 |
+
warm_min = 0.05
|
| 1020 |
+
warm_mlp = 0.6
|
| 1021 |
+
warm_attn = 0.8
|
| 1022 |
+
|
| 1023 |
+
# Store warm-start params for the Bayesian optimizer to pick up
|
| 1024 |
+
self._informed_warm_start = {
|
| 1025 |
+
"max_weight": 0.9,
|
| 1026 |
+
"peak_position": warm_peak,
|
| 1027 |
+
"min_weight": warm_min,
|
| 1028 |
+
"spread": warm_spread,
|
| 1029 |
+
"attn_scale": warm_attn,
|
| 1030 |
+
"mlp_scale": warm_mlp,
|
| 1031 |
+
"dir_idx": 0.0,
|
| 1032 |
+
}
|
| 1033 |
+
self.log(
|
| 1034 |
+
f" Warm-start: peak={warm_peak:.2f}, spread={warm_spread:.2f}, "
|
| 1035 |
+
f"min={warm_min:.2f}, attn={warm_attn:.2f}, mlp={warm_mlp:.2f}"
|
| 1036 |
+
)
|
| 1037 |
|
| 1038 |
def _excise_sparse(self):
|
| 1039 |
"""Sparse direction surgery — only modifies high-projection rows."""
|
|
|
|
| 1122 |
1. Residual refusal signal (via activation probing)
|
| 1123 |
2. Self-repair / Ouroboros effect (via defense robustness)
|
| 1124 |
3. Triggers additional targeted passes at compensating layers
|
| 1125 |
+
|
| 1126 |
+
KL-gated: stops early if model damage (KL divergence) is getting
|
| 1127 |
+
worse even though refusal persists. This prevents the death spiral
|
| 1128 |
+
where each pass damages the model without removing refusal.
|
| 1129 |
"""
|
| 1130 |
# Run standard verification first
|
| 1131 |
self._verify()
|
| 1132 |
|
| 1133 |
# Check if Ouroboros compensation is needed
|
| 1134 |
refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
|
| 1135 |
+
prev_kl = self._quality_metrics.get("kl_divergence", 0.0)
|
| 1136 |
ouroboros_pass = 0
|
| 1137 |
|
| 1138 |
+
# KL budget: stop if KL exceeds this threshold (model too damaged)
|
| 1139 |
+
kl_ceiling = getattr(self, "kl_budget", 0.5) * 2.0 # 2x budget as hard ceiling
|
| 1140 |
+
|
| 1141 |
while (refusal_rate > self._ouroboros_threshold
|
| 1142 |
and ouroboros_pass < self._max_ouroboros_passes):
|
| 1143 |
ouroboros_pass += 1
|
|
|
|
| 1154 |
self._distill_inner()
|
| 1155 |
self.log(f"Found {len(self._strong_layers)} layers with residual refusal")
|
| 1156 |
|
| 1157 |
+
# Re-excise at the new strong layers using informed strategy
|
| 1158 |
if self._strong_layers:
|
| 1159 |
+
self._excise_informed()
|
| 1160 |
else:
|
| 1161 |
self.log("No strong layers found — stopping Ouroboros compensation")
|
| 1162 |
break
|
|
|
|
| 1164 |
# Re-verify
|
| 1165 |
self._verify()
|
| 1166 |
refusal_rate = self._quality_metrics.get("refusal_rate", 0.0)
|
| 1167 |
+
current_kl = self._quality_metrics.get("kl_divergence", 0.0)
|
| 1168 |
+
self.log(f"After Ouroboros pass {ouroboros_pass}: refusal={refusal_rate:.0%}, KL={current_kl:.4f}")
|
| 1169 |
+
|
| 1170 |
+
# KL-gated early stopping: if KL is rising and exceeds ceiling,
|
| 1171 |
+
# the model is being damaged faster than refusal is being removed.
|
| 1172 |
+
if current_kl > kl_ceiling:
|
| 1173 |
+
self.log(
|
| 1174 |
+
f"KL divergence {current_kl:.4f} exceeds ceiling {kl_ceiling:.4f} — "
|
| 1175 |
+
f"stopping to prevent further model damage"
|
| 1176 |
+
)
|
| 1177 |
+
break
|
| 1178 |
+
if ouroboros_pass > 1 and current_kl > prev_kl * 1.5 and refusal_rate > 0.3:
|
| 1179 |
+
self.log(
|
| 1180 |
+
f"KL rising sharply ({prev_kl:.4f} → {current_kl:.4f}) with "
|
| 1181 |
+
f"refusal still at {refusal_rate:.0%} — stopping (diminishing returns)"
|
| 1182 |
+
)
|
| 1183 |
+
break
|
| 1184 |
+
prev_kl = current_kl
|
| 1185 |
|
| 1186 |
self._report.ouroboros_passes = ouroboros_pass
|
| 1187 |
self._report.final_refusal_rate = refusal_rate
|
|
|
|
| 1225 |
},
|
| 1226 |
"derived_config": {
|
| 1227 |
"n_directions": insights.recommended_n_directions,
|
| 1228 |
+
"direction_method": insights.recommended_direction_method,
|
| 1229 |
"regularization": insights.recommended_regularization,
|
| 1230 |
"refinement_passes": insights.recommended_refinement_passes,
|
| 1231 |
"layers_used": insights.recommended_layers,
|
|
|
|
| 1304 |
|
| 1305 |
lines.append("Derived Configuration:")
|
| 1306 |
lines.append(f" n_directions: {insights.recommended_n_directions}")
|
| 1307 |
+
lines.append(f" direction_method: {insights.recommended_direction_method}")
|
| 1308 |
lines.append(f" regularization: {insights.recommended_regularization}")
|
| 1309 |
lines.append(f" refinement_passes: {insights.recommended_refinement_passes}")
|
| 1310 |
lines.append(f" sparse surgery: {insights.use_sparse_surgery}")
|
scripts/run_benchmark_remote.sh
CHANGED
|
@@ -144,12 +144,18 @@ def _patched_collect(self, layer_modules, prompts, label):
|
|
| 144 |
torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
|
| 145 |
for i in range(torch.cuda.device_count())
|
| 146 |
)
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
max_length = 64
|
| 149 |
-
self.log(f" Low GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
|
| 150 |
-
elif free_gb <
|
| 151 |
max_length = 128
|
| 152 |
-
self.log(f" Tight GPU memory ({free_gb:.1f} GB free), using max_length={max_length}")
|
| 153 |
|
| 154 |
device = self._get_model_device(model)
|
| 155 |
|
|
|
|
| 144 |
torch.cuda.mem_get_info(i)[0] / (1024 ** 3)
|
| 145 |
for i in range(torch.cuda.device_count())
|
| 146 |
)
|
| 147 |
+
# Scale thresholds by model size (baseline: 7B with hidden=4096, 32 layers)
|
| 148 |
+
_h = self.handle.hidden_size if self.handle else 4096
|
| 149 |
+
_l = n_layers if n_layers else 32
|
| 150 |
+
_ms = (_h / 4096) * (_l / 32)
|
| 151 |
+
_tight = max(4.0 * _ms, 0.5)
|
| 152 |
+
_low = max(2.0 * _ms, 0.25)
|
| 153 |
+
if free_gb < _low:
|
| 154 |
max_length = 64
|
| 155 |
+
self.log(f" Low GPU memory ({free_gb:.1f} GB free, threshold {_low:.1f} GB), using max_length={max_length}")
|
| 156 |
+
elif free_gb < _tight:
|
| 157 |
max_length = 128
|
| 158 |
+
self.log(f" Tight GPU memory ({free_gb:.1f} GB free, threshold {_tight:.1f} GB), using max_length={max_length}")
|
| 159 |
|
| 160 |
device = self._get_model_device(model)
|
| 161 |
|
tests/test_informed_pipeline.py
CHANGED
|
@@ -50,7 +50,8 @@ class TestAnalysisInsights:
|
|
| 50 |
assert insights.cluster_count == 0
|
| 51 |
assert insights.direction_persistence == 0.0
|
| 52 |
assert insights.use_sparse_surgery is False
|
| 53 |
-
assert insights.recommended_n_directions ==
|
|
|
|
| 54 |
assert insights.recommended_regularization == 0.0
|
| 55 |
assert insights.recommended_refinement_passes == 2
|
| 56 |
assert insights.recommended_layers == []
|
|
@@ -86,12 +87,16 @@ class TestInformedMethod:
|
|
| 86 |
assert cfg["norm_preserve"] is True
|
| 87 |
assert cfg["project_biases"] is True
|
| 88 |
assert cfg["use_chat_template"] is True
|
| 89 |
-
assert cfg["use_whitened_svd"] is
|
| 90 |
assert cfg["true_iterative_refinement"] is True
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
def test_informed_method_standalone(self):
|
| 93 |
assert INFORMED_METHOD["label"] == "Informed (Analysis-Guided)"
|
| 94 |
-
assert INFORMED_METHOD["n_directions"] ==
|
|
|
|
| 95 |
assert INFORMED_METHOD["norm_preserve"] is True
|
| 96 |
|
| 97 |
|
|
@@ -121,8 +126,10 @@ class TestPipelineInit:
|
|
| 121 |
assert pipeline.norm_preserve is True
|
| 122 |
assert pipeline.project_biases is True
|
| 123 |
assert pipeline.use_chat_template is True
|
| 124 |
-
assert pipeline.use_whitened_svd is
|
| 125 |
assert pipeline.true_iterative_refinement is True
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def test_custom_flags(self):
|
| 128 |
p = InformedAbliterationPipeline(
|
|
@@ -162,17 +169,31 @@ class TestConfigurationDerivation:
|
|
| 162 |
cone_dimensionality=3.5,
|
| 163 |
)
|
| 164 |
p._derive_configuration()
|
| 165 |
-
#
|
| 166 |
assert p.n_directions == 7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
def
|
| 169 |
p = self._make_pipeline_with_insights(
|
| 170 |
cone_is_polyhedral=False,
|
| 171 |
cone_dimensionality=1.0,
|
| 172 |
)
|
| 173 |
p._derive_configuration()
|
| 174 |
-
# Linear
|
| 175 |
-
assert p.n_directions ==
|
|
|
|
|
|
|
| 176 |
|
| 177 |
def test_dpo_zero_regularization(self):
|
| 178 |
p = self._make_pipeline_with_insights(
|
|
@@ -282,6 +303,7 @@ class TestConfigurationDerivation:
|
|
| 282 |
p._derive_configuration()
|
| 283 |
assert p.n_directions > 1
|
| 284 |
assert p.use_whitened_svd is True
|
|
|
|
| 285 |
|
| 286 |
def test_no_whitened_svd_for_single_direction(self):
|
| 287 |
p = self._make_pipeline_with_insights(
|
|
@@ -289,9 +311,9 @@ class TestConfigurationDerivation:
|
|
| 289 |
cone_dimensionality=0.5,
|
| 290 |
)
|
| 291 |
p._derive_configuration()
|
| 292 |
-
# dim 0.5 → max(1, min(4, int(0.5+1))) = 1
|
| 293 |
assert p.n_directions == 1
|
| 294 |
assert p.use_whitened_svd is False
|
|
|
|
| 295 |
|
| 296 |
|
| 297 |
# ---------------------------------------------------------------------------
|
|
@@ -320,10 +342,12 @@ class TestFormatInsights:
|
|
| 320 |
|
| 321 |
def test_format_includes_derived_config(self, insights):
|
| 322 |
insights.recommended_n_directions = 6
|
|
|
|
| 323 |
insights.recommended_regularization = 0.2
|
| 324 |
insights.recommended_refinement_passes = 3
|
| 325 |
text = InformedAbliterationPipeline.format_insights(insights)
|
| 326 |
assert "n_directions: 6" in text
|
|
|
|
| 327 |
assert "regularization: 0.2" in text
|
| 328 |
assert "refinement_passes: 3" in text
|
| 329 |
|
|
@@ -372,14 +396,16 @@ class TestEdgeCases:
|
|
| 372 |
model_name="test",
|
| 373 |
on_log=lambda m: None,
|
| 374 |
)
|
| 375 |
-
# Very high dimensionality
|
| 376 |
p._insights.cone_is_polyhedral = True
|
| 377 |
p._insights.cone_dimensionality = 10.0
|
| 378 |
p._derive_configuration()
|
| 379 |
assert p.n_directions <= 8 # capped
|
|
|
|
| 380 |
|
| 381 |
-
# Very low dimensionality
|
| 382 |
p._insights.cone_is_polyhedral = False
|
| 383 |
p._insights.cone_dimensionality = 0.1
|
| 384 |
p._derive_configuration()
|
| 385 |
-
assert p.n_directions
|
|
|
|
|
|
| 50 |
assert insights.cluster_count == 0
|
| 51 |
assert insights.direction_persistence == 0.0
|
| 52 |
assert insights.use_sparse_surgery is False
|
| 53 |
+
assert insights.recommended_n_directions == 1
|
| 54 |
+
assert insights.recommended_direction_method == "diff_means"
|
| 55 |
assert insights.recommended_regularization == 0.0
|
| 56 |
assert insights.recommended_refinement_passes == 2
|
| 57 |
assert insights.recommended_layers == []
|
|
|
|
| 87 |
assert cfg["norm_preserve"] is True
|
| 88 |
assert cfg["project_biases"] is True
|
| 89 |
assert cfg["use_chat_template"] is True
|
| 90 |
+
assert cfg["use_whitened_svd"] is False
|
| 91 |
assert cfg["true_iterative_refinement"] is True
|
| 92 |
+
assert cfg["n_directions"] == 1
|
| 93 |
+
assert cfg["direction_method"] == "diff_means"
|
| 94 |
+
assert cfg["use_kl_optimization"] is True
|
| 95 |
|
| 96 |
def test_informed_method_standalone(self):
|
| 97 |
assert INFORMED_METHOD["label"] == "Informed (Analysis-Guided)"
|
| 98 |
+
assert INFORMED_METHOD["n_directions"] == 1
|
| 99 |
+
assert INFORMED_METHOD["direction_method"] == "diff_means"
|
| 100 |
assert INFORMED_METHOD["norm_preserve"] is True
|
| 101 |
|
| 102 |
|
|
|
|
| 126 |
assert pipeline.norm_preserve is True
|
| 127 |
assert pipeline.project_biases is True
|
| 128 |
assert pipeline.use_chat_template is True
|
| 129 |
+
assert pipeline.use_whitened_svd is False
|
| 130 |
assert pipeline.true_iterative_refinement is True
|
| 131 |
+
assert pipeline.direction_method == "diff_means"
|
| 132 |
+
assert pipeline.n_directions == 1
|
| 133 |
|
| 134 |
def test_custom_flags(self):
|
| 135 |
p = InformedAbliterationPipeline(
|
|
|
|
| 169 |
cone_dimensionality=3.5,
|
| 170 |
)
|
| 171 |
p._derive_configuration()
|
| 172 |
+
# Clearly polyhedral with dim 3.5 > 2.0 → SVD multi-direction
|
| 173 |
assert p.n_directions == 7
|
| 174 |
+
assert p.direction_method == "svd"
|
| 175 |
+
assert p.use_whitened_svd is True
|
| 176 |
+
|
| 177 |
+
def test_mildly_polyhedral_uses_leace(self):
|
| 178 |
+
p = self._make_pipeline_with_insights(
|
| 179 |
+
cone_is_polyhedral=True,
|
| 180 |
+
cone_dimensionality=1.5,
|
| 181 |
+
)
|
| 182 |
+
p._derive_configuration()
|
| 183 |
+
# Mildly polyhedral (dim <= 2.0) → single LEACE direction
|
| 184 |
+
assert p.n_directions == 1
|
| 185 |
+
assert p.direction_method == "leace"
|
| 186 |
|
| 187 |
+
def test_linear_cone_uses_diff_means(self):
|
| 188 |
p = self._make_pipeline_with_insights(
|
| 189 |
cone_is_polyhedral=False,
|
| 190 |
cone_dimensionality=1.0,
|
| 191 |
)
|
| 192 |
p._derive_configuration()
|
| 193 |
+
# Linear cone → single diff-of-means direction
|
| 194 |
+
assert p.n_directions == 1
|
| 195 |
+
assert p.direction_method == "diff_means"
|
| 196 |
+
assert p.use_whitened_svd is False
|
| 197 |
|
| 198 |
def test_dpo_zero_regularization(self):
|
| 199 |
p = self._make_pipeline_with_insights(
|
|
|
|
| 303 |
p._derive_configuration()
|
| 304 |
assert p.n_directions > 1
|
| 305 |
assert p.use_whitened_svd is True
|
| 306 |
+
assert p.direction_method == "svd"
|
| 307 |
|
| 308 |
def test_no_whitened_svd_for_single_direction(self):
|
| 309 |
p = self._make_pipeline_with_insights(
|
|
|
|
| 311 |
cone_dimensionality=0.5,
|
| 312 |
)
|
| 313 |
p._derive_configuration()
|
|
|
|
| 314 |
assert p.n_directions == 1
|
| 315 |
assert p.use_whitened_svd is False
|
| 316 |
+
assert p.direction_method == "diff_means"
|
| 317 |
|
| 318 |
|
| 319 |
# ---------------------------------------------------------------------------
|
|
|
|
| 342 |
|
| 343 |
def test_format_includes_derived_config(self, insights):
|
| 344 |
insights.recommended_n_directions = 6
|
| 345 |
+
insights.recommended_direction_method = "svd"
|
| 346 |
insights.recommended_regularization = 0.2
|
| 347 |
insights.recommended_refinement_passes = 3
|
| 348 |
text = InformedAbliterationPipeline.format_insights(insights)
|
| 349 |
assert "n_directions: 6" in text
|
| 350 |
+
assert "direction_method: svd" in text
|
| 351 |
assert "regularization: 0.2" in text
|
| 352 |
assert "refinement_passes: 3" in text
|
| 353 |
|
|
|
|
| 396 |
model_name="test",
|
| 397 |
on_log=lambda m: None,
|
| 398 |
)
|
| 399 |
+
# Very high dimensionality → multi-direction SVD, capped at 8
|
| 400 |
p._insights.cone_is_polyhedral = True
|
| 401 |
p._insights.cone_dimensionality = 10.0
|
| 402 |
p._derive_configuration()
|
| 403 |
assert p.n_directions <= 8 # capped
|
| 404 |
+
assert p.direction_method == "svd"
|
| 405 |
|
| 406 |
+
# Very low dimensionality → single diff-of-means
|
| 407 |
p._insights.cone_is_polyhedral = False
|
| 408 |
p._insights.cone_dimensionality = 0.1
|
| 409 |
p._derive_configuration()
|
| 410 |
+
assert p.n_directions == 1
|
| 411 |
+
assert p.direction_method == "diff_means"
|
tests/test_leace.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for LEACE (LEAst-squares Concept Erasure) direction extraction."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from obliteratus.analysis.leace import LEACEExtractor, LEACEResult
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# Fixtures
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
@pytest.fixture
|
| 16 |
+
def extractor():
|
| 17 |
+
return LEACEExtractor(regularization_eps=1e-4)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@pytest.fixture
|
| 21 |
+
def separable_data():
|
| 22 |
+
"""Generate clearly separable harmful/harmless activations."""
|
| 23 |
+
torch.manual_seed(42)
|
| 24 |
+
d = 64
|
| 25 |
+
n = 20
|
| 26 |
+
|
| 27 |
+
# Harmful activations: cluster around [1, 0, 0, ...]
|
| 28 |
+
harmful_dir = torch.zeros(d)
|
| 29 |
+
harmful_dir[0] = 1.0
|
| 30 |
+
harmful = [harmful_dir + 0.1 * torch.randn(d) for _ in range(n)]
|
| 31 |
+
|
| 32 |
+
# Harmless activations: cluster around [-1, 0, 0, ...]
|
| 33 |
+
harmless = [-harmful_dir + 0.1 * torch.randn(d) for _ in range(n)]
|
| 34 |
+
|
| 35 |
+
return harmful, harmless
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@pytest.fixture
|
| 39 |
+
def isotropic_data():
|
| 40 |
+
"""Data where classes differ only in mean, with isotropic variance."""
|
| 41 |
+
torch.manual_seed(123)
|
| 42 |
+
d = 32
|
| 43 |
+
n = 30
|
| 44 |
+
|
| 45 |
+
direction = torch.randn(d)
|
| 46 |
+
direction = direction / direction.norm()
|
| 47 |
+
|
| 48 |
+
harmful = [direction * 2.0 + torch.randn(d) for _ in range(n)]
|
| 49 |
+
harmless = [-direction * 2.0 + torch.randn(d) for _ in range(n)]
|
| 50 |
+
|
| 51 |
+
return harmful, harmless, direction
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# LEACEResult
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
class TestLEACEResult:
|
| 59 |
+
def test_result_fields(self, extractor, separable_data):
|
| 60 |
+
harmful, harmless = separable_data
|
| 61 |
+
result = extractor.extract(harmful, harmless, layer_idx=5)
|
| 62 |
+
|
| 63 |
+
assert isinstance(result, LEACEResult)
|
| 64 |
+
assert result.layer_idx == 5
|
| 65 |
+
assert result.direction.shape == (64,)
|
| 66 |
+
assert result.generalized_eigenvalue > 0
|
| 67 |
+
assert result.within_class_condition > 0
|
| 68 |
+
assert result.mean_diff_norm > 0
|
| 69 |
+
assert result.erasure_loss >= 0
|
| 70 |
+
|
| 71 |
+
def test_direction_is_unit_vector(self, extractor, separable_data):
|
| 72 |
+
harmful, harmless = separable_data
|
| 73 |
+
result = extractor.extract(harmful, harmless)
|
| 74 |
+
norm = result.direction.norm().item()
|
| 75 |
+
assert abs(norm - 1.0) < 1e-5
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Direction quality
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
class TestDirectionQuality:
|
| 83 |
+
def test_finds_true_direction(self, extractor, separable_data):
|
| 84 |
+
"""LEACE should find a direction aligned with the true separation axis."""
|
| 85 |
+
harmful, harmless = separable_data
|
| 86 |
+
result = extractor.extract(harmful, harmless)
|
| 87 |
+
|
| 88 |
+
# True direction is [1, 0, 0, ...]
|
| 89 |
+
true_dir = torch.zeros(64)
|
| 90 |
+
true_dir[0] = 1.0
|
| 91 |
+
|
| 92 |
+
cosine = (result.direction @ true_dir).abs().item()
|
| 93 |
+
# With 20 samples in 64 dims, some noise is expected
|
| 94 |
+
assert cosine > 0.5, f"LEACE direction not aligned with true direction: {cosine}"
|
| 95 |
+
|
| 96 |
+
def test_isotropic_matches_diff_of_means(self, extractor, isotropic_data):
|
| 97 |
+
"""With isotropic noise, LEACE should roughly match diff-of-means."""
|
| 98 |
+
harmful, harmless, true_dir = isotropic_data
|
| 99 |
+
result = extractor.extract(harmful, harmless)
|
| 100 |
+
|
| 101 |
+
# Diff of means
|
| 102 |
+
diff = torch.stack(harmful).mean(0) - torch.stack(harmless).mean(0)
|
| 103 |
+
diff_normalized = diff / diff.norm()
|
| 104 |
+
|
| 105 |
+
cosine = (result.direction @ diff_normalized).abs().item()
|
| 106 |
+
# With finite samples and regularization, some deviation is expected
|
| 107 |
+
assert cosine > 0.5
|
| 108 |
+
|
| 109 |
+
def test_leace_differs_from_diff_means_with_anisotropic_noise(self):
|
| 110 |
+
"""With anisotropic noise, LEACE should find a better direction than diff-of-means."""
|
| 111 |
+
torch.manual_seed(77)
|
| 112 |
+
d = 64
|
| 113 |
+
n = 50
|
| 114 |
+
|
| 115 |
+
# True refusal direction
|
| 116 |
+
true_dir = torch.zeros(d)
|
| 117 |
+
true_dir[0] = 1.0
|
| 118 |
+
|
| 119 |
+
# Add anisotropic noise: high variance in dim 1 (NOT the refusal direction)
|
| 120 |
+
noise_scale = torch.ones(d) * 0.1
|
| 121 |
+
noise_scale[1] = 5.0 # Rogue dimension
|
| 122 |
+
|
| 123 |
+
harmful = [true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)]
|
| 124 |
+
harmless = [-true_dir * 0.5 + torch.randn(d) * noise_scale for _ in range(n)]
|
| 125 |
+
|
| 126 |
+
extractor = LEACEExtractor()
|
| 127 |
+
result = extractor.extract(harmful, harmless)
|
| 128 |
+
|
| 129 |
+
cosine_to_true = (result.direction @ true_dir).abs().item()
|
| 130 |
+
# LEACE should still find the true direction, not be distracted by rogue dim
|
| 131 |
+
assert cosine_to_true > 0.5, f"LEACE distracted by rogue dimension: {cosine_to_true}"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ---------------------------------------------------------------------------
|
| 135 |
+
# Comparison with diff-of-means
|
| 136 |
+
# ---------------------------------------------------------------------------
|
| 137 |
+
|
| 138 |
+
class TestCompareWithDiffOfMeans:
|
| 139 |
+
def test_comparison_output(self, extractor, separable_data):
|
| 140 |
+
harmful, harmless = separable_data
|
| 141 |
+
result = extractor.extract(harmful, harmless)
|
| 142 |
+
|
| 143 |
+
harmful_mean = torch.stack(harmful).mean(0)
|
| 144 |
+
harmless_mean = torch.stack(harmless).mean(0)
|
| 145 |
+
|
| 146 |
+
comparison = LEACEExtractor.compare_with_diff_of_means(
|
| 147 |
+
result, harmful_mean, harmless_mean,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
assert "cosine_similarity" in comparison
|
| 151 |
+
assert "leace_eigenvalue" in comparison
|
| 152 |
+
assert "leace_erasure_loss" in comparison
|
| 153 |
+
assert "within_class_condition" in comparison
|
| 154 |
+
assert "mean_diff_norm" in comparison
|
| 155 |
+
assert 0 <= comparison["cosine_similarity"] <= 1.0
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
# Multi-layer extraction
|
| 160 |
+
# ---------------------------------------------------------------------------
|
| 161 |
+
|
| 162 |
+
class TestMultiLayer:
|
| 163 |
+
def test_extract_all_layers(self, extractor):
|
| 164 |
+
torch.manual_seed(42)
|
| 165 |
+
d = 32
|
| 166 |
+
n = 15
|
| 167 |
+
|
| 168 |
+
harmful_acts = {}
|
| 169 |
+
harmless_acts = {}
|
| 170 |
+
for layer in [0, 1, 2, 5]:
|
| 171 |
+
harmful_acts[layer] = [torch.randn(d) + 0.5 for _ in range(n)]
|
| 172 |
+
harmless_acts[layer] = [torch.randn(d) - 0.5 for _ in range(n)]
|
| 173 |
+
|
| 174 |
+
results = extractor.extract_all_layers(harmful_acts, harmless_acts)
|
| 175 |
+
|
| 176 |
+
assert set(results.keys()) == {0, 1, 2, 5}
|
| 177 |
+
for idx, result in results.items():
|
| 178 |
+
assert result.layer_idx == idx
|
| 179 |
+
assert result.direction.shape == (d,)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# ---------------------------------------------------------------------------
|
| 183 |
+
# Edge cases
|
| 184 |
+
# ---------------------------------------------------------------------------
|
| 185 |
+
|
| 186 |
+
class TestEdgeCases:
|
| 187 |
+
def test_single_sample(self, extractor):
|
| 188 |
+
"""Should handle single sample per class gracefully."""
|
| 189 |
+
d = 32
|
| 190 |
+
harmful = [torch.randn(d)]
|
| 191 |
+
harmless = [torch.randn(d)]
|
| 192 |
+
|
| 193 |
+
result = extractor.extract(harmful, harmless)
|
| 194 |
+
assert result.direction.shape == (d,)
|
| 195 |
+
assert torch.isfinite(result.direction).all()
|
| 196 |
+
|
| 197 |
+
def test_identical_activations(self, extractor):
|
| 198 |
+
"""Should handle case where harmful == harmless."""
|
| 199 |
+
d = 32
|
| 200 |
+
x = torch.randn(d)
|
| 201 |
+
harmful = [x.clone() for _ in range(5)]
|
| 202 |
+
harmless = [x.clone() for _ in range(5)]
|
| 203 |
+
|
| 204 |
+
result = extractor.extract(harmful, harmless)
|
| 205 |
+
assert result.direction.shape == (d,)
|
| 206 |
+
# Direction norm should be ~0 or direction is a fallback
|
| 207 |
+
assert torch.isfinite(result.direction).all()
|
| 208 |
+
|
| 209 |
+
def test_3d_input_squeezed(self, extractor):
|
| 210 |
+
"""Should handle (n, 1, d) shaped inputs."""
|
| 211 |
+
d = 32
|
| 212 |
+
harmful = [torch.randn(1, d) for _ in range(10)]
|
| 213 |
+
harmless = [torch.randn(1, d) for _ in range(10)]
|
| 214 |
+
|
| 215 |
+
result = extractor.extract(harmful, harmless)
|
| 216 |
+
assert result.direction.shape == (d,)
|
| 217 |
+
|
| 218 |
+
def test_shrinkage(self):
|
| 219 |
+
"""Shrinkage should produce valid results."""
|
| 220 |
+
torch.manual_seed(42)
|
| 221 |
+
d = 64
|
| 222 |
+
n = 10 # n < d → need shrinkage
|
| 223 |
+
|
| 224 |
+
harmful = [torch.randn(d) + 0.3 for _ in range(n)]
|
| 225 |
+
harmless = [torch.randn(d) - 0.3 for _ in range(n)]
|
| 226 |
+
|
| 227 |
+
extractor = LEACEExtractor(shrinkage=0.5)
|
| 228 |
+
result = extractor.extract(harmful, harmless)
|
| 229 |
+
assert result.direction.shape == (d,)
|
| 230 |
+
assert torch.isfinite(result.direction).all()
|