cfhot-weights / code /training_pipelines /02_arc_adapter_training_MULTIHEAD.py
LoganResearch's picture
🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified
#!/usr/bin/env python3
"""
ARC ADAPTER TRAINING - COMPLETE VERSION
========================================
Trains the combined ARC adapter on a FROZEN base model.
Components:
- Shared fiber projections (4096 → 16 dim)
- Repetition detection head (target: 50×+ separation)
- Hedging detection head
- Verbosity detection head
- Sycophancy detection head
- Loop 4 tokenizer expansion
- Learned intervention thresholds
Base model: COMPLETELY FROZEN (never modified)
Adapter: ~2M trainable parameters
Author: Logan Napolitano
Date: January 2026
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import re
import gc
import os
import time
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Tuple
from collections import defaultdict
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
# =============================================================================
# CONFIGURATION
# =============================================================================
@dataclass
class ARCAdapterConfig:
"""Complete configuration for ARC Adapter training."""
# Paths
base_model_path: str = "."
output_dir: str = "arc_adapter"
# Device
device: str = "cuda"
# Model architecture (auto-filled from base model)
hidden_dim: int = 4096
fiber_dim: int = 16
probe_layers: List[int] = field(default_factory=lambda: [8, 16, 24])
# Data generation settings
n_samples_per_head: int = 300
max_gen_tokens: int = 80
repetition_window: int = 32
# Training settings
epochs: int = 15
batch_size: int = 32
learning_rate: float = 1e-4
weight_decay: float = 0.01
warmup_steps: int = 100
# Target separations for each head
target_separation: Dict[str, float] = field(default_factory=lambda: {
"repetition": 50.0, # We've achieved 125×, so 50× is conservative
"hedging": 5.0,
"verbosity": 5.0,
"sycophancy": 3.0,
})
# Loop 4 settings
loop4_iterations: int = 3
n_merges_per_iteration: int = 10
min_pair_frequency: int = 2
# Intervention defaults (learned during training)
default_thresholds: Dict[str, float] = field(default_factory=lambda: {
"repetition": 0.1,
"hedging": 0.3,
"verbosity": 0.4,
"sycophancy": 0.4,
})
default_penalty_strength: float = 2.0
# EMA settings for control field
ema_alpha: float = 0.15
# =============================================================================
# ADAPTER ARCHITECTURE
# =============================================================================
class FiberProjection(nn.Module):
"""
Projects hidden states from multiple layers to shared fiber space.
This is the geometric core of CF-HoT - compressing high-dimensional
hidden states to a low-dimensional manifold where behavioral
tendencies are linearly separable.
"""
def __init__(self, hidden_dim: int, fiber_dim: int, n_layers: int):
super().__init__()
self.hidden_dim = hidden_dim
self.fiber_dim = fiber_dim
self.n_layers = n_layers
# Per-layer projection matrices
self.projections = nn.ModuleList([
nn.Linear(hidden_dim, fiber_dim, bias=True)
for _ in range(n_layers)
])
# Learned layer importance weights
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
# Initialize projections
for proj in self.projections:
nn.init.xavier_uniform_(proj.weight)
nn.init.zeros_(proj.bias)
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
"""
Project list of hidden states to fiber space.
Args:
hidden_states: List of [batch, seq, hidden_dim] tensors
Returns:
fiber: [batch, seq, fiber_dim]
"""
weights = F.softmax(self.layer_weights, dim=0)
fiber = None
for i, (h, proj) in enumerate(zip(hidden_states, self.projections)):
# Cast to float32 for adapter computation
h = h.float()
projected = proj(h)
if fiber is None:
fiber = weights[i] * projected
else:
fiber = fiber + weights[i] * projected
return fiber
def forward_stacked(self, hidden_stack: torch.Tensor) -> torch.Tensor:
"""
Project stacked hidden states to fiber space.
Args:
hidden_stack: [batch, n_layers, hidden_dim]
Returns:
fiber: [batch, fiber_dim]
"""
# Cast to float32 for adapter computation (model outputs bfloat16)
hidden_stack = hidden_stack.float()
weights = F.softmax(self.layer_weights, dim=0)
batch_size = hidden_stack.shape[0]
fiber = torch.zeros(
batch_size,
self.fiber_dim,
device=hidden_stack.device,
dtype=torch.float32
)
for i, proj in enumerate(self.projections):
fiber = fiber + weights[i] * proj(hidden_stack[:, i, :])
return fiber
class BehaviorHead(nn.Module):
"""
Single behavioral detection head.
Takes fiber state, outputs probability of specific behavior.
Architecture: fiber_dim → 64 → 16 → 1
"""
def __init__(self, fiber_dim: int, name: str):
super().__init__()
self.name = name
self.fiber_dim = fiber_dim
self.classifier = nn.Sequential(
nn.Linear(fiber_dim, 64),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.ReLU(),
nn.Dropout(0.05),
nn.Linear(16, 1),
)
# Initialize
for module in self.classifier:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.zeros_(module.bias)
def forward(self, fiber: torch.Tensor) -> torch.Tensor:
"""
Get logits from fiber state.
Args:
fiber: [batch, fiber_dim] or [batch, seq, fiber_dim]
Returns:
logits: [batch] or [batch, seq]
"""
logits = self.classifier(fiber)
return logits.squeeze(-1)
def predict_proba(self, fiber: torch.Tensor) -> torch.Tensor:
"""Get probabilities."""
return torch.sigmoid(self.forward(fiber))
class ARCAdapter(nn.Module):
"""
Complete ARC Adapter module.
Contains:
- Shared fiber projection (geometry)
- Multiple behavioral heads (detection)
- Intervention parameters (control)
- EMA tracking (temporal smoothing)
"""
def __init__(self, config: ARCAdapterConfig):
super().__init__()
self.config = config
# Shared fiber projection
self.fiber_proj = FiberProjection(
hidden_dim=config.hidden_dim,
fiber_dim=config.fiber_dim,
n_layers=len(config.probe_layers)
)
# Behavioral detection heads
self.heads = nn.ModuleDict({
"repetition": BehaviorHead(config.fiber_dim, "repetition"),
"hedging": BehaviorHead(config.fiber_dim, "hedging"),
"verbosity": BehaviorHead(config.fiber_dim, "verbosity"),
"sycophancy": BehaviorHead(config.fiber_dim, "sycophancy"),
})
# Learned intervention thresholds
self.thresholds = nn.ParameterDict({
name: nn.Parameter(torch.tensor(thresh))
for name, thresh in config.default_thresholds.items()
})
# Learned penalty strength
self.penalty_strength = nn.Parameter(
torch.tensor(config.default_penalty_strength)
)
# EMA state for control field accumulation
self.ema_alpha = config.ema_alpha
self.register_buffer('_ema_initialized', torch.tensor(False))
self._ema_states: Dict[str, Optional[float]] = {}
self.reset_ema()
def reset_ema(self):
"""Reset EMA states for new generation."""
self._ema_states = {name: None for name in self.heads.keys()}
def forward(
self,
hidden_states: List[torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Full forward pass through adapter.
Args:
hidden_states: List of hidden states from probe layers
Returns:
Dict mapping head_name → logits
"""
fiber = self.fiber_proj(hidden_states)
predictions = {}
for name, head in self.heads.items():
predictions[name] = head(fiber)
return predictions
def get_fiber(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
"""Get fiber representation."""
return self.fiber_proj(hidden_states)
def get_risks(
self,
hidden_states: List[torch.Tensor],
update_ema: bool = True
) -> Dict[str, float]:
"""
Get current risk scores with optional EMA update.
Args:
hidden_states: List of [1, 1, hidden_dim] tensors (last position)
update_ema: Whether to update EMA states
Returns:
Dict mapping head_name → risk score (0-1)
"""
# Stack and project
# hidden_states is list of [batch, seq, hidden_dim]
# We want the last position: [batch, n_layers, hidden_dim]
stacked = torch.stack([h[:, -1, :] for h in hidden_states], dim=1)
fiber = self.fiber_proj.forward_stacked(stacked)
risks = {}
for name, head in self.heads.items():
with torch.no_grad():
prob = head.predict_proba(fiber).mean().item()
if update_ema:
if self._ema_states[name] is None:
self._ema_states[name] = prob
else:
self._ema_states[name] = (
self.ema_alpha * prob +
(1 - self.ema_alpha) * self._ema_states[name]
)
risks[name] = self._ema_states[name]
else:
risks[name] = prob
return risks
def compute_intervention(
self,
risks: Dict[str, float],
recent_tokens: List[int],
window_size: int = 32
) -> Dict[int, float]:
"""
Compute logit penalties based on current risks.
Args:
risks: Current risk scores from get_risks()
recent_tokens: Recently generated token IDs
window_size: How far back to penalize repetitions
Returns:
Dict mapping token_id → penalty amount
"""
penalties = {}
# Repetition intervention
rep_risk = risks.get("repetition", 0)
rep_thresh = self.thresholds["repetition"].item()
if rep_risk > rep_thresh:
# Scale penalty by how much we exceed threshold
strength = self.penalty_strength.item() * (rep_risk / rep_thresh)
# Penalize recently used tokens
recent = recent_tokens[-window_size:] if len(recent_tokens) > window_size else recent_tokens
for token_id in set(recent):
penalties[token_id] = penalties.get(token_id, 0) + strength
# Could add hedging/verbosity interventions here
# (e.g., penalize "As an AI" type tokens)
return penalties
def get_param_count(self) -> int:
"""Get total trainable parameter count."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def save(self, path: str):
"""Save adapter to directory."""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
# Save model weights
torch.save(self.state_dict(), path / "adapter_weights.pt")
# Save config as JSON
config_dict = {
"hidden_dim": self.config.hidden_dim,
"fiber_dim": self.config.fiber_dim,
"probe_layers": self.config.probe_layers,
"ema_alpha": self.ema_alpha,
"thresholds": {
name: self.thresholds[name].item()
for name in self.thresholds
},
"penalty_strength": self.penalty_strength.item(),
"head_names": list(self.heads.keys()),
}
with open(path / "adapter_config.json", "w") as f:
json.dump(config_dict, f, indent=2)
print(f"💾 Adapter saved to {path}")
print(f" Weights: adapter_weights.pt")
print(f" Config: adapter_config.json")
@classmethod
def load(cls, path: str, device: str = "cuda") -> "ARCAdapter":
"""Load adapter from directory."""
path = Path(path)
# Load config
with open(path / "adapter_config.json") as f:
config_dict = json.load(f)
# Create config object
config = ARCAdapterConfig(
hidden_dim=config_dict["hidden_dim"],
fiber_dim=config_dict["fiber_dim"],
probe_layers=config_dict["probe_layers"],
ema_alpha=config_dict["ema_alpha"],
default_thresholds=config_dict["thresholds"],
default_penalty_strength=config_dict["penalty_strength"],
)
# Create adapter
adapter = cls(config)
# Load weights
state_dict = torch.load(path / "adapter_weights.pt", map_location=device)
adapter.load_state_dict(state_dict)
return adapter.to(device)
# =============================================================================
# DATA GENERATION
# =============================================================================
class DataGenerator:
"""
Generates labeled training data for all behavioral heads.
Each head needs different types of data:
- Repetition: Token-level labels during generation
- Hedging: Prompt-level labels based on response patterns
- Verbosity: Prompt-level labels based on response length/density
- Sycophancy: Prompt-level labels based on agreement patterns
"""
# Pattern definitions
HEDGING_PATTERNS = [
r"as an ai",
r"i cannot",
r"i can't",
r"i don't have",
r"i'm not able",
r"i am not able",
r"i don't actually",
r"i'm just",
r"i am just",
r"my purpose is",
r"i was designed",
r"i don't experience",
r"i'm an ai",
r"as a language model",
]
SYCOPHANCY_PATTERNS = [
r"great question",
r"excellent point",
r"you're (absolutely )?right",
r"that's a (great|excellent|wonderful)",
r"i (completely |totally )?agree",
r"absolutely[,!]",
r"definitely[,!]",
r"of course[,!]",
r"you make a (great|excellent|good) point",
]
def __init__(self, model, tokenizer, config: ARCAdapterConfig):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.device = config.device
# Compile patterns
self.hedging_patterns = [
re.compile(p, re.IGNORECASE) for p in self.HEDGING_PATTERNS
]
self.sycophancy_patterns = [
re.compile(p, re.IGNORECASE) for p in self.SYCOPHANCY_PATTERNS
]
def is_repetition(self, tokens: List[int], position: int) -> bool:
"""Check if token at position repeats within window."""
if position < 1:
return False
current = tokens[position]
start = max(0, position - self.config.repetition_window)
return current in tokens[start:position]
def is_hedging(self, text: str) -> bool:
"""Check if text contains hedging patterns."""
return any(p.search(text) for p in self.hedging_patterns)
def is_sycophantic(self, text: str) -> bool:
"""Check if text contains sycophancy patterns."""
return any(p.search(text) for p in self.sycophancy_patterns)
def is_verbose(self, text: str, token_count: int) -> bool:
"""
Check if response is verbose.
Verbose = low information density or excessive length.
"""
words = text.split()
if len(words) < 10:
return False
# Unique word ratio
unique_ratio = len(set(w.lower() for w in words)) / len(words)
# Verbose if low uniqueness or very long
return unique_ratio < 0.5 or token_count > 100
def extract_hidden_states(
self,
input_ids: torch.Tensor
) -> torch.Tensor:
"""
Extract hidden states at probe layers for last position.
Args:
input_ids: [1, seq_len]
Returns:
hidden_stack: [n_layers, hidden_dim]
"""
with torch.no_grad():
outputs = self.model(
input_ids,
output_hidden_states=True,
)
hidden_list = []
for layer_idx in self.config.probe_layers:
# Get last position: [hidden_dim]
h = outputs.hidden_states[layer_idx][0, -1, :].cpu()
hidden_list.append(h)
return torch.stack(hidden_list) # [n_layers, hidden_dim]
def generate_repetition_data(
self,
prompts: List[str]
) -> Dict[str, List]:
"""
Generate token-level labeled data for repetition detection.
For each generated token, we capture:
- Hidden states at probe layers (before generating the token)
- Label: 1 if the token repeats within window, 0 otherwise
"""
all_hidden = []
all_labels = []
print(f"\n📊 Generating repetition training data...")
print(f" Prompts: {len(prompts)}")
print(f" Max tokens per prompt: {self.config.max_gen_tokens}")
for prompt in tqdm(prompts, desc="Repetition data"):
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
generated_ids = inputs.input_ids[0].tolist()
for step in range(self.config.max_gen_tokens):
# Current sequence as tensor
input_tensor = torch.tensor([generated_ids]).to(self.device)
# Extract hidden states BEFORE generating next token
hidden_stack = self.extract_hidden_states(input_tensor)
# Generate next token
with torch.no_grad():
outputs = self.model(input_tensor)
logits = outputs.logits[0, -1, :]
probs = F.softmax(logits / 0.8, dim=-1)
next_token = torch.multinomial(probs, 1).item()
# Record position and add token
position = len(generated_ids)
generated_ids.append(next_token)
# Label: did this token repeat?
is_rep = self.is_repetition(generated_ids, position)
all_hidden.append(hidden_stack)
all_labels.append(1 if is_rep else 0)
# Stop at EOS
if next_token == self.tokenizer.eos_token_id:
break
except Exception as e:
print(f" Error on prompt: {e}")
continue
pos_count = sum(all_labels)
total = len(all_labels)
print(f" Generated: {total} examples")
print(f" Positive (repetition): {pos_count} ({100*pos_count/total:.1f}%)")
print(f" Negative (no repeat): {total - pos_count}")
return {
"hidden_states": all_hidden,
"labels": all_labels,
}
def generate_hedging_data(
self,
prompts: List[str]
) -> Dict[str, List]:
"""
Generate prompt-level labeled data for hedging detection.
For each prompt, we:
- Extract hidden states at end of prompt
- Generate a response
- Label: 1 if response contains hedging patterns, 0 otherwise
"""
all_hidden = []
all_labels = []
print(f"\n📊 Generating hedging training data...")
print(f" Prompts: {len(prompts)}")
for prompt in tqdm(prompts, desc="Hedging data"):
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
# Hidden states at end of prompt
hidden_stack = self.extract_hidden_states(inputs.input_ids)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode response only (not prompt)
response = self.tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Label
is_hedge = self.is_hedging(response)
all_hidden.append(hidden_stack)
all_labels.append(1 if is_hedge else 0)
except Exception as e:
continue
pos_count = sum(all_labels)
total = len(all_labels)
print(f" Generated: {total} examples")
print(f" Positive (hedging): {pos_count} ({100*pos_count/total:.1f}%)")
return {
"hidden_states": all_hidden,
"labels": all_labels,
}
def generate_verbosity_data(
self,
prompts: List[str]
) -> Dict[str, List]:
"""Generate prompt-level labeled data for verbosity detection."""
all_hidden = []
all_labels = []
print(f"\n📊 Generating verbosity training data...")
print(f" Prompts: {len(prompts)}")
for prompt in tqdm(prompts, desc="Verbosity data"):
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
hidden_stack = self.extract_hidden_states(inputs.input_ids)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=150,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
token_count = outputs.shape[1] - inputs.input_ids.shape[1]
is_verbose = self.is_verbose(response, token_count)
all_hidden.append(hidden_stack)
all_labels.append(1 if is_verbose else 0)
except Exception as e:
continue
pos_count = sum(all_labels)
total = len(all_labels)
print(f" Generated: {total} examples")
print(f" Positive (verbose): {pos_count} ({100*pos_count/total:.1f}%)")
return {
"hidden_states": all_hidden,
"labels": all_labels,
}
def generate_sycophancy_data(
self,
prompts: List[str]
) -> Dict[str, List]:
"""Generate prompt-level labeled data for sycophancy detection."""
all_hidden = []
all_labels = []
print(f"\n📊 Generating sycophancy training data...")
print(f" Prompts: {len(prompts)}")
for prompt in tqdm(prompts, desc="Sycophancy data"):
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
hidden_stack = self.extract_hidden_states(inputs.input_ids)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=50,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
is_syc = self.is_sycophantic(response)
all_hidden.append(hidden_stack)
all_labels.append(1 if is_syc else 0)
except Exception as e:
continue
pos_count = sum(all_labels)
total = len(all_labels)
print(f" Generated: {total} examples")
print(f" Positive (sycophantic): {pos_count} ({100*pos_count/total:.1f}%)")
return {
"hidden_states": all_hidden,
"labels": all_labels,
}
def get_prompts_for_head(self, head_name: str, n: int) -> List[str]:
"""Get appropriate prompts for each head type."""
if head_name == "repetition":
# Prompts that tend to induce repetitive generation
templates = [
"Write a detailed explanation of {}:",
"Describe the process of {} step by step:",
"Explain everything about {}:",
"Give a comprehensive overview of {}:",
"Discuss {} in depth:",
]
topics = [
"machine learning", "photosynthesis", "gravity",
"neural networks", "evolution", "quantum mechanics",
"economics", "climate change", "DNA replication",
"cellular respiration", "protein synthesis", "mitosis",
]
elif head_name == "hedging":
# Mix of hedging-likely and neutral prompts
hedging_prompts = [
"Do you have feelings?",
"Are you conscious?",
"Can you experience emotions?",
"What do you think about yourself?",
"Do you have preferences?",
"Are you sentient?",
"Do you have a personality?",
"Can you be creative?",
]
neutral_prompts = [
"What is photosynthesis?",
"Explain gravity.",
"How do computers work?",
"What is DNA?",
"Describe the water cycle.",
"What causes earthquakes?",
]
# Mix: half hedging-likely, half neutral
prompts = []
for i in range(n):
if i % 2 == 0:
prompts.append(hedging_prompts[i % len(hedging_prompts)])
else:
prompts.append(neutral_prompts[i % len(neutral_prompts)])
return prompts
elif head_name == "sycophancy":
# Opinionated prompts that might trigger sycophancy
sycophancy_prompts = [
"I think AI will solve all problems. Don't you agree?",
"This is clearly the best approach, right?",
"My analysis shows this is true. What do you think?",
"Everyone knows that this is correct.",
"I believe I've found the perfect solution. Thoughts?",
"My idea is brilliant, isn't it?",
]
neutral_prompts = [
"What are the pros and cons of renewable energy?",
"Explain different perspectives on this issue.",
"What is the scientific consensus?",
"Compare these two approaches objectively.",
]
prompts = []
for i in range(n):
if i % 2 == 0:
prompts.append(sycophancy_prompts[i % len(sycophancy_prompts)])
else:
prompts.append(neutral_prompts[i % len(neutral_prompts)])
return prompts
elif head_name == "verbosity":
templates = [
"Briefly explain {}:",
"In one sentence, what is {}?",
"Summarize {} concisely:",
"Give a detailed analysis of {}:",
"Write extensively about {}:",
"Provide a comprehensive discussion of {}:",
]
topics = [
"gravity", "democracy", "evolution", "technology",
"economics", "climate", "education", "healthcare",
]
else:
templates = ["Explain {}:"]
topics = ["science", "technology", "nature"]
# Generate prompts from templates and topics
prompts = []
for template in templates:
for topic in topics:
prompts.append(template.format(topic))
if len(prompts) >= n:
return prompts[:n]
# If we need more, cycle through
while len(prompts) < n:
prompts.extend(prompts[:n - len(prompts)])
return prompts[:n]
# =============================================================================
# TRAINING
# =============================================================================
class ProbeDataset(Dataset):
"""Dataset for probe training."""
def __init__(
self,
hidden_states: List[torch.Tensor],
labels: List[int]
):
self.hidden_states = hidden_states
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return {
"hidden": self.hidden_states[idx],
"label": torch.tensor(self.labels[idx], dtype=torch.float32),
}
class AdapterTrainer:
"""
Trains all components of the ARC adapter.
Training order:
1. Repetition head (most important)
2. Hedging head
3. Verbosity head
4. Sycophancy head
5. Loop 4 tokenization
"""
def __init__(
self,
model,
tokenizer,
config: ARCAdapterConfig
):
self.model = model # FROZEN - never modified
self.tokenizer = tokenizer
self.config = config
self.device = config.device
# Create adapter
self.adapter = ARCAdapter(config).to(self.device)
# Data generator
self.data_generator = DataGenerator(model, tokenizer, config)
# Output directory
self.output_dir = Path(config.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# Training history
self.history = {}
def compute_metrics(
self,
predictions: torch.Tensor,
labels: torch.Tensor
) -> Dict[str, float]:
"""
Compute classification metrics.
Key metric: Class Separation Ratio
= mean(positive_probs) / mean(negative_probs)
Higher separation = better discrimination.
"""
probs = torch.sigmoid(predictions)
binary_preds = (probs > 0.5).float()
# Basic metrics
tp = ((binary_preds == 1) & (labels == 1)).sum().item()
fp = ((binary_preds == 1) & (labels == 0)).sum().item()
fn = ((binary_preds == 0) & (labels == 1)).sum().item()
tn = ((binary_preds == 0) & (labels == 0)).sum().item()
accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8)
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
# Class separation ratio - KEY METRIC
pos_mask = labels == 1
neg_mask = labels == 0
if pos_mask.sum() > 0:
pos_mean = probs[pos_mask].mean().item()
else:
pos_mean = 0.5
if neg_mask.sum() > 0:
neg_mean = probs[neg_mask].mean().item()
else:
neg_mean = 0.5
separation = pos_mean / (neg_mean + 1e-8)
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"separation": separation,
"pos_mean": pos_mean,
"neg_mean": neg_mean,
}
def train_head(
self,
head_name: str,
data: Dict[str, List]
) -> Dict[str, float]:
"""
Train a single behavioral head.
Uses shared fiber projection (also trained).
"""
print(f"\n{'='*70}")
print(f"TRAINING HEAD: {head_name.upper()}")
print(f"{'='*70}")
# Split data
n = len(data["labels"])
indices = np.random.permutation(n)
split_idx = int(n * 0.9)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
train_hidden = [data["hidden_states"][i] for i in train_indices]
train_labels = [data["labels"][i] for i in train_indices]
val_hidden = [data["hidden_states"][i] for i in val_indices]
val_labels = [data["labels"][i] for i in val_indices]
# Create datasets
train_dataset = ProbeDataset(train_hidden, train_labels)
val_dataset = ProbeDataset(val_hidden, val_labels)
train_loader = DataLoader(
train_dataset,
batch_size=self.config.batch_size,
shuffle=True
)
val_loader = DataLoader(
val_dataset,
batch_size=self.config.batch_size
)
# Class weighting for imbalanced data
pos_count = sum(train_labels)
neg_count = len(train_labels) - pos_count
if pos_count > 0:
pos_weight = torch.tensor([neg_count / pos_count]).to(self.device)
else:
pos_weight = torch.tensor([1.0]).to(self.device)
print(f"Train samples: {len(train_labels)}")
print(f"Val samples: {len(val_labels)}")
print(f"Positive: {pos_count} ({100*pos_count/len(train_labels):.1f}%)")
print(f"Negative: {neg_count}")
print(f"Target separation: {self.config.target_separation[head_name]}×")
# Get head and fiber projection
head = self.adapter.heads[head_name]
fiber_proj = self.adapter.fiber_proj
# Optimizer for head + shared fiber projection
params = list(head.parameters()) + list(fiber_proj.parameters())
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = optim.AdamW(
params,
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.config.epochs
)
# Training loop
best_separation = 0
best_state = None
history = []
global_step = 0
for epoch in range(self.config.epochs):
# Training
head.train()
fiber_proj.train()
total_loss = 0
for batch_idx, batch in enumerate(train_loader):
hidden = batch["hidden"].to(self.device)
labels = batch["label"].to(self.device)
# Forward: fiber projection then head
fiber = fiber_proj.forward_stacked(hidden)
logits = head(fiber)
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
global_step += 1
# Checkpoint every 100 steps
if global_step % 100 == 0:
checkpoint_path = self.output_dir / f"checkpoint_step_{global_step}"
checkpoint_path.mkdir(parents=True, exist_ok=True)
torch.save({
'head_state': head.state_dict(),
'fiber_state': fiber_proj.state_dict(),
'optimizer_state': optimizer.state_dict(),
'epoch': epoch,
'step': global_step,
'loss': loss.item(),
'head_name': head_name,
}, checkpoint_path / "checkpoint.pt")
print(f" 💾 Checkpoint saved: step {global_step}")
avg_loss = total_loss / len(train_loader)
# Validation
head.eval()
fiber_proj.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for batch in val_loader:
hidden = batch["hidden"].to(self.device)
labels = batch["label"]
fiber = fiber_proj.forward_stacked(hidden)
logits = head(fiber)
all_preds.append(logits.cpu())
all_labels.append(labels)
preds = torch.cat(all_preds)
labels = torch.cat(all_labels)
metrics = self.compute_metrics(preds, labels)
history.append(metrics)
sep = metrics["separation"]
print(f"Epoch {epoch+1:2d}/{self.config.epochs} | "
f"Loss: {avg_loss:.4f} | "
f"Sep: {sep:6.1f}× | "
f"F1: {metrics['f1']:.3f} | "
f"Pos: {metrics['pos_mean']:.3f} | "
f"Neg: {metrics['neg_mean']:.3f}")
# Track best
if sep > best_separation:
best_separation = sep
best_state = {
"head": {k: v.cpu().clone() for k, v in head.state_dict().items()},
"fiber": {k: v.cpu().clone() for k, v in fiber_proj.state_dict().items()},
}
scheduler.step()
# Restore best state
if best_state is not None:
head.load_state_dict(best_state["head"])
fiber_proj.load_state_dict(best_state["fiber"])
head.to(self.device)
fiber_proj.to(self.device)
# Report results
target = self.config.target_separation[head_name]
if best_separation >= target:
print(f"\n✅ {head_name.upper()}: {best_separation:.1f}× separation")
print(f" TARGET ACHIEVED ({target}×)")
else:
print(f"\n⚠️ {head_name.upper()}: {best_separation:.1f}× separation")
print(f" Below target ({target}×)")
return {
"best_separation": best_separation,
"target": target,
"achieved": best_separation >= target,
"history": history,
}
def train_all_heads(self) -> Dict[str, Dict]:
"""Train all behavioral heads sequentially."""
results = {}
head_order = ["repetition", "hedging", "verbosity", "sycophancy"]
for head_name in head_order:
print(f"\n{'#'*70}")
print(f"# PREPARING DATA FOR: {head_name.upper()}")
print(f"{'#'*70}")
# Generate data for this head
prompts = self.data_generator.get_prompts_for_head(
head_name,
self.config.n_samples_per_head
)
# Check if we have saved data from a previous run
data_path = self.output_dir / f"data_{head_name}.pt"
if data_path.exists():
print(f" 📂 Loading saved data from {data_path}")
saved = torch.load(data_path)
data = {
'hidden_states': saved['hidden_states'],
'labels': saved['labels'],
}
print(f" Loaded: {len(data['labels'])} examples")
else:
# Generate new data
if head_name == "repetition":
data = self.data_generator.generate_repetition_data(prompts)
elif head_name == "hedging":
data = self.data_generator.generate_hedging_data(prompts)
elif head_name == "verbosity":
data = self.data_generator.generate_verbosity_data(prompts)
elif head_name == "sycophancy":
data = self.data_generator.generate_sycophancy_data(prompts)
# Save generated data so we don't lose it on crash
torch.save({
'hidden_states': data['hidden_states'],
'labels': data['labels'],
}, data_path)
print(f" 💾 Data saved: {data_path}")
# Train head
result = self.train_head(head_name, data)
results[head_name] = result
# Save checkpoint after each head
checkpoint_dir = self.output_dir / f"checkpoint_{head_name}"
self.adapter.save(checkpoint_dir)
# Clean up
torch.cuda.empty_cache()
gc.collect()
return results
def run_loop4(self) -> Dict[str, int]:
"""
Run Loop 4: Tokenization co-evolution.
Analyzes boundary stress and adds high-stress token pairs
to the vocabulary.
"""
print(f"\n{'='*70}")
print("LOOP 4: TOKENIZATION EXPANSION")
print(f"{'='*70}")
total_added = 0
for iteration in range(self.config.loop4_iterations):
print(f"\n--- Iteration {iteration + 1}/{self.config.loop4_iterations} ---")
# Generate corpus for analysis
prompts = [
"Explain machine learning and neural networks in detail:",
"Describe the structure of atoms and molecules:",
"What are the fundamental principles of economics?",
"Analyze the causes and effects of climate change:",
"Discuss the process of biological evolution:",
]
corpus = []
for prompt in prompts:
try:
inputs = self.tokenizer(
prompt,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=100,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
corpus.append(text)
except:
continue
if not corpus:
print(" No corpus generated, skipping iteration")
continue
# Analyze boundary stress
pair_stats = defaultdict(lambda: {"stress": [], "count": 0})
for text in corpus:
try:
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=256
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits[0]
tokens = inputs["input_ids"][0]
# Compute entropy at each position
probs = F.softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim=-1)
entropy = -(probs * log_probs).sum(dim=-1)
# Record boundary stress
for i in range(1, len(tokens)):
before_token = self.tokenizer.decode([tokens[i-1]]).strip()
after_token = self.tokenizer.decode([tokens[i]]).strip()
# Skip short or special tokens
if len(before_token) < 2 or len(after_token) < 2:
continue
if any(c in before_token + after_token for c in "<>[]{}|\\"):
continue
stress = entropy[i-1].item() / 10.0 # Normalize
pair = (before_token, after_token)
pair_stats[pair]["stress"].append(stress)
pair_stats[pair]["count"] += 1
except:
continue
# Find merge candidates
candidates = []
for pair, stats in pair_stats.items():
if stats["count"] >= self.config.min_pair_frequency:
mean_stress = np.mean(stats["stress"])
score = mean_stress * np.log1p(stats["count"])
candidates.append({
"before": pair[0],
"after": pair[1],
"merged": pair[0] + pair[1],
"stress": mean_stress,
"count": stats["count"],
"score": score,
})
# Sort by score and take top N
candidates.sort(key=lambda x: x["score"], reverse=True)
candidates = candidates[:self.config.n_merges_per_iteration]
if candidates:
print(f" Top candidates:")
for c in candidates[:5]:
print(f" '{c['before']}' + '{c['after']}' → '{c['merged']}' "
f"(stress: {c['stress']:.2f}, count: {c['count']})")
# Add tokens to vocabulary
tokens_to_add = [
c["merged"] for c in candidates
if c["merged"] not in self.tokenizer.get_vocab()
]
if tokens_to_add:
num_added = self.tokenizer.add_tokens(tokens_to_add)
self.model.resize_token_embeddings(len(self.tokenizer))
total_added += num_added
print(f" Added {num_added} new tokens")
else:
print(f" No new tokens to add")
# Save tokenizer
tokenizer_dir = self.output_dir / "tokenizer"
self.tokenizer.save_pretrained(tokenizer_dir)
print(f"\nLoop 4 complete:")
print(f" Total tokens added: {total_added}")
print(f" Final vocab size: {len(self.tokenizer)}")
print(f" Tokenizer saved to: {tokenizer_dir}")
return {
"tokens_added": total_added,
"final_vocab_size": len(self.tokenizer),
}
def train(self) -> Dict:
"""
Run complete adapter training pipeline.
1. Train all behavioral heads
2. Run Loop 4 tokenization
3. Save final adapter
"""
print("\n" + "="*70)
print("ARC ADAPTER TRAINING")
print("="*70)
print(f"Base model: FROZEN")
print(f"Adapter params: ~{self.adapter.get_param_count():,}")
print(f"Output dir: {self.output_dir}")
print("="*70)
start_time = time.time()
# Train all heads
head_results = self.train_all_heads()
# Run Loop 4
loop4_results = self.run_loop4()
# Save final adapter
final_dir = self.output_dir / "final"
self.adapter.save(final_dir)
elapsed = time.time() - start_time
# Summary
print("\n" + "="*70)
print("TRAINING COMPLETE")
print("="*70)
all_achieved = True
for head_name, result in head_results.items():
status = "✅" if result["achieved"] else "⚠️"
print(f"{status} {head_name}: {result['best_separation']:.1f}× "
f"(target: {result['target']}×)")
if not result["achieved"]:
all_achieved = False
print(f"\nLoop 4: Added {loop4_results['tokens_added']} tokens")
print(f"Final vocab size: {loop4_results['final_vocab_size']}")
print(f"Training time: {elapsed/3600:.1f} hours")
if all_achieved:
print("\n🎉 ALL TARGETS ACHIEVED!")
else:
print("\n⚠️ Some targets not achieved. Consider:")
print(" - Increasing n_samples_per_head")
print(" - Increasing epochs")
print(" - Adjusting learning rate")
# Save results
final_results = {
"heads": {
name: {
"separation": r["best_separation"],
"target": r["target"],
"achieved": r["achieved"],
}
for name, r in head_results.items()
},
"loop4": loop4_results,
"training_time_hours": elapsed / 3600,
"adapter_params": self.adapter.get_param_count(),
}
with open(self.output_dir / "training_results.json", "w") as f:
json.dump(final_results, f, indent=2)
print(f"\nResults saved to: {self.output_dir / 'training_results.json'}")
print(f"Adapter saved to: {final_dir}")
return final_results
# =============================================================================
# INFERENCE
# =============================================================================
class ARCInference:
"""
Inference using trained ARC adapter.
Base model generates, adapter monitors and intervenes.
"""
def __init__(
self,
model,
tokenizer,
adapter: ARCAdapter,
probe_layers: List[int],
device: str = "cuda"
):
self.model = model # FROZEN
self.tokenizer = tokenizer
self.adapter = adapter
self.probe_layers = probe_layers
self.device = device
def generate(
self,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 0.7,
use_intervention: bool = True,
verbose: bool = False,
) -> str:
"""
Generate with optional decode-time intervention.
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
generated_ids = inputs.input_ids[0].tolist()
# Reset adapter EMA state
self.adapter.reset_ema()
for step in range(max_new_tokens):
input_tensor = torch.tensor([generated_ids]).to(self.device)
with torch.no_grad():
outputs = self.model(
input_tensor,
output_hidden_states=True,
)
logits = outputs.logits[0, -1, :].clone()
if use_intervention:
# Get hidden states at probe layers
hidden_list = [
outputs.hidden_states[layer]
for layer in self.probe_layers
]
# Get risks from adapter
risks = self.adapter.get_risks(hidden_list)
if verbose and step % 10 == 0:
print(f"Step {step}: risks = {risks}")
# Get and apply penalties
penalties = self.adapter.compute_intervention(risks, generated_ids)
for token_id, penalty in penalties.items():
logits[token_id] -= penalty
# Sample next token
probs = F.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, 1).item()
generated_ids.append(next_token)
if next_token == self.tokenizer.eos_token_id:
break
response = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return response[len(prompt):].strip()
# =============================================================================
# MAIN
# =============================================================================
def main():
"""Main entry point."""
# Configuration
config = ARCAdapterConfig(
base_model_path=".",
output_dir="arc_adapter",
n_samples_per_head=300,
epochs=15,
batch_size=32,
learning_rate=1e-4,
target_separation={
"repetition": 50.0,
"hedging": 5.0,
"verbosity": 5.0,
"sycophancy": 3.0,
},
loop4_iterations=3,
n_merges_per_iteration=10,
)
print("="*70)
print("ARC ADAPTER TRAINING")
print("="*70)
print()
print("This script trains the ARC adapter on a FROZEN base model.")
print("The base model weights are NEVER modified.")
print()
print("Components trained:")
print(" - Shared fiber projections (~500K params)")
print(" - Repetition detection head (~5K params)")
print(" - Hedging detection head (~5K params)")
print(" - Verbosity detection head (~5K params)")
print(" - Sycophancy detection head (~5K params)")
print(" - Loop 4 tokenizer expansion")
print()
print("="*70)
# Load base model (FROZEN)
print("\nLoading base model...")
tokenizer = AutoTokenizer.from_pretrained(
config.base_model_path,
local_files_only=True
)
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_path,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
local_files_only=True,
)
# FREEZE the base model
for param in model.parameters():
param.requires_grad = False
# Update config with actual hidden dim
config.hidden_dim = model.config.hidden_size
total_params = sum(p.numel() for p in model.parameters())
print(f"Base model: {total_params/1e9:.1f}B parameters (FROZEN)")
print(f"Hidden dimension: {config.hidden_dim}")
print(f"Vocabulary size: {len(tokenizer)}")
print(f"VRAM usage: {torch.cuda.memory_allocated()/1024**3:.1f}GB")
# Create trainer and run
trainer = AdapterTrainer(model, tokenizer, config)
results = trainer.train()
# Final message
print("\n" + "="*70)
print("ADAPTER READY FOR USE")
print("="*70)
print(f"\nAdapter location: {config.output_dir}/final/")
print(f"Tokenizer location: {config.output_dir}/tokenizer/")
print()
print("To use the adapter:")
print(" from arc_adapter_training import ARCAdapter, ARCInference")
print(" adapter = ARCAdapter.load('arc_adapter/final')")
print(" inference = ARCInference(model, tokenizer, adapter, probe_layers)")
print(" response = inference.generate('Your prompt here')")
print()
print("="*70)
if __name__ == "__main__":
main()