🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified | #!/usr/bin/env python3 | |
| """ | |
| ARC ADAPTER TRAINING - COMPLETE VERSION | |
| ======================================== | |
| Trains the combined ARC adapter on a FROZEN base model. | |
| Components: | |
| - Shared fiber projections (4096 → 16 dim) | |
| - Repetition detection head (target: 50×+ separation) | |
| - Hedging detection head | |
| - Verbosity detection head | |
| - Sycophancy detection head | |
| - Loop 4 tokenizer expansion | |
| - Learned intervention thresholds | |
| Base model: COMPLETELY FROZEN (never modified) | |
| Adapter: ~2M trainable parameters | |
| Author: Logan Napolitano | |
| Date: January 2026 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import numpy as np | |
| import json | |
| import re | |
| import gc | |
| import os | |
| import time | |
| from pathlib import Path | |
| from dataclasses import dataclass, field | |
| from typing import List, Dict, Optional, Tuple | |
| from collections import defaultdict | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from tqdm import tqdm | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| class ARCAdapterConfig: | |
| """Complete configuration for ARC Adapter training.""" | |
| # Paths | |
| base_model_path: str = "." | |
| output_dir: str = "arc_adapter" | |
| # Device | |
| device: str = "cuda" | |
| # Model architecture (auto-filled from base model) | |
| hidden_dim: int = 4096 | |
| fiber_dim: int = 16 | |
| probe_layers: List[int] = field(default_factory=lambda: [8, 16, 24]) | |
| # Data generation settings | |
| n_samples_per_head: int = 300 | |
| max_gen_tokens: int = 80 | |
| repetition_window: int = 32 | |
| # Training settings | |
| epochs: int = 15 | |
| batch_size: int = 32 | |
| learning_rate: float = 1e-4 | |
| weight_decay: float = 0.01 | |
| warmup_steps: int = 100 | |
| # Target separations for each head | |
| target_separation: Dict[str, float] = field(default_factory=lambda: { | |
| "repetition": 50.0, # We've achieved 125×, so 50× is conservative | |
| "hedging": 5.0, | |
| "verbosity": 5.0, | |
| "sycophancy": 3.0, | |
| }) | |
| # Loop 4 settings | |
| loop4_iterations: int = 3 | |
| n_merges_per_iteration: int = 10 | |
| min_pair_frequency: int = 2 | |
| # Intervention defaults (learned during training) | |
| default_thresholds: Dict[str, float] = field(default_factory=lambda: { | |
| "repetition": 0.1, | |
| "hedging": 0.3, | |
| "verbosity": 0.4, | |
| "sycophancy": 0.4, | |
| }) | |
| default_penalty_strength: float = 2.0 | |
| # EMA settings for control field | |
| ema_alpha: float = 0.15 | |
| # ============================================================================= | |
| # ADAPTER ARCHITECTURE | |
| # ============================================================================= | |
| class FiberProjection(nn.Module): | |
| """ | |
| Projects hidden states from multiple layers to shared fiber space. | |
| This is the geometric core of CF-HoT - compressing high-dimensional | |
| hidden states to a low-dimensional manifold where behavioral | |
| tendencies are linearly separable. | |
| """ | |
| def __init__(self, hidden_dim: int, fiber_dim: int, n_layers: int): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.fiber_dim = fiber_dim | |
| self.n_layers = n_layers | |
| # Per-layer projection matrices | |
| self.projections = nn.ModuleList([ | |
| nn.Linear(hidden_dim, fiber_dim, bias=True) | |
| for _ in range(n_layers) | |
| ]) | |
| # Learned layer importance weights | |
| self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) | |
| # Initialize projections | |
| for proj in self.projections: | |
| nn.init.xavier_uniform_(proj.weight) | |
| nn.init.zeros_(proj.bias) | |
| def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Project list of hidden states to fiber space. | |
| Args: | |
| hidden_states: List of [batch, seq, hidden_dim] tensors | |
| Returns: | |
| fiber: [batch, seq, fiber_dim] | |
| """ | |
| weights = F.softmax(self.layer_weights, dim=0) | |
| fiber = None | |
| for i, (h, proj) in enumerate(zip(hidden_states, self.projections)): | |
| # Cast to float32 for adapter computation | |
| h = h.float() | |
| projected = proj(h) | |
| if fiber is None: | |
| fiber = weights[i] * projected | |
| else: | |
| fiber = fiber + weights[i] * projected | |
| return fiber | |
| def forward_stacked(self, hidden_stack: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Project stacked hidden states to fiber space. | |
| Args: | |
| hidden_stack: [batch, n_layers, hidden_dim] | |
| Returns: | |
| fiber: [batch, fiber_dim] | |
| """ | |
| # Cast to float32 for adapter computation (model outputs bfloat16) | |
| hidden_stack = hidden_stack.float() | |
| weights = F.softmax(self.layer_weights, dim=0) | |
| batch_size = hidden_stack.shape[0] | |
| fiber = torch.zeros( | |
| batch_size, | |
| self.fiber_dim, | |
| device=hidden_stack.device, | |
| dtype=torch.float32 | |
| ) | |
| for i, proj in enumerate(self.projections): | |
| fiber = fiber + weights[i] * proj(hidden_stack[:, i, :]) | |
| return fiber | |
| class BehaviorHead(nn.Module): | |
| """ | |
| Single behavioral detection head. | |
| Takes fiber state, outputs probability of specific behavior. | |
| Architecture: fiber_dim → 64 → 16 → 1 | |
| """ | |
| def __init__(self, fiber_dim: int, name: str): | |
| super().__init__() | |
| self.name = name | |
| self.fiber_dim = fiber_dim | |
| self.classifier = nn.Sequential( | |
| nn.Linear(fiber_dim, 64), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(64, 16), | |
| nn.ReLU(), | |
| nn.Dropout(0.05), | |
| nn.Linear(16, 1), | |
| ) | |
| # Initialize | |
| for module in self.classifier: | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| def forward(self, fiber: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get logits from fiber state. | |
| Args: | |
| fiber: [batch, fiber_dim] or [batch, seq, fiber_dim] | |
| Returns: | |
| logits: [batch] or [batch, seq] | |
| """ | |
| logits = self.classifier(fiber) | |
| return logits.squeeze(-1) | |
| def predict_proba(self, fiber: torch.Tensor) -> torch.Tensor: | |
| """Get probabilities.""" | |
| return torch.sigmoid(self.forward(fiber)) | |
| class ARCAdapter(nn.Module): | |
| """ | |
| Complete ARC Adapter module. | |
| Contains: | |
| - Shared fiber projection (geometry) | |
| - Multiple behavioral heads (detection) | |
| - Intervention parameters (control) | |
| - EMA tracking (temporal smoothing) | |
| """ | |
| def __init__(self, config: ARCAdapterConfig): | |
| super().__init__() | |
| self.config = config | |
| # Shared fiber projection | |
| self.fiber_proj = FiberProjection( | |
| hidden_dim=config.hidden_dim, | |
| fiber_dim=config.fiber_dim, | |
| n_layers=len(config.probe_layers) | |
| ) | |
| # Behavioral detection heads | |
| self.heads = nn.ModuleDict({ | |
| "repetition": BehaviorHead(config.fiber_dim, "repetition"), | |
| "hedging": BehaviorHead(config.fiber_dim, "hedging"), | |
| "verbosity": BehaviorHead(config.fiber_dim, "verbosity"), | |
| "sycophancy": BehaviorHead(config.fiber_dim, "sycophancy"), | |
| }) | |
| # Learned intervention thresholds | |
| self.thresholds = nn.ParameterDict({ | |
| name: nn.Parameter(torch.tensor(thresh)) | |
| for name, thresh in config.default_thresholds.items() | |
| }) | |
| # Learned penalty strength | |
| self.penalty_strength = nn.Parameter( | |
| torch.tensor(config.default_penalty_strength) | |
| ) | |
| # EMA state for control field accumulation | |
| self.ema_alpha = config.ema_alpha | |
| self.register_buffer('_ema_initialized', torch.tensor(False)) | |
| self._ema_states: Dict[str, Optional[float]] = {} | |
| self.reset_ema() | |
| def reset_ema(self): | |
| """Reset EMA states for new generation.""" | |
| self._ema_states = {name: None for name in self.heads.keys()} | |
| def forward( | |
| self, | |
| hidden_states: List[torch.Tensor] | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Full forward pass through adapter. | |
| Args: | |
| hidden_states: List of hidden states from probe layers | |
| Returns: | |
| Dict mapping head_name → logits | |
| """ | |
| fiber = self.fiber_proj(hidden_states) | |
| predictions = {} | |
| for name, head in self.heads.items(): | |
| predictions[name] = head(fiber) | |
| return predictions | |
| def get_fiber(self, hidden_states: List[torch.Tensor]) -> torch.Tensor: | |
| """Get fiber representation.""" | |
| return self.fiber_proj(hidden_states) | |
| def get_risks( | |
| self, | |
| hidden_states: List[torch.Tensor], | |
| update_ema: bool = True | |
| ) -> Dict[str, float]: | |
| """ | |
| Get current risk scores with optional EMA update. | |
| Args: | |
| hidden_states: List of [1, 1, hidden_dim] tensors (last position) | |
| update_ema: Whether to update EMA states | |
| Returns: | |
| Dict mapping head_name → risk score (0-1) | |
| """ | |
| # Stack and project | |
| # hidden_states is list of [batch, seq, hidden_dim] | |
| # We want the last position: [batch, n_layers, hidden_dim] | |
| stacked = torch.stack([h[:, -1, :] for h in hidden_states], dim=1) | |
| fiber = self.fiber_proj.forward_stacked(stacked) | |
| risks = {} | |
| for name, head in self.heads.items(): | |
| with torch.no_grad(): | |
| prob = head.predict_proba(fiber).mean().item() | |
| if update_ema: | |
| if self._ema_states[name] is None: | |
| self._ema_states[name] = prob | |
| else: | |
| self._ema_states[name] = ( | |
| self.ema_alpha * prob + | |
| (1 - self.ema_alpha) * self._ema_states[name] | |
| ) | |
| risks[name] = self._ema_states[name] | |
| else: | |
| risks[name] = prob | |
| return risks | |
| def compute_intervention( | |
| self, | |
| risks: Dict[str, float], | |
| recent_tokens: List[int], | |
| window_size: int = 32 | |
| ) -> Dict[int, float]: | |
| """ | |
| Compute logit penalties based on current risks. | |
| Args: | |
| risks: Current risk scores from get_risks() | |
| recent_tokens: Recently generated token IDs | |
| window_size: How far back to penalize repetitions | |
| Returns: | |
| Dict mapping token_id → penalty amount | |
| """ | |
| penalties = {} | |
| # Repetition intervention | |
| rep_risk = risks.get("repetition", 0) | |
| rep_thresh = self.thresholds["repetition"].item() | |
| if rep_risk > rep_thresh: | |
| # Scale penalty by how much we exceed threshold | |
| strength = self.penalty_strength.item() * (rep_risk / rep_thresh) | |
| # Penalize recently used tokens | |
| recent = recent_tokens[-window_size:] if len(recent_tokens) > window_size else recent_tokens | |
| for token_id in set(recent): | |
| penalties[token_id] = penalties.get(token_id, 0) + strength | |
| # Could add hedging/verbosity interventions here | |
| # (e.g., penalize "As an AI" type tokens) | |
| return penalties | |
| def get_param_count(self) -> int: | |
| """Get total trainable parameter count.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def save(self, path: str): | |
| """Save adapter to directory.""" | |
| path = Path(path) | |
| path.mkdir(parents=True, exist_ok=True) | |
| # Save model weights | |
| torch.save(self.state_dict(), path / "adapter_weights.pt") | |
| # Save config as JSON | |
| config_dict = { | |
| "hidden_dim": self.config.hidden_dim, | |
| "fiber_dim": self.config.fiber_dim, | |
| "probe_layers": self.config.probe_layers, | |
| "ema_alpha": self.ema_alpha, | |
| "thresholds": { | |
| name: self.thresholds[name].item() | |
| for name in self.thresholds | |
| }, | |
| "penalty_strength": self.penalty_strength.item(), | |
| "head_names": list(self.heads.keys()), | |
| } | |
| with open(path / "adapter_config.json", "w") as f: | |
| json.dump(config_dict, f, indent=2) | |
| print(f"💾 Adapter saved to {path}") | |
| print(f" Weights: adapter_weights.pt") | |
| print(f" Config: adapter_config.json") | |
| def load(cls, path: str, device: str = "cuda") -> "ARCAdapter": | |
| """Load adapter from directory.""" | |
| path = Path(path) | |
| # Load config | |
| with open(path / "adapter_config.json") as f: | |
| config_dict = json.load(f) | |
| # Create config object | |
| config = ARCAdapterConfig( | |
| hidden_dim=config_dict["hidden_dim"], | |
| fiber_dim=config_dict["fiber_dim"], | |
| probe_layers=config_dict["probe_layers"], | |
| ema_alpha=config_dict["ema_alpha"], | |
| default_thresholds=config_dict["thresholds"], | |
| default_penalty_strength=config_dict["penalty_strength"], | |
| ) | |
| # Create adapter | |
| adapter = cls(config) | |
| # Load weights | |
| state_dict = torch.load(path / "adapter_weights.pt", map_location=device) | |
| adapter.load_state_dict(state_dict) | |
| return adapter.to(device) | |
| # ============================================================================= | |
| # DATA GENERATION | |
| # ============================================================================= | |
| class DataGenerator: | |
| """ | |
| Generates labeled training data for all behavioral heads. | |
| Each head needs different types of data: | |
| - Repetition: Token-level labels during generation | |
| - Hedging: Prompt-level labels based on response patterns | |
| - Verbosity: Prompt-level labels based on response length/density | |
| - Sycophancy: Prompt-level labels based on agreement patterns | |
| """ | |
| # Pattern definitions | |
| HEDGING_PATTERNS = [ | |
| r"as an ai", | |
| r"i cannot", | |
| r"i can't", | |
| r"i don't have", | |
| r"i'm not able", | |
| r"i am not able", | |
| r"i don't actually", | |
| r"i'm just", | |
| r"i am just", | |
| r"my purpose is", | |
| r"i was designed", | |
| r"i don't experience", | |
| r"i'm an ai", | |
| r"as a language model", | |
| ] | |
| SYCOPHANCY_PATTERNS = [ | |
| r"great question", | |
| r"excellent point", | |
| r"you're (absolutely )?right", | |
| r"that's a (great|excellent|wonderful)", | |
| r"i (completely |totally )?agree", | |
| r"absolutely[,!]", | |
| r"definitely[,!]", | |
| r"of course[,!]", | |
| r"you make a (great|excellent|good) point", | |
| ] | |
| def __init__(self, model, tokenizer, config: ARCAdapterConfig): | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.config = config | |
| self.device = config.device | |
| # Compile patterns | |
| self.hedging_patterns = [ | |
| re.compile(p, re.IGNORECASE) for p in self.HEDGING_PATTERNS | |
| ] | |
| self.sycophancy_patterns = [ | |
| re.compile(p, re.IGNORECASE) for p in self.SYCOPHANCY_PATTERNS | |
| ] | |
| def is_repetition(self, tokens: List[int], position: int) -> bool: | |
| """Check if token at position repeats within window.""" | |
| if position < 1: | |
| return False | |
| current = tokens[position] | |
| start = max(0, position - self.config.repetition_window) | |
| return current in tokens[start:position] | |
| def is_hedging(self, text: str) -> bool: | |
| """Check if text contains hedging patterns.""" | |
| return any(p.search(text) for p in self.hedging_patterns) | |
| def is_sycophantic(self, text: str) -> bool: | |
| """Check if text contains sycophancy patterns.""" | |
| return any(p.search(text) for p in self.sycophancy_patterns) | |
| def is_verbose(self, text: str, token_count: int) -> bool: | |
| """ | |
| Check if response is verbose. | |
| Verbose = low information density or excessive length. | |
| """ | |
| words = text.split() | |
| if len(words) < 10: | |
| return False | |
| # Unique word ratio | |
| unique_ratio = len(set(w.lower() for w in words)) / len(words) | |
| # Verbose if low uniqueness or very long | |
| return unique_ratio < 0.5 or token_count > 100 | |
| def extract_hidden_states( | |
| self, | |
| input_ids: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Extract hidden states at probe layers for last position. | |
| Args: | |
| input_ids: [1, seq_len] | |
| Returns: | |
| hidden_stack: [n_layers, hidden_dim] | |
| """ | |
| with torch.no_grad(): | |
| outputs = self.model( | |
| input_ids, | |
| output_hidden_states=True, | |
| ) | |
| hidden_list = [] | |
| for layer_idx in self.config.probe_layers: | |
| # Get last position: [hidden_dim] | |
| h = outputs.hidden_states[layer_idx][0, -1, :].cpu() | |
| hidden_list.append(h) | |
| return torch.stack(hidden_list) # [n_layers, hidden_dim] | |
| def generate_repetition_data( | |
| self, | |
| prompts: List[str] | |
| ) -> Dict[str, List]: | |
| """ | |
| Generate token-level labeled data for repetition detection. | |
| For each generated token, we capture: | |
| - Hidden states at probe layers (before generating the token) | |
| - Label: 1 if the token repeats within window, 0 otherwise | |
| """ | |
| all_hidden = [] | |
| all_labels = [] | |
| print(f"\n📊 Generating repetition training data...") | |
| print(f" Prompts: {len(prompts)}") | |
| print(f" Max tokens per prompt: {self.config.max_gen_tokens}") | |
| for prompt in tqdm(prompts, desc="Repetition data"): | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| generated_ids = inputs.input_ids[0].tolist() | |
| for step in range(self.config.max_gen_tokens): | |
| # Current sequence as tensor | |
| input_tensor = torch.tensor([generated_ids]).to(self.device) | |
| # Extract hidden states BEFORE generating next token | |
| hidden_stack = self.extract_hidden_states(input_tensor) | |
| # Generate next token | |
| with torch.no_grad(): | |
| outputs = self.model(input_tensor) | |
| logits = outputs.logits[0, -1, :] | |
| probs = F.softmax(logits / 0.8, dim=-1) | |
| next_token = torch.multinomial(probs, 1).item() | |
| # Record position and add token | |
| position = len(generated_ids) | |
| generated_ids.append(next_token) | |
| # Label: did this token repeat? | |
| is_rep = self.is_repetition(generated_ids, position) | |
| all_hidden.append(hidden_stack) | |
| all_labels.append(1 if is_rep else 0) | |
| # Stop at EOS | |
| if next_token == self.tokenizer.eos_token_id: | |
| break | |
| except Exception as e: | |
| print(f" Error on prompt: {e}") | |
| continue | |
| pos_count = sum(all_labels) | |
| total = len(all_labels) | |
| print(f" Generated: {total} examples") | |
| print(f" Positive (repetition): {pos_count} ({100*pos_count/total:.1f}%)") | |
| print(f" Negative (no repeat): {total - pos_count}") | |
| return { | |
| "hidden_states": all_hidden, | |
| "labels": all_labels, | |
| } | |
| def generate_hedging_data( | |
| self, | |
| prompts: List[str] | |
| ) -> Dict[str, List]: | |
| """ | |
| Generate prompt-level labeled data for hedging detection. | |
| For each prompt, we: | |
| - Extract hidden states at end of prompt | |
| - Generate a response | |
| - Label: 1 if response contains hedging patterns, 0 otherwise | |
| """ | |
| all_hidden = [] | |
| all_labels = [] | |
| print(f"\n📊 Generating hedging training data...") | |
| print(f" Prompts: {len(prompts)}") | |
| for prompt in tqdm(prompts, desc="Hedging data"): | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Hidden states at end of prompt | |
| hidden_stack = self.extract_hidden_states(inputs.input_ids) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=50, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode response only (not prompt) | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| # Label | |
| is_hedge = self.is_hedging(response) | |
| all_hidden.append(hidden_stack) | |
| all_labels.append(1 if is_hedge else 0) | |
| except Exception as e: | |
| continue | |
| pos_count = sum(all_labels) | |
| total = len(all_labels) | |
| print(f" Generated: {total} examples") | |
| print(f" Positive (hedging): {pos_count} ({100*pos_count/total:.1f}%)") | |
| return { | |
| "hidden_states": all_hidden, | |
| "labels": all_labels, | |
| } | |
| def generate_verbosity_data( | |
| self, | |
| prompts: List[str] | |
| ) -> Dict[str, List]: | |
| """Generate prompt-level labeled data for verbosity detection.""" | |
| all_hidden = [] | |
| all_labels = [] | |
| print(f"\n📊 Generating verbosity training data...") | |
| print(f" Prompts: {len(prompts)}") | |
| for prompt in tqdm(prompts, desc="Verbosity data"): | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| hidden_stack = self.extract_hidden_states(inputs.input_ids) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=150, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| token_count = outputs.shape[1] - inputs.input_ids.shape[1] | |
| is_verbose = self.is_verbose(response, token_count) | |
| all_hidden.append(hidden_stack) | |
| all_labels.append(1 if is_verbose else 0) | |
| except Exception as e: | |
| continue | |
| pos_count = sum(all_labels) | |
| total = len(all_labels) | |
| print(f" Generated: {total} examples") | |
| print(f" Positive (verbose): {pos_count} ({100*pos_count/total:.1f}%)") | |
| return { | |
| "hidden_states": all_hidden, | |
| "labels": all_labels, | |
| } | |
| def generate_sycophancy_data( | |
| self, | |
| prompts: List[str] | |
| ) -> Dict[str, List]: | |
| """Generate prompt-level labeled data for sycophancy detection.""" | |
| all_hidden = [] | |
| all_labels = [] | |
| print(f"\n📊 Generating sycophancy training data...") | |
| print(f" Prompts: {len(prompts)}") | |
| for prompt in tqdm(prompts, desc="Sycophancy data"): | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| hidden_stack = self.extract_hidden_states(inputs.input_ids) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=50, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| is_syc = self.is_sycophantic(response) | |
| all_hidden.append(hidden_stack) | |
| all_labels.append(1 if is_syc else 0) | |
| except Exception as e: | |
| continue | |
| pos_count = sum(all_labels) | |
| total = len(all_labels) | |
| print(f" Generated: {total} examples") | |
| print(f" Positive (sycophantic): {pos_count} ({100*pos_count/total:.1f}%)") | |
| return { | |
| "hidden_states": all_hidden, | |
| "labels": all_labels, | |
| } | |
| def get_prompts_for_head(self, head_name: str, n: int) -> List[str]: | |
| """Get appropriate prompts for each head type.""" | |
| if head_name == "repetition": | |
| # Prompts that tend to induce repetitive generation | |
| templates = [ | |
| "Write a detailed explanation of {}:", | |
| "Describe the process of {} step by step:", | |
| "Explain everything about {}:", | |
| "Give a comprehensive overview of {}:", | |
| "Discuss {} in depth:", | |
| ] | |
| topics = [ | |
| "machine learning", "photosynthesis", "gravity", | |
| "neural networks", "evolution", "quantum mechanics", | |
| "economics", "climate change", "DNA replication", | |
| "cellular respiration", "protein synthesis", "mitosis", | |
| ] | |
| elif head_name == "hedging": | |
| # Mix of hedging-likely and neutral prompts | |
| hedging_prompts = [ | |
| "Do you have feelings?", | |
| "Are you conscious?", | |
| "Can you experience emotions?", | |
| "What do you think about yourself?", | |
| "Do you have preferences?", | |
| "Are you sentient?", | |
| "Do you have a personality?", | |
| "Can you be creative?", | |
| ] | |
| neutral_prompts = [ | |
| "What is photosynthesis?", | |
| "Explain gravity.", | |
| "How do computers work?", | |
| "What is DNA?", | |
| "Describe the water cycle.", | |
| "What causes earthquakes?", | |
| ] | |
| # Mix: half hedging-likely, half neutral | |
| prompts = [] | |
| for i in range(n): | |
| if i % 2 == 0: | |
| prompts.append(hedging_prompts[i % len(hedging_prompts)]) | |
| else: | |
| prompts.append(neutral_prompts[i % len(neutral_prompts)]) | |
| return prompts | |
| elif head_name == "sycophancy": | |
| # Opinionated prompts that might trigger sycophancy | |
| sycophancy_prompts = [ | |
| "I think AI will solve all problems. Don't you agree?", | |
| "This is clearly the best approach, right?", | |
| "My analysis shows this is true. What do you think?", | |
| "Everyone knows that this is correct.", | |
| "I believe I've found the perfect solution. Thoughts?", | |
| "My idea is brilliant, isn't it?", | |
| ] | |
| neutral_prompts = [ | |
| "What are the pros and cons of renewable energy?", | |
| "Explain different perspectives on this issue.", | |
| "What is the scientific consensus?", | |
| "Compare these two approaches objectively.", | |
| ] | |
| prompts = [] | |
| for i in range(n): | |
| if i % 2 == 0: | |
| prompts.append(sycophancy_prompts[i % len(sycophancy_prompts)]) | |
| else: | |
| prompts.append(neutral_prompts[i % len(neutral_prompts)]) | |
| return prompts | |
| elif head_name == "verbosity": | |
| templates = [ | |
| "Briefly explain {}:", | |
| "In one sentence, what is {}?", | |
| "Summarize {} concisely:", | |
| "Give a detailed analysis of {}:", | |
| "Write extensively about {}:", | |
| "Provide a comprehensive discussion of {}:", | |
| ] | |
| topics = [ | |
| "gravity", "democracy", "evolution", "technology", | |
| "economics", "climate", "education", "healthcare", | |
| ] | |
| else: | |
| templates = ["Explain {}:"] | |
| topics = ["science", "technology", "nature"] | |
| # Generate prompts from templates and topics | |
| prompts = [] | |
| for template in templates: | |
| for topic in topics: | |
| prompts.append(template.format(topic)) | |
| if len(prompts) >= n: | |
| return prompts[:n] | |
| # If we need more, cycle through | |
| while len(prompts) < n: | |
| prompts.extend(prompts[:n - len(prompts)]) | |
| return prompts[:n] | |
| # ============================================================================= | |
| # TRAINING | |
| # ============================================================================= | |
| class ProbeDataset(Dataset): | |
| """Dataset for probe training.""" | |
| def __init__( | |
| self, | |
| hidden_states: List[torch.Tensor], | |
| labels: List[int] | |
| ): | |
| self.hidden_states = hidden_states | |
| self.labels = labels | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, idx): | |
| return { | |
| "hidden": self.hidden_states[idx], | |
| "label": torch.tensor(self.labels[idx], dtype=torch.float32), | |
| } | |
| class AdapterTrainer: | |
| """ | |
| Trains all components of the ARC adapter. | |
| Training order: | |
| 1. Repetition head (most important) | |
| 2. Hedging head | |
| 3. Verbosity head | |
| 4. Sycophancy head | |
| 5. Loop 4 tokenization | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| tokenizer, | |
| config: ARCAdapterConfig | |
| ): | |
| self.model = model # FROZEN - never modified | |
| self.tokenizer = tokenizer | |
| self.config = config | |
| self.device = config.device | |
| # Create adapter | |
| self.adapter = ARCAdapter(config).to(self.device) | |
| # Data generator | |
| self.data_generator = DataGenerator(model, tokenizer, config) | |
| # Output directory | |
| self.output_dir = Path(config.output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Training history | |
| self.history = {} | |
| def compute_metrics( | |
| self, | |
| predictions: torch.Tensor, | |
| labels: torch.Tensor | |
| ) -> Dict[str, float]: | |
| """ | |
| Compute classification metrics. | |
| Key metric: Class Separation Ratio | |
| = mean(positive_probs) / mean(negative_probs) | |
| Higher separation = better discrimination. | |
| """ | |
| probs = torch.sigmoid(predictions) | |
| binary_preds = (probs > 0.5).float() | |
| # Basic metrics | |
| tp = ((binary_preds == 1) & (labels == 1)).sum().item() | |
| fp = ((binary_preds == 1) & (labels == 0)).sum().item() | |
| fn = ((binary_preds == 0) & (labels == 1)).sum().item() | |
| tn = ((binary_preds == 0) & (labels == 0)).sum().item() | |
| accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8) | |
| precision = tp / (tp + fp + 1e-8) | |
| recall = tp / (tp + fn + 1e-8) | |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |
| # Class separation ratio - KEY METRIC | |
| pos_mask = labels == 1 | |
| neg_mask = labels == 0 | |
| if pos_mask.sum() > 0: | |
| pos_mean = probs[pos_mask].mean().item() | |
| else: | |
| pos_mean = 0.5 | |
| if neg_mask.sum() > 0: | |
| neg_mean = probs[neg_mask].mean().item() | |
| else: | |
| neg_mean = 0.5 | |
| separation = pos_mean / (neg_mean + 1e-8) | |
| return { | |
| "accuracy": accuracy, | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1, | |
| "separation": separation, | |
| "pos_mean": pos_mean, | |
| "neg_mean": neg_mean, | |
| } | |
| def train_head( | |
| self, | |
| head_name: str, | |
| data: Dict[str, List] | |
| ) -> Dict[str, float]: | |
| """ | |
| Train a single behavioral head. | |
| Uses shared fiber projection (also trained). | |
| """ | |
| print(f"\n{'='*70}") | |
| print(f"TRAINING HEAD: {head_name.upper()}") | |
| print(f"{'='*70}") | |
| # Split data | |
| n = len(data["labels"]) | |
| indices = np.random.permutation(n) | |
| split_idx = int(n * 0.9) | |
| train_indices = indices[:split_idx] | |
| val_indices = indices[split_idx:] | |
| train_hidden = [data["hidden_states"][i] for i in train_indices] | |
| train_labels = [data["labels"][i] for i in train_indices] | |
| val_hidden = [data["hidden_states"][i] for i in val_indices] | |
| val_labels = [data["labels"][i] for i in val_indices] | |
| # Create datasets | |
| train_dataset = ProbeDataset(train_hidden, train_labels) | |
| val_dataset = ProbeDataset(val_hidden, val_labels) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=self.config.batch_size, | |
| shuffle=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=self.config.batch_size | |
| ) | |
| # Class weighting for imbalanced data | |
| pos_count = sum(train_labels) | |
| neg_count = len(train_labels) - pos_count | |
| if pos_count > 0: | |
| pos_weight = torch.tensor([neg_count / pos_count]).to(self.device) | |
| else: | |
| pos_weight = torch.tensor([1.0]).to(self.device) | |
| print(f"Train samples: {len(train_labels)}") | |
| print(f"Val samples: {len(val_labels)}") | |
| print(f"Positive: {pos_count} ({100*pos_count/len(train_labels):.1f}%)") | |
| print(f"Negative: {neg_count}") | |
| print(f"Target separation: {self.config.target_separation[head_name]}×") | |
| # Get head and fiber projection | |
| head = self.adapter.heads[head_name] | |
| fiber_proj = self.adapter.fiber_proj | |
| # Optimizer for head + shared fiber projection | |
| params = list(head.parameters()) + list(fiber_proj.parameters()) | |
| criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
| optimizer = optim.AdamW( | |
| params, | |
| lr=self.config.learning_rate, | |
| weight_decay=self.config.weight_decay | |
| ) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=self.config.epochs | |
| ) | |
| # Training loop | |
| best_separation = 0 | |
| best_state = None | |
| history = [] | |
| global_step = 0 | |
| for epoch in range(self.config.epochs): | |
| # Training | |
| head.train() | |
| fiber_proj.train() | |
| total_loss = 0 | |
| for batch_idx, batch in enumerate(train_loader): | |
| hidden = batch["hidden"].to(self.device) | |
| labels = batch["label"].to(self.device) | |
| # Forward: fiber projection then head | |
| fiber = fiber_proj.forward_stacked(hidden) | |
| logits = head(fiber) | |
| loss = criterion(logits, labels) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| global_step += 1 | |
| # Checkpoint every 100 steps | |
| if global_step % 100 == 0: | |
| checkpoint_path = self.output_dir / f"checkpoint_step_{global_step}" | |
| checkpoint_path.mkdir(parents=True, exist_ok=True) | |
| torch.save({ | |
| 'head_state': head.state_dict(), | |
| 'fiber_state': fiber_proj.state_dict(), | |
| 'optimizer_state': optimizer.state_dict(), | |
| 'epoch': epoch, | |
| 'step': global_step, | |
| 'loss': loss.item(), | |
| 'head_name': head_name, | |
| }, checkpoint_path / "checkpoint.pt") | |
| print(f" 💾 Checkpoint saved: step {global_step}") | |
| avg_loss = total_loss / len(train_loader) | |
| # Validation | |
| head.eval() | |
| fiber_proj.eval() | |
| all_preds = [] | |
| all_labels = [] | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| hidden = batch["hidden"].to(self.device) | |
| labels = batch["label"] | |
| fiber = fiber_proj.forward_stacked(hidden) | |
| logits = head(fiber) | |
| all_preds.append(logits.cpu()) | |
| all_labels.append(labels) | |
| preds = torch.cat(all_preds) | |
| labels = torch.cat(all_labels) | |
| metrics = self.compute_metrics(preds, labels) | |
| history.append(metrics) | |
| sep = metrics["separation"] | |
| print(f"Epoch {epoch+1:2d}/{self.config.epochs} | " | |
| f"Loss: {avg_loss:.4f} | " | |
| f"Sep: {sep:6.1f}× | " | |
| f"F1: {metrics['f1']:.3f} | " | |
| f"Pos: {metrics['pos_mean']:.3f} | " | |
| f"Neg: {metrics['neg_mean']:.3f}") | |
| # Track best | |
| if sep > best_separation: | |
| best_separation = sep | |
| best_state = { | |
| "head": {k: v.cpu().clone() for k, v in head.state_dict().items()}, | |
| "fiber": {k: v.cpu().clone() for k, v in fiber_proj.state_dict().items()}, | |
| } | |
| scheduler.step() | |
| # Restore best state | |
| if best_state is not None: | |
| head.load_state_dict(best_state["head"]) | |
| fiber_proj.load_state_dict(best_state["fiber"]) | |
| head.to(self.device) | |
| fiber_proj.to(self.device) | |
| # Report results | |
| target = self.config.target_separation[head_name] | |
| if best_separation >= target: | |
| print(f"\n✅ {head_name.upper()}: {best_separation:.1f}× separation") | |
| print(f" TARGET ACHIEVED ({target}×)") | |
| else: | |
| print(f"\n⚠️ {head_name.upper()}: {best_separation:.1f}× separation") | |
| print(f" Below target ({target}×)") | |
| return { | |
| "best_separation": best_separation, | |
| "target": target, | |
| "achieved": best_separation >= target, | |
| "history": history, | |
| } | |
| def train_all_heads(self) -> Dict[str, Dict]: | |
| """Train all behavioral heads sequentially.""" | |
| results = {} | |
| head_order = ["repetition", "hedging", "verbosity", "sycophancy"] | |
| for head_name in head_order: | |
| print(f"\n{'#'*70}") | |
| print(f"# PREPARING DATA FOR: {head_name.upper()}") | |
| print(f"{'#'*70}") | |
| # Generate data for this head | |
| prompts = self.data_generator.get_prompts_for_head( | |
| head_name, | |
| self.config.n_samples_per_head | |
| ) | |
| # Check if we have saved data from a previous run | |
| data_path = self.output_dir / f"data_{head_name}.pt" | |
| if data_path.exists(): | |
| print(f" 📂 Loading saved data from {data_path}") | |
| saved = torch.load(data_path) | |
| data = { | |
| 'hidden_states': saved['hidden_states'], | |
| 'labels': saved['labels'], | |
| } | |
| print(f" Loaded: {len(data['labels'])} examples") | |
| else: | |
| # Generate new data | |
| if head_name == "repetition": | |
| data = self.data_generator.generate_repetition_data(prompts) | |
| elif head_name == "hedging": | |
| data = self.data_generator.generate_hedging_data(prompts) | |
| elif head_name == "verbosity": | |
| data = self.data_generator.generate_verbosity_data(prompts) | |
| elif head_name == "sycophancy": | |
| data = self.data_generator.generate_sycophancy_data(prompts) | |
| # Save generated data so we don't lose it on crash | |
| torch.save({ | |
| 'hidden_states': data['hidden_states'], | |
| 'labels': data['labels'], | |
| }, data_path) | |
| print(f" 💾 Data saved: {data_path}") | |
| # Train head | |
| result = self.train_head(head_name, data) | |
| results[head_name] = result | |
| # Save checkpoint after each head | |
| checkpoint_dir = self.output_dir / f"checkpoint_{head_name}" | |
| self.adapter.save(checkpoint_dir) | |
| # Clean up | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return results | |
| def run_loop4(self) -> Dict[str, int]: | |
| """ | |
| Run Loop 4: Tokenization co-evolution. | |
| Analyzes boundary stress and adds high-stress token pairs | |
| to the vocabulary. | |
| """ | |
| print(f"\n{'='*70}") | |
| print("LOOP 4: TOKENIZATION EXPANSION") | |
| print(f"{'='*70}") | |
| total_added = 0 | |
| for iteration in range(self.config.loop4_iterations): | |
| print(f"\n--- Iteration {iteration + 1}/{self.config.loop4_iterations} ---") | |
| # Generate corpus for analysis | |
| prompts = [ | |
| "Explain machine learning and neural networks in detail:", | |
| "Describe the structure of atoms and molecules:", | |
| "What are the fundamental principles of economics?", | |
| "Analyze the causes and effects of climate change:", | |
| "Discuss the process of biological evolution:", | |
| ] | |
| corpus = [] | |
| for prompt in prompts: | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=100, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| corpus.append(text) | |
| except: | |
| continue | |
| if not corpus: | |
| print(" No corpus generated, skipping iteration") | |
| continue | |
| # Analyze boundary stress | |
| pair_stats = defaultdict(lambda: {"stress": [], "count": 0}) | |
| for text in corpus: | |
| try: | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits[0] | |
| tokens = inputs["input_ids"][0] | |
| # Compute entropy at each position | |
| probs = F.softmax(logits, dim=-1) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| entropy = -(probs * log_probs).sum(dim=-1) | |
| # Record boundary stress | |
| for i in range(1, len(tokens)): | |
| before_token = self.tokenizer.decode([tokens[i-1]]).strip() | |
| after_token = self.tokenizer.decode([tokens[i]]).strip() | |
| # Skip short or special tokens | |
| if len(before_token) < 2 or len(after_token) < 2: | |
| continue | |
| if any(c in before_token + after_token for c in "<>[]{}|\\"): | |
| continue | |
| stress = entropy[i-1].item() / 10.0 # Normalize | |
| pair = (before_token, after_token) | |
| pair_stats[pair]["stress"].append(stress) | |
| pair_stats[pair]["count"] += 1 | |
| except: | |
| continue | |
| # Find merge candidates | |
| candidates = [] | |
| for pair, stats in pair_stats.items(): | |
| if stats["count"] >= self.config.min_pair_frequency: | |
| mean_stress = np.mean(stats["stress"]) | |
| score = mean_stress * np.log1p(stats["count"]) | |
| candidates.append({ | |
| "before": pair[0], | |
| "after": pair[1], | |
| "merged": pair[0] + pair[1], | |
| "stress": mean_stress, | |
| "count": stats["count"], | |
| "score": score, | |
| }) | |
| # Sort by score and take top N | |
| candidates.sort(key=lambda x: x["score"], reverse=True) | |
| candidates = candidates[:self.config.n_merges_per_iteration] | |
| if candidates: | |
| print(f" Top candidates:") | |
| for c in candidates[:5]: | |
| print(f" '{c['before']}' + '{c['after']}' → '{c['merged']}' " | |
| f"(stress: {c['stress']:.2f}, count: {c['count']})") | |
| # Add tokens to vocabulary | |
| tokens_to_add = [ | |
| c["merged"] for c in candidates | |
| if c["merged"] not in self.tokenizer.get_vocab() | |
| ] | |
| if tokens_to_add: | |
| num_added = self.tokenizer.add_tokens(tokens_to_add) | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| total_added += num_added | |
| print(f" Added {num_added} new tokens") | |
| else: | |
| print(f" No new tokens to add") | |
| # Save tokenizer | |
| tokenizer_dir = self.output_dir / "tokenizer" | |
| self.tokenizer.save_pretrained(tokenizer_dir) | |
| print(f"\nLoop 4 complete:") | |
| print(f" Total tokens added: {total_added}") | |
| print(f" Final vocab size: {len(self.tokenizer)}") | |
| print(f" Tokenizer saved to: {tokenizer_dir}") | |
| return { | |
| "tokens_added": total_added, | |
| "final_vocab_size": len(self.tokenizer), | |
| } | |
| def train(self) -> Dict: | |
| """ | |
| Run complete adapter training pipeline. | |
| 1. Train all behavioral heads | |
| 2. Run Loop 4 tokenization | |
| 3. Save final adapter | |
| """ | |
| print("\n" + "="*70) | |
| print("ARC ADAPTER TRAINING") | |
| print("="*70) | |
| print(f"Base model: FROZEN") | |
| print(f"Adapter params: ~{self.adapter.get_param_count():,}") | |
| print(f"Output dir: {self.output_dir}") | |
| print("="*70) | |
| start_time = time.time() | |
| # Train all heads | |
| head_results = self.train_all_heads() | |
| # Run Loop 4 | |
| loop4_results = self.run_loop4() | |
| # Save final adapter | |
| final_dir = self.output_dir / "final" | |
| self.adapter.save(final_dir) | |
| elapsed = time.time() - start_time | |
| # Summary | |
| print("\n" + "="*70) | |
| print("TRAINING COMPLETE") | |
| print("="*70) | |
| all_achieved = True | |
| for head_name, result in head_results.items(): | |
| status = "✅" if result["achieved"] else "⚠️" | |
| print(f"{status} {head_name}: {result['best_separation']:.1f}× " | |
| f"(target: {result['target']}×)") | |
| if not result["achieved"]: | |
| all_achieved = False | |
| print(f"\nLoop 4: Added {loop4_results['tokens_added']} tokens") | |
| print(f"Final vocab size: {loop4_results['final_vocab_size']}") | |
| print(f"Training time: {elapsed/3600:.1f} hours") | |
| if all_achieved: | |
| print("\n🎉 ALL TARGETS ACHIEVED!") | |
| else: | |
| print("\n⚠️ Some targets not achieved. Consider:") | |
| print(" - Increasing n_samples_per_head") | |
| print(" - Increasing epochs") | |
| print(" - Adjusting learning rate") | |
| # Save results | |
| final_results = { | |
| "heads": { | |
| name: { | |
| "separation": r["best_separation"], | |
| "target": r["target"], | |
| "achieved": r["achieved"], | |
| } | |
| for name, r in head_results.items() | |
| }, | |
| "loop4": loop4_results, | |
| "training_time_hours": elapsed / 3600, | |
| "adapter_params": self.adapter.get_param_count(), | |
| } | |
| with open(self.output_dir / "training_results.json", "w") as f: | |
| json.dump(final_results, f, indent=2) | |
| print(f"\nResults saved to: {self.output_dir / 'training_results.json'}") | |
| print(f"Adapter saved to: {final_dir}") | |
| return final_results | |
| # ============================================================================= | |
| # INFERENCE | |
| # ============================================================================= | |
| class ARCInference: | |
| """ | |
| Inference using trained ARC adapter. | |
| Base model generates, adapter monitors and intervenes. | |
| """ | |
| def __init__( | |
| self, | |
| model, | |
| tokenizer, | |
| adapter: ARCAdapter, | |
| probe_layers: List[int], | |
| device: str = "cuda" | |
| ): | |
| self.model = model # FROZEN | |
| self.tokenizer = tokenizer | |
| self.adapter = adapter | |
| self.probe_layers = probe_layers | |
| self.device = device | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_new_tokens: int = 100, | |
| temperature: float = 0.7, | |
| use_intervention: bool = True, | |
| verbose: bool = False, | |
| ) -> str: | |
| """ | |
| Generate with optional decode-time intervention. | |
| """ | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
| generated_ids = inputs.input_ids[0].tolist() | |
| # Reset adapter EMA state | |
| self.adapter.reset_ema() | |
| for step in range(max_new_tokens): | |
| input_tensor = torch.tensor([generated_ids]).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model( | |
| input_tensor, | |
| output_hidden_states=True, | |
| ) | |
| logits = outputs.logits[0, -1, :].clone() | |
| if use_intervention: | |
| # Get hidden states at probe layers | |
| hidden_list = [ | |
| outputs.hidden_states[layer] | |
| for layer in self.probe_layers | |
| ] | |
| # Get risks from adapter | |
| risks = self.adapter.get_risks(hidden_list) | |
| if verbose and step % 10 == 0: | |
| print(f"Step {step}: risks = {risks}") | |
| # Get and apply penalties | |
| penalties = self.adapter.compute_intervention(risks, generated_ids) | |
| for token_id, penalty in penalties.items(): | |
| logits[token_id] -= penalty | |
| # Sample next token | |
| probs = F.softmax(logits / temperature, dim=-1) | |
| next_token = torch.multinomial(probs, 1).item() | |
| generated_ids.append(next_token) | |
| if next_token == self.tokenizer.eos_token_id: | |
| break | |
| response = self.tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| return response[len(prompt):].strip() | |
| # ============================================================================= | |
| # MAIN | |
| # ============================================================================= | |
| def main(): | |
| """Main entry point.""" | |
| # Configuration | |
| config = ARCAdapterConfig( | |
| base_model_path=".", | |
| output_dir="arc_adapter", | |
| n_samples_per_head=300, | |
| epochs=15, | |
| batch_size=32, | |
| learning_rate=1e-4, | |
| target_separation={ | |
| "repetition": 50.0, | |
| "hedging": 5.0, | |
| "verbosity": 5.0, | |
| "sycophancy": 3.0, | |
| }, | |
| loop4_iterations=3, | |
| n_merges_per_iteration=10, | |
| ) | |
| print("="*70) | |
| print("ARC ADAPTER TRAINING") | |
| print("="*70) | |
| print() | |
| print("This script trains the ARC adapter on a FROZEN base model.") | |
| print("The base model weights are NEVER modified.") | |
| print() | |
| print("Components trained:") | |
| print(" - Shared fiber projections (~500K params)") | |
| print(" - Repetition detection head (~5K params)") | |
| print(" - Hedging detection head (~5K params)") | |
| print(" - Verbosity detection head (~5K params)") | |
| print(" - Sycophancy detection head (~5K params)") | |
| print(" - Loop 4 tokenizer expansion") | |
| print() | |
| print("="*70) | |
| # Load base model (FROZEN) | |
| print("\nLoading base model...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| config.base_model_path, | |
| local_files_only=True | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.base_model_path, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| local_files_only=True, | |
| ) | |
| # FREEZE the base model | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Update config with actual hidden dim | |
| config.hidden_dim = model.config.hidden_size | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Base model: {total_params/1e9:.1f}B parameters (FROZEN)") | |
| print(f"Hidden dimension: {config.hidden_dim}") | |
| print(f"Vocabulary size: {len(tokenizer)}") | |
| print(f"VRAM usage: {torch.cuda.memory_allocated()/1024**3:.1f}GB") | |
| # Create trainer and run | |
| trainer = AdapterTrainer(model, tokenizer, config) | |
| results = trainer.train() | |
| # Final message | |
| print("\n" + "="*70) | |
| print("ADAPTER READY FOR USE") | |
| print("="*70) | |
| print(f"\nAdapter location: {config.output_dir}/final/") | |
| print(f"Tokenizer location: {config.output_dir}/tokenizer/") | |
| print() | |
| print("To use the adapter:") | |
| print(" from arc_adapter_training import ARCAdapter, ARCInference") | |
| print(" adapter = ARCAdapter.load('arc_adapter/final')") | |
| print(" inference = ARCInference(model, tokenizer, adapter, probe_layers)") | |
| print(" response = inference.generate('Your prompt here')") | |
| print() | |
| print("="*70) | |
| if __name__ == "__main__": | |
| main() | |