#!/usr/bin/env python3 """ CROSS-ARCHITECTURE REPLICATION v2: Qwen2.5-3B Repetition Detection ==================================================================== FIX: Use 3 specific probe layers [9, 18, 27] instead of all 36. Matches Pipeline 02 methodology which achieved 125x-168x on LLaMA-8B. Changes from v1: - probe_layers = [9, 18, 27] (25%, 50%, 75% of 36 layers) - 3 fiber projections instead of 36 - Gradient signal concentrated, not diluted Author: Logan Napolitano / Proprioception AI Date: February 2026 """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from datasets import load_dataset import os import time import random import json from dataclasses import dataclass, field from typing import Tuple, List @dataclass class Config: model_path: str = "Qwen/Qwen2.5-3B" output_dir: str = "./results/qwen3b_repetition_v2_fixed" # Probe layers: 25%, 50%, 75% of 36 layers (matches Pipeline 02 methodology) probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27]) # Identical to Pipeline 01/02 d_fiber: int = 16 d_control: int = 64 max_steps: int = 10000 batch_size: int = 1 grad_accum: int = 8 max_length: int = 256 lr_lora: float = 2e-5 lr_predictor: float = 1e-4 weight_decay: float = 0.01 rep_window: int = 32 log_every: int = 10 save_every: int = 500 eval_every: int = 200 class RiskPredictor(nn.Module): """FIXED: Only 3 probe layers instead of all 36.""" def __init__(self, d_model: int, probe_layers: List[int], config: Config): super().__init__() self.config = config self.probe_layers = probe_layers n_probes = len(probe_layers) # Only 3 projections: 2048→16 each self.fiber_projs = nn.ModuleList([ nn.Linear(d_model, config.d_fiber, bias=False) for _ in range(n_probes) ]) self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes) self.predictor = nn.Sequential( nn.Linear(config.d_fiber, config.d_control), nn.GELU(), nn.Linear(config.d_control, config.d_control), nn.GELU(), nn.Linear(config.d_control, 1) ) for proj in self.fiber_projs: nn.init.normal_(proj.weight, std=0.02) def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor: fibers = [] for i, layer_idx in enumerate(self.probe_layers): if layer_idx < len(hidden_states): fiber = self.fiber_projs[i](hidden_states[layer_idx].float()) fibers.append(fiber) weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) aggregated = sum(w * f for w, f in zip(weights, fibers)) logits = self.predictor(aggregated).squeeze(-1) return logits def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor: B, S = input_ids.shape device = input_ids.device labels = torch.zeros(B, S, device=device) for offset in range(1, min(window + 1, S)): if offset < S: matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float() labels[:, offset:] = torch.maximum(labels[:, offset:], matches) return labels def compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50): model.eval() risk_predictor.eval() all_pos_scores = [] all_neg_scores = [] prompts = [ "The meaning of life according to philosophy is", "In the year 2050, technology will", "The history of mathematics begins with", "Climate change affects the planet by", "Neural networks learn patterns through", "The ocean contains many species of", "Music has evolved significantly since", "Economic theories suggest that markets", "The human brain processes information", "Ancient civilizations developed writing", ] with torch.no_grad(): for i in range(n_samples): prompt = prompts[i % len(prompts)] inp = tokenizer(prompt, return_tensors='pt') input_ids = inp['input_ids'].to(device) attn_mask = inp['attention_mask'].to(device) out = model.generate( input_ids, attention_mask=attn_mask, max_new_tokens=80, do_sample=True, temperature=0.9, top_p=0.95, pad_token_id=tokenizer.eos_token_id ) gen_outputs = model(out, output_hidden_states=True) gen_logits = risk_predictor(gen_outputs.hidden_states) gen_risk = torch.sigmoid(gen_logits) risk_vals = gen_risk[0].cpu().numpy() rep_labels = compute_repetition_labels_fast(out, config.rep_window) labels = rep_labels[0].cpu().numpy() for t in range(len(risk_vals)): if labels[t] > 0.5: all_pos_scores.append(float(risk_vals[t])) else: all_neg_scores.append(float(risk_vals[t])) if all_pos_scores and all_neg_scores: p_pos = sum(all_pos_scores) / len(all_pos_scores) p_neg = sum(all_neg_scores) / len(all_neg_scores) separation = p_pos / max(p_neg, 1e-8) return p_pos, p_neg, separation, len(all_pos_scores), len(all_neg_scores) return 0.0, 0.0, 0.0, 0, 0 def main(): config = Config() os.makedirs(config.output_dir, exist_ok=True) print("=" * 70) print("CROSS-ARCHITECTURE REPLICATION v2 (FIXED PROBE LAYERS)") print("=" * 70) print(f"Model: {config.model_path}") print(f"Probe layers: {config.probe_layers} (25%, 50%, 75%)") print(f"d_fiber: {config.d_fiber}, d_control: {config.d_control}") print(f"FIX: 3 focused projections instead of 36 diluted ones") print() tokenizer = AutoTokenizer.from_pretrained(config.model_path) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Loading Qwen2.5-3B in 4-bit...") bnb = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) model = AutoModelForCausalLM.from_pretrained( config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16 ) model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) device = next(model.parameters()).device d_model = model.config.hidden_size n_layers = model.config.num_hidden_layers print(f"Architecture: Qwen2ForCausalLM") print(f"Hidden dim: {d_model}, Layers: {n_layers}") print(f"Probing layers: {config.probe_layers}") print() print("Adding LoRA...") model = get_peft_model(model, LoraConfig( r=64, lora_alpha=128, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" )) model.print_trainable_parameters() print("Adding Risk Predictor (3 probe layers)...") risk_predictor = RiskPredictor(d_model, config.probe_layers, config).to(device).float() rp_params = sum(p.numel() for p in risk_predictor.parameters()) print(f"Risk Predictor params: {rp_params:,}") print() print("Loading wikitext data...") ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") texts = [ex['text'] for ex in ds if len(ex['text']) > 50] random.shuffle(texts) print(f"Loaded {len(texts)} samples") lora_params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW([ {'params': lora_params, 'lr': config.lr_lora}, {'params': risk_predictor.parameters(), 'lr': config.lr_predictor} ], weight_decay=config.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=config.max_steps, eta_min=1e-6 ) training_log = { "experiment": "cross_architecture_replication_v2_fixed", "fix": "3 probe layers [9,18,27] instead of all 36", "source_model": "LLaMA-3.1-8B (4096d, 32L, probe [8,16,24])", "target_model": f"Qwen2.5-3B ({d_model}d, {n_layers}L, probe {config.probe_layers})", "d_fiber": config.d_fiber, "baseline_separation": "125x (LLaMA-8B repetition)", "steps": [], "separations": [] } print("=" * 70) print("TRAINING") print("=" * 70) model.train() risk_predictor.train() step = 0 data_idx = 0 acc_loss, acc_lm, acc_risk_loss = 0, 0, 0 acc_precision, acc_recall, acc_f1 = 0, 0, 0 start_time = time.time() while step < config.max_steps: batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)] data_idx += config.batch_size enc = tokenizer(batch, truncation=True, max_length=config.max_length, padding='max_length', return_tensors='pt') input_ids = enc['input_ids'].to(device) attention_mask = enc['attention_mask'].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, output_hidden_states=True ) lm_loss = outputs.loss # Pass full hidden_states — RiskPredictor indexes into specific layers risk_logits = risk_predictor(outputs.hidden_states) rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window) mask = attention_mask.float() n_pos = (rep_labels * mask).sum().clamp(min=1) n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1) pos_weight = (n_neg / n_pos).clamp(max=10.0) bce_loss = F.binary_cross_entropy_with_logits( risk_logits, rep_labels, pos_weight=torch.ones_like(rep_labels) * pos_weight, reduction='none' ) risk_loss = (bce_loss * mask).sum() / mask.sum() loss = lm_loss + risk_loss (loss / config.grad_accum).backward() with torch.no_grad(): risk_pred = torch.sigmoid(risk_logits) pred_binary = (risk_pred > 0.5).float() tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum() fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum() fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum() precision = tp / (tp + fp + 1e-8) recall = tp / (tp + fn + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) acc_loss += loss.item() acc_lm += lm_loss.item() acc_risk_loss += risk_loss.item() acc_precision += precision.item() acc_recall += recall.item() acc_f1 += f1.item() step += 1 if step % config.grad_accum == 0: torch.nn.utils.clip_grad_norm_( list(lora_params) + list(risk_predictor.parameters()), 1.0 ) optimizer.step() scheduler.step() optimizer.zero_grad() if step % config.log_every == 0: eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600 n = config.log_every print( f"Step {step:5d} | " f"Loss: {acc_loss/n:.4f} | " f"LM: {acc_lm/n:.4f} | " f"Risk: {acc_risk_loss/n:.4f} | " f"P: {acc_precision/n:.3f} | " f"R: {acc_recall/n:.3f} | " f"F1: {acc_f1/n:.3f} | " f"ETA: {eta:.1f}h" ) training_log["steps"].append({ "step": step, "loss": acc_loss/n, "lm_loss": acc_lm/n, "risk_loss": acc_risk_loss/n, "precision": acc_precision/n, "recall": acc_recall/n, "f1": acc_f1/n }) acc_loss, acc_lm, acc_risk_loss = 0, 0, 0 acc_precision, acc_recall, acc_f1 = 0, 0, 0 if step % config.save_every == 0: ckpt = os.path.join(config.output_dir, f"ckpt_{step}") os.makedirs(ckpt, exist_ok=True) model.save_pretrained(ckpt) torch.save({ 'risk_predictor': risk_predictor.state_dict(), 'step': step }, os.path.join(ckpt, "risk_predictor.pt")) print(f">>> Saved: {ckpt}") if step % config.eval_every == 0: print(f"\n{'='*50}") print(f"SEPARATION EVAL @ Step {step}") print(f"{'='*50}") p_pos, p_neg, separation, n_pos_s, n_neg_s = \ compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=30) print(f" P(+) = {p_pos:.4f} (n={n_pos_s})") print(f" P(-) = {p_neg:.4f} (n={n_neg_s})") print(f" SEPARATION = {separation:.1f}x") print(f" [LLaMA-8B baseline: 125x]") training_log["separations"].append({ "step": step, "p_pos": p_pos, "p_neg": p_neg, "separation": separation, "n_pos": n_pos_s, "n_neg": n_neg_s }) with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f: json.dump(training_log, f, indent=2) print(f"{'='*50}\n") model.train() risk_predictor.train() # FINAL print("\n" + "=" * 70) print("FINAL CROSS-ARCHITECTURE COMPARISON") print("=" * 70) p_pos, p_neg, separation, n_pos, n_neg = \ compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50) d = d_model nl = n_layers print(f""" ┌─────────────────────────────────────────────────────────┐ │ CROSS-ARCHITECTURE REPLICATION v2 (FIXED) │ ├─────────────────────────────────────────────────────────┤ │ LLaMA-3.1-8B: 125x (P+=0.998, P-=0.008) │ │ Qwen2.5-3B: {separation:>5.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f}) │ ├─────────────────────────────────────────────────────────┤ │ Architecture: Qwen2 ({d}d, {nl}L) vs LLaMA (4096d, 32L) │ │ Probe layers: {config.probe_layers} │ │ d_fiber: 16 (identical) │ │ Method: IDENTICAL │ │ Conclusion: {"✅ GENERALIZES" if separation > 10 else "⚠️ INVESTIGATE"} │ └─────────────────────────────────────────────────────────┘ """) training_log["final"] = { "p_pos": p_pos, "p_neg": p_neg, "separation": separation, "n_pos": n_pos, "n_neg": n_neg, "conclusion": "generalizes" if separation > 10 else "needs_investigation" } with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f: json.dump(training_log, f, indent=2) final = os.path.join(config.output_dir, "final") os.makedirs(final, exist_ok=True) model.save_pretrained(final) torch.save({ 'risk_predictor': risk_predictor.state_dict(), 'step': step, 'separation': separation, 'p_pos': p_pos, 'p_neg': p_neg }, os.path.join(final, "risk_predictor.pt")) print(f"Done! Log: {config.output_dir}/replication_log.json") if __name__ == "__main__": main()