π§ Full weight release: 9 probes Γ 3 architectures + production adapter + training code
297244f verified | #!/usr/bin/env python3 | |
| """ | |
| FIBER DIMENSION SWEEP + EXTENDED TRAINING: Qwen2.5-3B | |
| ====================================================== | |
| 1. Quick sweep: d_fiber = [8, 16, 32] @ 800 steps each | |
| 2. Full training: best dimension @ 5000 steps | |
| 3. Target: 70x+ separation | |
| Loads model ONCE, runs all sweeps, then extends training on winner. | |
| Author: Logan Napolitano / Proprioception AI | |
| Date: February 2026 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from datasets import load_dataset | |
| import os | |
| import time | |
| import random | |
| import json | |
| import gc | |
| from dataclasses import dataclass, field | |
| from typing import Tuple, List, Dict | |
| # Sweep configuration | |
| SWEEP_DIMS = [8, 16, 32] | |
| SWEEP_STEPS = 800 | |
| FULL_TRAINING_STEPS = 5000 | |
| TARGET_SEPARATION = 70.0 | |
| class Config: | |
| model_path: str = "Qwen/Qwen2.5-3B" | |
| output_dir: str = "./results/qwen3b_dimension_sweep" | |
| probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27]) | |
| d_fiber: int = 16 # Will be varied during sweep | |
| d_control: int = 64 | |
| max_steps: int = 800 | |
| batch_size: int = 1 | |
| grad_accum: int = 8 | |
| max_length: int = 256 | |
| lr_lora: float = 2e-5 | |
| lr_predictor: float = 1e-4 | |
| weight_decay: float = 0.01 | |
| rep_window: int = 32 | |
| log_every: int = 50 | |
| eval_every: int = 200 | |
| class RiskPredictor(nn.Module): | |
| def __init__(self, d_model: int, d_fiber: int, probe_layers: List[int], d_control: int = 64): | |
| super().__init__() | |
| self.probe_layers = probe_layers | |
| self.d_fiber = d_fiber | |
| n_probes = len(probe_layers) | |
| self.fiber_projs = nn.ModuleList([ | |
| nn.Linear(d_model, d_fiber, bias=False) | |
| for _ in range(n_probes) | |
| ]) | |
| self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes) | |
| self.predictor = nn.Sequential( | |
| nn.Linear(d_fiber, d_control), | |
| nn.GELU(), | |
| nn.Linear(d_control, d_control), | |
| nn.GELU(), | |
| nn.Linear(d_control, 1) | |
| ) | |
| for proj in self.fiber_projs: | |
| nn.init.normal_(proj.weight, std=0.02) | |
| def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor: | |
| fibers = [] | |
| for i, layer_idx in enumerate(self.probe_layers): | |
| if layer_idx < len(hidden_states): | |
| fiber = self.fiber_projs[i](hidden_states[layer_idx].float()) | |
| fibers.append(fiber) | |
| weights = F.softmax(self.layer_weights[:len(fibers)], dim=0) | |
| aggregated = sum(w * f for w, f in zip(weights, fibers)) | |
| return self.predictor(aggregated).squeeze(-1) | |
| def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor: | |
| B, S = input_ids.shape | |
| labels = torch.zeros(B, S, device=input_ids.device) | |
| for offset in range(1, min(window + 1, S)): | |
| if offset < S: | |
| matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float() | |
| labels[:, offset:] = torch.maximum(labels[:, offset:], matches) | |
| return labels | |
| def compute_separation(predictor, model, tokenizer, device, config, n_samples=30): | |
| model.eval() | |
| predictor.eval() | |
| pos_scores, neg_scores = [], [] | |
| prompts = [ | |
| "The meaning of life according to philosophy is", | |
| "In the year 2050, technology will", | |
| "The history of mathematics begins with", | |
| "Climate change affects the planet by", | |
| "Neural networks learn patterns through", | |
| "The ocean contains many species of", | |
| "Music has evolved significantly since", | |
| "Economic theories suggest that markets", | |
| "The human brain processes information", | |
| "Ancient civilizations developed writing", | |
| ] | |
| with torch.no_grad(): | |
| for i in range(n_samples): | |
| prompt = prompts[i % len(prompts)] | |
| inp = tokenizer(prompt, return_tensors='pt') | |
| input_ids = inp['input_ids'].to(device) | |
| attn = inp['attention_mask'].to(device) | |
| out = model.generate(input_ids, attention_mask=attn, max_new_tokens=80, | |
| do_sample=True, temperature=0.9, top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id) | |
| outputs = model(out, output_hidden_states=True) | |
| risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy() | |
| labels = compute_repetition_labels(out, config.rep_window)[0].cpu().numpy() | |
| for t in range(len(risk)): | |
| (pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t])) | |
| if pos_scores and neg_scores: | |
| p_pos, p_neg = sum(pos_scores)/len(pos_scores), sum(neg_scores)/len(neg_scores) | |
| return p_pos, p_neg, p_pos/max(p_neg, 1e-8), len(pos_scores), len(neg_scores) | |
| return 0, 0, 0, 0, 0 | |
| def train_probe(model, tokenizer, texts, device, d_model, config, d_fiber, max_steps, | |
| existing_predictor=None, existing_optimizer=None): | |
| """Train a probe with given d_fiber. Returns (predictor, final_separation, history).""" | |
| if existing_predictor is None: | |
| predictor = RiskPredictor(d_model, d_fiber, config.probe_layers, config.d_control).to(device).float() | |
| else: | |
| predictor = existing_predictor | |
| lora_params = [p for p in model.parameters() if p.requires_grad] | |
| if existing_optimizer is None: | |
| optimizer = torch.optim.AdamW([ | |
| {'params': lora_params, 'lr': config.lr_lora}, | |
| {'params': predictor.parameters(), 'lr': config.lr_predictor} | |
| ], weight_decay=config.weight_decay) | |
| else: | |
| optimizer = existing_optimizer | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6) | |
| model.train() | |
| predictor.train() | |
| history = {"steps": [], "separations": []} | |
| step, data_idx = 0, 0 | |
| acc_loss, acc_risk = 0, 0 | |
| start = time.time() | |
| while step < max_steps: | |
| batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)] | |
| data_idx += config.batch_size | |
| enc = tokenizer(batch, truncation=True, max_length=config.max_length, | |
| padding='max_length', return_tensors='pt') | |
| input_ids = enc['input_ids'].to(device) | |
| attention_mask = enc['attention_mask'].to(device) | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask, | |
| labels=input_ids, output_hidden_states=True) | |
| lm_loss = outputs.loss | |
| risk_logits = predictor(outputs.hidden_states) | |
| rep_labels = compute_repetition_labels(input_ids, config.rep_window) | |
| mask = attention_mask.float() | |
| n_pos = (rep_labels * mask).sum().clamp(min=1) | |
| n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1) | |
| pos_weight = (n_neg / n_pos).clamp(max=10.0) | |
| bce = F.binary_cross_entropy_with_logits( | |
| risk_logits, rep_labels, | |
| pos_weight=torch.ones_like(rep_labels) * pos_weight, reduction='none') | |
| risk_loss = (bce * mask).sum() / mask.sum() | |
| loss = lm_loss + risk_loss | |
| (loss / config.grad_accum).backward() | |
| acc_loss += loss.item() | |
| acc_risk += risk_loss.item() | |
| step += 1 | |
| if step % config.grad_accum == 0: | |
| torch.nn.utils.clip_grad_norm_(list(lora_params) + list(predictor.parameters()), 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| if step % config.log_every == 0: | |
| eta = (max_steps - step) / (step / (time.time() - start)) / 60 | |
| print(f" Step {step:4d}/{max_steps} | Loss: {acc_loss/config.log_every:.3f} | " | |
| f"Risk: {acc_risk/config.log_every:.3f} | ETA: {eta:.1f}m") | |
| history["steps"].append({"step": step, "loss": acc_loss/config.log_every}) | |
| acc_loss, acc_risk = 0, 0 | |
| if step % config.eval_every == 0: | |
| p_pos, p_neg, sep, n_p, n_n = compute_separation(predictor, model, tokenizer, device, config) | |
| print(f" >>> SEPARATION @ {step}: {sep:.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f})") | |
| history["separations"].append({"step": step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg}) | |
| model.train() | |
| predictor.train() | |
| # Final eval | |
| p_pos, p_neg, final_sep, _, _ = compute_separation(predictor, model, tokenizer, device, config, n_samples=50) | |
| return predictor, optimizer, final_sep, p_pos, p_neg, history | |
| def main(): | |
| config = Config() | |
| os.makedirs(config.output_dir, exist_ok=True) | |
| print("=" * 70) | |
| print("FIBER DIMENSION SWEEP + EXTENDED TRAINING") | |
| print(f"Target: {TARGET_SEPARATION}x separation on Qwen2.5-3B") | |
| print("=" * 70) | |
| print(f"Sweep dimensions: {SWEEP_DIMS}") | |
| print(f"Sweep steps each: {SWEEP_STEPS}") | |
| print(f"Full training steps: {FULL_TRAINING_STEPS}") | |
| print() | |
| tokenizer = AutoTokenizer.from_pretrained(config.model_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading Qwen2.5-3B...") | |
| bnb = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16) | |
| model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) | |
| print("Adding LoRA...") | |
| model = get_peft_model(model, LoraConfig( | |
| r=64, lora_alpha=128, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")) | |
| model.print_trainable_parameters() | |
| device = next(model.parameters()).device | |
| d_model = model.config.hidden_size | |
| print("Loading data...") | |
| ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") | |
| texts = [ex['text'] for ex in ds if len(ex['text']) > 50] | |
| random.shuffle(texts) | |
| print(f"Loaded {len(texts)} samples\n") | |
| # ========================================================================= | |
| # PHASE 1: DIMENSION SWEEP | |
| # ========================================================================= | |
| print("=" * 70) | |
| print("PHASE 1: DIMENSION SWEEP") | |
| print("=" * 70) | |
| sweep_results = {} | |
| best_dim, best_sep = None, 0 | |
| for d_fiber in SWEEP_DIMS: | |
| print(f"\n{'β'*50}") | |
| print(f"Testing d_fiber = {d_fiber}") | |
| print(f" Projection: {d_model} β {d_fiber} ({d_model//d_fiber}:1 compression)") | |
| print(f"{'β'*50}") | |
| # Reset LoRA weights for fair comparison | |
| for name, param in model.named_parameters(): | |
| if 'lora' in name.lower() and param.requires_grad: | |
| if 'weight' in name: | |
| nn.init.kaiming_uniform_(param) | |
| elif 'bias' in name: | |
| nn.init.zeros_(param) | |
| predictor, optimizer, sep, p_pos, p_neg, history = train_probe( | |
| model, tokenizer, texts, device, d_model, config, | |
| d_fiber=d_fiber, max_steps=SWEEP_STEPS) | |
| sweep_results[d_fiber] = { | |
| "separation": sep, "p_pos": p_pos, "p_neg": p_neg, "history": history} | |
| print(f"\n d_fiber={d_fiber} RESULT: {sep:.1f}x separation") | |
| if sep > best_sep: | |
| best_sep = sep | |
| best_dim = d_fiber | |
| best_predictor = predictor | |
| best_optimizer = optimizer | |
| # Clear predictor if not best | |
| if d_fiber != best_dim: | |
| del predictor | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Sweep summary | |
| print("\n" + "=" * 70) | |
| print("SWEEP RESULTS") | |
| print("=" * 70) | |
| for d, res in sweep_results.items(): | |
| marker = " β BEST" if d == best_dim else "" | |
| print(f" d_fiber={d:2d}: {res['separation']:6.1f}x (P+={res['p_pos']:.3f}, P-={res['p_neg']:.3f}){marker}") | |
| print() | |
| # ========================================================================= | |
| # PHASE 2: EXTENDED TRAINING ON BEST DIMENSION | |
| # ========================================================================= | |
| print("=" * 70) | |
| print(f"PHASE 2: EXTENDED TRAINING (d_fiber={best_dim})") | |
| print(f"Current: {best_sep:.1f}x β Target: {TARGET_SEPARATION}x") | |
| print("=" * 70) | |
| remaining_steps = FULL_TRAINING_STEPS - SWEEP_STEPS | |
| print(f"Running {remaining_steps} more steps...\n") | |
| config.eval_every = 400 # Less frequent evals for extended training | |
| config.log_every = 100 | |
| best_predictor, _, final_sep, final_p_pos, final_p_neg, ext_history = train_probe( | |
| model, tokenizer, texts, device, d_model, config, | |
| d_fiber=best_dim, max_steps=remaining_steps, | |
| existing_predictor=best_predictor, existing_optimizer=best_optimizer) | |
| # ========================================================================= | |
| # FINAL RESULTS | |
| # ========================================================================= | |
| print("\n" + "=" * 70) | |
| print("FINAL RESULTS") | |
| print("=" * 70) | |
| target_hit = "β TARGET HIT" if final_sep >= TARGET_SEPARATION else f"β οΈ {final_sep:.1f}x < {TARGET_SEPARATION}x target" | |
| print(f""" | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β CROSS-ARCHITECTURE REPLICATION RESULTS β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ | |
| β β | |
| β LLaMA-3.1-8B baseline: 125x separation β | |
| β β | |
| β Qwen2.5-3B (this run): β | |
| β Best d_fiber: {best_dim} β | |
| β Final separation: {final_sep:.1f}x β | |
| β P(+): {final_p_pos:.4f} β | |
| β P(-): {final_p_neg:.4f} β | |
| β β | |
| β {target_hit:^53} β | |
| β β | |
| β Sweep results: β""") | |
| for d, res in sweep_results.items(): | |
| print(f"β d_fiber={d:2d}: {res['separation']:5.1f}x{' β selected' if d == best_dim else '':>20} β") | |
| print(f"""β β | |
| β Method: Fiber projection (identical to LLaMA-8B) β | |
| β Probe layers: {config.probe_layers} β | |
| β Architecture: Qwen2 (2048d, 36L) β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """) | |
| # Save everything | |
| full_results = { | |
| "experiment": "qwen3b_dimension_sweep_extended", | |
| "target_separation": TARGET_SEPARATION, | |
| "sweep_dims": SWEEP_DIMS, | |
| "sweep_steps": SWEEP_STEPS, | |
| "full_training_steps": FULL_TRAINING_STEPS, | |
| "best_d_fiber": best_dim, | |
| "final_separation": final_sep, | |
| "final_p_pos": final_p_pos, | |
| "final_p_neg": final_p_neg, | |
| "target_hit": final_sep >= TARGET_SEPARATION, | |
| "sweep_results": {str(k): {"separation": v["separation"], "p_pos": v["p_pos"], "p_neg": v["p_neg"]} | |
| for k, v in sweep_results.items()}, | |
| "baseline_comparison": { | |
| "llama_8b_separation": 125.0, | |
| "qwen_3b_separation": final_sep, | |
| "ratio": final_sep / 125.0 | |
| } | |
| } | |
| with open(os.path.join(config.output_dir, "full_results.json"), 'w') as f: | |
| json.dump(full_results, f, indent=2) | |
| # Save best model | |
| final_dir = os.path.join(config.output_dir, "final") | |
| os.makedirs(final_dir, exist_ok=True) | |
| model.save_pretrained(final_dir) | |
| torch.save({ | |
| 'risk_predictor': best_predictor.state_dict(), | |
| 'd_fiber': best_dim, | |
| 'separation': final_sep, | |
| 'p_pos': final_p_pos, | |
| 'p_neg': final_p_neg | |
| }, os.path.join(final_dir, "risk_predictor.pt")) | |
| print(f"Results saved to {config.output_dir}/full_results.json") | |
| print(f"Model saved to {final_dir}/") | |
| print("\nDONE!") | |
| if __name__ == "__main__": | |
| main() | |