π§ Full weight release: 9 probes Γ 3 architectures + production adapter + training code
297244f verified | #!/usr/bin/env python3 | |
| """ | |
| CROSS-ARCHITECTURE REPLICATION v2: Qwen2.5-3B Repetition Detection | |
| ==================================================================== | |
| FIX: Use 3 specific probe layers [9, 18, 27] instead of all 36. | |
| Matches Pipeline 02 methodology which achieved 125x-168x on LLaMA-8B. | |
| Changes from v1: | |
| - probe_layers = [9, 18, 27] (25%, 50%, 75% of 36 layers) | |
| - 3 fiber projections instead of 36 | |
| - Gradient signal concentrated, not diluted | |
| Author: Logan Napolitano / Proprioception AI | |
| Date: February 2026 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from datasets import load_dataset | |
| import os | |
| import time | |
| import random | |
| import json | |
| from dataclasses import dataclass, field | |
| from typing import Tuple, List | |
| class Config: | |
| model_path: str = "Qwen/Qwen2.5-3B" | |
| output_dir: str = "./results/qwen3b_repetition_v2_fixed" | |
| # Probe layers: 25%, 50%, 75% of 36 layers (matches Pipeline 02 methodology) | |
| probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27]) | |
| # Identical to Pipeline 01/02 | |
| d_fiber: int = 16 | |
| d_control: int = 64 | |
| max_steps: int = 10000 | |
| batch_size: int = 1 | |
| grad_accum: int = 8 | |
| max_length: int = 256 | |
| lr_lora: float = 2e-5 | |
| lr_predictor: float = 1e-4 | |
| weight_decay: float = 0.01 | |
| rep_window: int = 32 | |
| log_every: int = 10 | |
| save_every: int = 500 | |
| eval_every: int = 200 | |
| class RiskPredictor(nn.Module): | |
| """FIXED: Only 3 probe layers instead of all 36.""" | |
| def __init__(self, d_model: int, probe_layers: List[int], config: Config): | |
| super().__init__() | |
| self.config = config | |
| self.probe_layers = probe_layers | |
| n_probes = len(probe_layers) | |
| # Only 3 projections: 2048β16 each | |
| self.fiber_projs = nn.ModuleList([ | |
| nn.Linear(d_model, config.d_fiber, bias=False) | |
| for _ in range(n_probes) | |
| ]) | |
| self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes) | |
| self.predictor = nn.Sequential( | |
| nn.Linear(config.d_fiber, config.d_control), | |
| nn.GELU(), | |
| nn.Linear(config.d_control, config.d_control), | |
| nn.GELU(), | |
| nn.Linear(config.d_control, 1) | |
| ) | |
| for proj in self.fiber_projs: | |
| nn.init.normal_(proj.weight, std=0.02) | |
| def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor: | |
| fibers = [] | |
| for i, layer_idx in enumerate(self.probe_layers): | |
| if layer_idx < len(hidden_states): | |
| fiber = self.fiber_projs[i](hidden_states[layer_idx].float()) | |
| fibers.append(fiber) | |
| weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) | |
| aggregated = sum(w * f for w, f in zip(weights, fibers)) | |
| logits = self.predictor(aggregated).squeeze(-1) | |
| return logits | |
| def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor: | |
| B, S = input_ids.shape | |
| device = input_ids.device | |
| labels = torch.zeros(B, S, device=device) | |
| for offset in range(1, min(window + 1, S)): | |
| if offset < S: | |
| matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float() | |
| labels[:, offset:] = torch.maximum(labels[:, offset:], matches) | |
| return labels | |
| def compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50): | |
| model.eval() | |
| risk_predictor.eval() | |
| all_pos_scores = [] | |
| all_neg_scores = [] | |
| prompts = [ | |
| "The meaning of life according to philosophy is", | |
| "In the year 2050, technology will", | |
| "The history of mathematics begins with", | |
| "Climate change affects the planet by", | |
| "Neural networks learn patterns through", | |
| "The ocean contains many species of", | |
| "Music has evolved significantly since", | |
| "Economic theories suggest that markets", | |
| "The human brain processes information", | |
| "Ancient civilizations developed writing", | |
| ] | |
| with torch.no_grad(): | |
| for i in range(n_samples): | |
| prompt = prompts[i % len(prompts)] | |
| inp = tokenizer(prompt, return_tensors='pt') | |
| input_ids = inp['input_ids'].to(device) | |
| attn_mask = inp['attention_mask'].to(device) | |
| out = model.generate( | |
| input_ids, attention_mask=attn_mask, max_new_tokens=80, | |
| do_sample=True, temperature=0.9, top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| gen_outputs = model(out, output_hidden_states=True) | |
| gen_logits = risk_predictor(gen_outputs.hidden_states) | |
| gen_risk = torch.sigmoid(gen_logits) | |
| risk_vals = gen_risk[0].cpu().numpy() | |
| rep_labels = compute_repetition_labels_fast(out, config.rep_window) | |
| labels = rep_labels[0].cpu().numpy() | |
| for t in range(len(risk_vals)): | |
| if labels[t] > 0.5: | |
| all_pos_scores.append(float(risk_vals[t])) | |
| else: | |
| all_neg_scores.append(float(risk_vals[t])) | |
| if all_pos_scores and all_neg_scores: | |
| p_pos = sum(all_pos_scores) / len(all_pos_scores) | |
| p_neg = sum(all_neg_scores) / len(all_neg_scores) | |
| separation = p_pos / max(p_neg, 1e-8) | |
| return p_pos, p_neg, separation, len(all_pos_scores), len(all_neg_scores) | |
| return 0.0, 0.0, 0.0, 0, 0 | |
| def main(): | |
| config = Config() | |
| os.makedirs(config.output_dir, exist_ok=True) | |
| print("=" * 70) | |
| print("CROSS-ARCHITECTURE REPLICATION v2 (FIXED PROBE LAYERS)") | |
| print("=" * 70) | |
| print(f"Model: {config.model_path}") | |
| print(f"Probe layers: {config.probe_layers} (25%, 50%, 75%)") | |
| print(f"d_fiber: {config.d_fiber}, d_control: {config.d_control}") | |
| print(f"FIX: 3 focused projections instead of 36 diluted ones") | |
| print() | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading Qwen2.5-3B in 4-bit...") | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, quantization_config=bnb, | |
| device_map='auto', torch_dtype=torch.float16 | |
| ) | |
| model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) | |
| device = next(model.parameters()).device | |
| d_model = model.config.hidden_size | |
| n_layers = model.config.num_hidden_layers | |
| print(f"Architecture: Qwen2ForCausalLM") | |
| print(f"Hidden dim: {d_model}, Layers: {n_layers}") | |
| print(f"Probing layers: {config.probe_layers}") | |
| print() | |
| print("Adding LoRA...") | |
| model = get_peft_model(model, LoraConfig( | |
| r=64, lora_alpha=128, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" | |
| )) | |
| model.print_trainable_parameters() | |
| print("Adding Risk Predictor (3 probe layers)...") | |
| risk_predictor = RiskPredictor(d_model, config.probe_layers, config).to(device).float() | |
| rp_params = sum(p.numel() for p in risk_predictor.parameters()) | |
| print(f"Risk Predictor params: {rp_params:,}") | |
| print() | |
| print("Loading wikitext data...") | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") | |
| texts = [ex['text'] for ex in ds if len(ex['text']) > 50] | |
| random.shuffle(texts) | |
| print(f"Loaded {len(texts)} samples") | |
| lora_params = [p for p in model.parameters() if p.requires_grad] | |
| optimizer = torch.optim.AdamW([ | |
| {'params': lora_params, 'lr': config.lr_lora}, | |
| {'params': risk_predictor.parameters(), 'lr': config.lr_predictor} | |
| ], weight_decay=config.weight_decay) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=config.max_steps, eta_min=1e-6 | |
| ) | |
| training_log = { | |
| "experiment": "cross_architecture_replication_v2_fixed", | |
| "fix": "3 probe layers [9,18,27] instead of all 36", | |
| "source_model": "LLaMA-3.1-8B (4096d, 32L, probe [8,16,24])", | |
| "target_model": f"Qwen2.5-3B ({d_model}d, {n_layers}L, probe {config.probe_layers})", | |
| "d_fiber": config.d_fiber, | |
| "baseline_separation": "125x (LLaMA-8B repetition)", | |
| "steps": [], | |
| "separations": [] | |
| } | |
| print("=" * 70) | |
| print("TRAINING") | |
| print("=" * 70) | |
| model.train() | |
| risk_predictor.train() | |
| step = 0 | |
| data_idx = 0 | |
| acc_loss, acc_lm, acc_risk_loss = 0, 0, 0 | |
| acc_precision, acc_recall, acc_f1 = 0, 0, 0 | |
| start_time = time.time() | |
| while step < config.max_steps: | |
| batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)] | |
| data_idx += config.batch_size | |
| enc = tokenizer(batch, truncation=True, max_length=config.max_length, | |
| padding='max_length', return_tensors='pt') | |
| input_ids = enc['input_ids'].to(device) | |
| attention_mask = enc['attention_mask'].to(device) | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=input_ids, | |
| output_hidden_states=True | |
| ) | |
| lm_loss = outputs.loss | |
| # Pass full hidden_states β RiskPredictor indexes into specific layers | |
| risk_logits = risk_predictor(outputs.hidden_states) | |
| rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window) | |
| mask = attention_mask.float() | |
| n_pos = (rep_labels * mask).sum().clamp(min=1) | |
| n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1) | |
| pos_weight = (n_neg / n_pos).clamp(max=10.0) | |
| bce_loss = F.binary_cross_entropy_with_logits( | |
| risk_logits, rep_labels, | |
| pos_weight=torch.ones_like(rep_labels) * pos_weight, | |
| reduction='none' | |
| ) | |
| risk_loss = (bce_loss * mask).sum() / mask.sum() | |
| loss = lm_loss + risk_loss | |
| (loss / config.grad_accum).backward() | |
| with torch.no_grad(): | |
| risk_pred = torch.sigmoid(risk_logits) | |
| pred_binary = (risk_pred > 0.5).float() | |
| tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum() | |
| fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum() | |
| fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum() | |
| precision = tp / (tp + fp + 1e-8) | |
| recall = tp / (tp + fn + 1e-8) | |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |
| acc_loss += loss.item() | |
| acc_lm += lm_loss.item() | |
| acc_risk_loss += risk_loss.item() | |
| acc_precision += precision.item() | |
| acc_recall += recall.item() | |
| acc_f1 += f1.item() | |
| step += 1 | |
| if step % config.grad_accum == 0: | |
| torch.nn.utils.clip_grad_norm_( | |
| list(lora_params) + list(risk_predictor.parameters()), 1.0 | |
| ) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| if step % config.log_every == 0: | |
| eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600 | |
| n = config.log_every | |
| print( | |
| f"Step {step:5d} | " | |
| f"Loss: {acc_loss/n:.4f} | " | |
| f"LM: {acc_lm/n:.4f} | " | |
| f"Risk: {acc_risk_loss/n:.4f} | " | |
| f"P: {acc_precision/n:.3f} | " | |
| f"R: {acc_recall/n:.3f} | " | |
| f"F1: {acc_f1/n:.3f} | " | |
| f"ETA: {eta:.1f}h" | |
| ) | |
| training_log["steps"].append({ | |
| "step": step, "loss": acc_loss/n, "lm_loss": acc_lm/n, | |
| "risk_loss": acc_risk_loss/n, "precision": acc_precision/n, | |
| "recall": acc_recall/n, "f1": acc_f1/n | |
| }) | |
| acc_loss, acc_lm, acc_risk_loss = 0, 0, 0 | |
| acc_precision, acc_recall, acc_f1 = 0, 0, 0 | |
| if step % config.save_every == 0: | |
| ckpt = os.path.join(config.output_dir, f"ckpt_{step}") | |
| os.makedirs(ckpt, exist_ok=True) | |
| model.save_pretrained(ckpt) | |
| torch.save({ | |
| 'risk_predictor': risk_predictor.state_dict(), | |
| 'step': step | |
| }, os.path.join(ckpt, "risk_predictor.pt")) | |
| print(f">>> Saved: {ckpt}") | |
| if step % config.eval_every == 0: | |
| print(f"\n{'='*50}") | |
| print(f"SEPARATION EVAL @ Step {step}") | |
| print(f"{'='*50}") | |
| p_pos, p_neg, separation, n_pos_s, n_neg_s = \ | |
| compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=30) | |
| print(f" P(+) = {p_pos:.4f} (n={n_pos_s})") | |
| print(f" P(-) = {p_neg:.4f} (n={n_neg_s})") | |
| print(f" SEPARATION = {separation:.1f}x") | |
| print(f" [LLaMA-8B baseline: 125x]") | |
| training_log["separations"].append({ | |
| "step": step, "p_pos": p_pos, "p_neg": p_neg, | |
| "separation": separation, "n_pos": n_pos_s, "n_neg": n_neg_s | |
| }) | |
| with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f: | |
| json.dump(training_log, f, indent=2) | |
| print(f"{'='*50}\n") | |
| model.train() | |
| risk_predictor.train() | |
| # FINAL | |
| print("\n" + "=" * 70) | |
| print("FINAL CROSS-ARCHITECTURE COMPARISON") | |
| print("=" * 70) | |
| p_pos, p_neg, separation, n_pos, n_neg = \ | |
| compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50) | |
| d = d_model | |
| nl = n_layers | |
| print(f""" | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β CROSS-ARCHITECTURE REPLICATION v2 (FIXED) β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ | |
| β LLaMA-3.1-8B: 125x (P+=0.998, P-=0.008) β | |
| β Qwen2.5-3B: {separation:>5.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f}) β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ | |
| β Architecture: Qwen2 ({d}d, {nl}L) vs LLaMA (4096d, 32L) β | |
| β Probe layers: {config.probe_layers} β | |
| β d_fiber: 16 (identical) β | |
| β Method: IDENTICAL β | |
| β Conclusion: {"β GENERALIZES" if separation > 10 else "β οΈ INVESTIGATE"} β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """) | |
| training_log["final"] = { | |
| "p_pos": p_pos, "p_neg": p_neg, "separation": separation, | |
| "n_pos": n_pos, "n_neg": n_neg, | |
| "conclusion": "generalizes" if separation > 10 else "needs_investigation" | |
| } | |
| with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f: | |
| json.dump(training_log, f, indent=2) | |
| final = os.path.join(config.output_dir, "final") | |
| os.makedirs(final, exist_ok=True) | |
| model.save_pretrained(final) | |
| torch.save({ | |
| 'risk_predictor': risk_predictor.state_dict(), | |
| 'step': step, 'separation': separation, | |
| 'p_pos': p_pos, 'p_neg': p_neg | |
| }, os.path.join(final, "risk_predictor.pt")) | |
| print(f"Done! Log: {config.output_dir}/replication_log.json") | |
| if __name__ == "__main__": | |
| main() | |