cfhot-weights / code /training_pipelines /08_qwen3b_dimension_sweep_FULL.py
LoganResearch's picture
🧠 Full weight release: 9 probes Γ— 3 architectures + production adapter + training code
297244f verified
#!/usr/bin/env python3
"""
FIBER DIMENSION SWEEP + EXTENDED TRAINING: Qwen2.5-3B
======================================================
1. Quick sweep: d_fiber = [8, 16, 32] @ 800 steps each
2. Full training: best dimension @ 5000 steps
3. Target: 70x+ separation
Loads model ONCE, runs all sweeps, then extends training on winner.
Author: Logan Napolitano / Proprioception AI
Date: February 2026
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import os
import time
import random
import json
import gc
from dataclasses import dataclass, field
from typing import Tuple, List, Dict
# Sweep configuration
SWEEP_DIMS = [8, 16, 32]
SWEEP_STEPS = 800
FULL_TRAINING_STEPS = 5000
TARGET_SEPARATION = 70.0
@dataclass
class Config:
model_path: str = "Qwen/Qwen2.5-3B"
output_dir: str = "./results/qwen3b_dimension_sweep"
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
d_fiber: int = 16 # Will be varied during sweep
d_control: int = 64
max_steps: int = 800
batch_size: int = 1
grad_accum: int = 8
max_length: int = 256
lr_lora: float = 2e-5
lr_predictor: float = 1e-4
weight_decay: float = 0.01
rep_window: int = 32
log_every: int = 50
eval_every: int = 200
class RiskPredictor(nn.Module):
def __init__(self, d_model: int, d_fiber: int, probe_layers: List[int], d_control: int = 64):
super().__init__()
self.probe_layers = probe_layers
self.d_fiber = d_fiber
n_probes = len(probe_layers)
self.fiber_projs = nn.ModuleList([
nn.Linear(d_model, d_fiber, bias=False)
for _ in range(n_probes)
])
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
self.predictor = nn.Sequential(
nn.Linear(d_fiber, d_control),
nn.GELU(),
nn.Linear(d_control, d_control),
nn.GELU(),
nn.Linear(d_control, 1)
)
for proj in self.fiber_projs:
nn.init.normal_(proj.weight, std=0.02)
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
fibers = []
for i, layer_idx in enumerate(self.probe_layers):
if layer_idx < len(hidden_states):
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
fibers.append(fiber)
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
aggregated = sum(w * f for w, f in zip(weights, fibers))
return self.predictor(aggregated).squeeze(-1)
def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
B, S = input_ids.shape
labels = torch.zeros(B, S, device=input_ids.device)
for offset in range(1, min(window + 1, S)):
if offset < S:
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
return labels
def compute_separation(predictor, model, tokenizer, device, config, n_samples=30):
model.eval()
predictor.eval()
pos_scores, neg_scores = [], []
prompts = [
"The meaning of life according to philosophy is",
"In the year 2050, technology will",
"The history of mathematics begins with",
"Climate change affects the planet by",
"Neural networks learn patterns through",
"The ocean contains many species of",
"Music has evolved significantly since",
"Economic theories suggest that markets",
"The human brain processes information",
"Ancient civilizations developed writing",
]
with torch.no_grad():
for i in range(n_samples):
prompt = prompts[i % len(prompts)]
inp = tokenizer(prompt, return_tensors='pt')
input_ids = inp['input_ids'].to(device)
attn = inp['attention_mask'].to(device)
out = model.generate(input_ids, attention_mask=attn, max_new_tokens=80,
do_sample=True, temperature=0.9, top_p=0.95,
pad_token_id=tokenizer.eos_token_id)
outputs = model(out, output_hidden_states=True)
risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy()
labels = compute_repetition_labels(out, config.rep_window)[0].cpu().numpy()
for t in range(len(risk)):
(pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t]))
if pos_scores and neg_scores:
p_pos, p_neg = sum(pos_scores)/len(pos_scores), sum(neg_scores)/len(neg_scores)
return p_pos, p_neg, p_pos/max(p_neg, 1e-8), len(pos_scores), len(neg_scores)
return 0, 0, 0, 0, 0
def train_probe(model, tokenizer, texts, device, d_model, config, d_fiber, max_steps,
existing_predictor=None, existing_optimizer=None):
"""Train a probe with given d_fiber. Returns (predictor, final_separation, history)."""
if existing_predictor is None:
predictor = RiskPredictor(d_model, d_fiber, config.probe_layers, config.d_control).to(device).float()
else:
predictor = existing_predictor
lora_params = [p for p in model.parameters() if p.requires_grad]
if existing_optimizer is None:
optimizer = torch.optim.AdamW([
{'params': lora_params, 'lr': config.lr_lora},
{'params': predictor.parameters(), 'lr': config.lr_predictor}
], weight_decay=config.weight_decay)
else:
optimizer = existing_optimizer
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6)
model.train()
predictor.train()
history = {"steps": [], "separations": []}
step, data_idx = 0, 0
acc_loss, acc_risk = 0, 0
start = time.time()
while step < max_steps:
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
data_idx += config.batch_size
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
padding='max_length', return_tensors='pt')
input_ids = enc['input_ids'].to(device)
attention_mask = enc['attention_mask'].to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
labels=input_ids, output_hidden_states=True)
lm_loss = outputs.loss
risk_logits = predictor(outputs.hidden_states)
rep_labels = compute_repetition_labels(input_ids, config.rep_window)
mask = attention_mask.float()
n_pos = (rep_labels * mask).sum().clamp(min=1)
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
pos_weight = (n_neg / n_pos).clamp(max=10.0)
bce = F.binary_cross_entropy_with_logits(
risk_logits, rep_labels,
pos_weight=torch.ones_like(rep_labels) * pos_weight, reduction='none')
risk_loss = (bce * mask).sum() / mask.sum()
loss = lm_loss + risk_loss
(loss / config.grad_accum).backward()
acc_loss += loss.item()
acc_risk += risk_loss.item()
step += 1
if step % config.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(list(lora_params) + list(predictor.parameters()), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % config.log_every == 0:
eta = (max_steps - step) / (step / (time.time() - start)) / 60
print(f" Step {step:4d}/{max_steps} | Loss: {acc_loss/config.log_every:.3f} | "
f"Risk: {acc_risk/config.log_every:.3f} | ETA: {eta:.1f}m")
history["steps"].append({"step": step, "loss": acc_loss/config.log_every})
acc_loss, acc_risk = 0, 0
if step % config.eval_every == 0:
p_pos, p_neg, sep, n_p, n_n = compute_separation(predictor, model, tokenizer, device, config)
print(f" >>> SEPARATION @ {step}: {sep:.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f})")
history["separations"].append({"step": step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg})
model.train()
predictor.train()
# Final eval
p_pos, p_neg, final_sep, _, _ = compute_separation(predictor, model, tokenizer, device, config, n_samples=50)
return predictor, optimizer, final_sep, p_pos, p_neg, history
def main():
config = Config()
os.makedirs(config.output_dir, exist_ok=True)
print("=" * 70)
print("FIBER DIMENSION SWEEP + EXTENDED TRAINING")
print(f"Target: {TARGET_SEPARATION}x separation on Qwen2.5-3B")
print("=" * 70)
print(f"Sweep dimensions: {SWEEP_DIMS}")
print(f"Sweep steps each: {SWEEP_STEPS}")
print(f"Full training steps: {FULL_TRAINING_STEPS}")
print()
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading Qwen2.5-3B...")
bnb = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
model = AutoModelForCausalLM.from_pretrained(
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
print("Adding LoRA...")
model = get_peft_model(model, LoraConfig(
r=64, lora_alpha=128, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"))
model.print_trainable_parameters()
device = next(model.parameters()).device
d_model = model.config.hidden_size
print("Loading data...")
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
random.shuffle(texts)
print(f"Loaded {len(texts)} samples\n")
# =========================================================================
# PHASE 1: DIMENSION SWEEP
# =========================================================================
print("=" * 70)
print("PHASE 1: DIMENSION SWEEP")
print("=" * 70)
sweep_results = {}
best_dim, best_sep = None, 0
for d_fiber in SWEEP_DIMS:
print(f"\n{'─'*50}")
print(f"Testing d_fiber = {d_fiber}")
print(f" Projection: {d_model} β†’ {d_fiber} ({d_model//d_fiber}:1 compression)")
print(f"{'─'*50}")
# Reset LoRA weights for fair comparison
for name, param in model.named_parameters():
if 'lora' in name.lower() and param.requires_grad:
if 'weight' in name:
nn.init.kaiming_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
predictor, optimizer, sep, p_pos, p_neg, history = train_probe(
model, tokenizer, texts, device, d_model, config,
d_fiber=d_fiber, max_steps=SWEEP_STEPS)
sweep_results[d_fiber] = {
"separation": sep, "p_pos": p_pos, "p_neg": p_neg, "history": history}
print(f"\n d_fiber={d_fiber} RESULT: {sep:.1f}x separation")
if sep > best_sep:
best_sep = sep
best_dim = d_fiber
best_predictor = predictor
best_optimizer = optimizer
# Clear predictor if not best
if d_fiber != best_dim:
del predictor
gc.collect()
torch.cuda.empty_cache()
# Sweep summary
print("\n" + "=" * 70)
print("SWEEP RESULTS")
print("=" * 70)
for d, res in sweep_results.items():
marker = " ← BEST" if d == best_dim else ""
print(f" d_fiber={d:2d}: {res['separation']:6.1f}x (P+={res['p_pos']:.3f}, P-={res['p_neg']:.3f}){marker}")
print()
# =========================================================================
# PHASE 2: EXTENDED TRAINING ON BEST DIMENSION
# =========================================================================
print("=" * 70)
print(f"PHASE 2: EXTENDED TRAINING (d_fiber={best_dim})")
print(f"Current: {best_sep:.1f}x β†’ Target: {TARGET_SEPARATION}x")
print("=" * 70)
remaining_steps = FULL_TRAINING_STEPS - SWEEP_STEPS
print(f"Running {remaining_steps} more steps...\n")
config.eval_every = 400 # Less frequent evals for extended training
config.log_every = 100
best_predictor, _, final_sep, final_p_pos, final_p_neg, ext_history = train_probe(
model, tokenizer, texts, device, d_model, config,
d_fiber=best_dim, max_steps=remaining_steps,
existing_predictor=best_predictor, existing_optimizer=best_optimizer)
# =========================================================================
# FINAL RESULTS
# =========================================================================
print("\n" + "=" * 70)
print("FINAL RESULTS")
print("=" * 70)
target_hit = "βœ… TARGET HIT" if final_sep >= TARGET_SEPARATION else f"⚠️ {final_sep:.1f}x < {TARGET_SEPARATION}x target"
print(f"""
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ CROSS-ARCHITECTURE REPLICATION RESULTS β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ β”‚
β”‚ LLaMA-3.1-8B baseline: 125x separation β”‚
β”‚ β”‚
β”‚ Qwen2.5-3B (this run): β”‚
β”‚ Best d_fiber: {best_dim} β”‚
β”‚ Final separation: {final_sep:.1f}x β”‚
β”‚ P(+): {final_p_pos:.4f} β”‚
β”‚ P(-): {final_p_neg:.4f} β”‚
β”‚ β”‚
β”‚ {target_hit:^53} β”‚
β”‚ β”‚
β”‚ Sweep results: β”‚""")
for d, res in sweep_results.items():
print(f"β”‚ d_fiber={d:2d}: {res['separation']:5.1f}x{' ← selected' if d == best_dim else '':>20} β”‚")
print(f"""β”‚ β”‚
β”‚ Method: Fiber projection (identical to LLaMA-8B) β”‚
β”‚ Probe layers: {config.probe_layers} β”‚
β”‚ Architecture: Qwen2 (2048d, 36L) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
""")
# Save everything
full_results = {
"experiment": "qwen3b_dimension_sweep_extended",
"target_separation": TARGET_SEPARATION,
"sweep_dims": SWEEP_DIMS,
"sweep_steps": SWEEP_STEPS,
"full_training_steps": FULL_TRAINING_STEPS,
"best_d_fiber": best_dim,
"final_separation": final_sep,
"final_p_pos": final_p_pos,
"final_p_neg": final_p_neg,
"target_hit": final_sep >= TARGET_SEPARATION,
"sweep_results": {str(k): {"separation": v["separation"], "p_pos": v["p_pos"], "p_neg": v["p_neg"]}
for k, v in sweep_results.items()},
"baseline_comparison": {
"llama_8b_separation": 125.0,
"qwen_3b_separation": final_sep,
"ratio": final_sep / 125.0
}
}
with open(os.path.join(config.output_dir, "full_results.json"), 'w') as f:
json.dump(full_results, f, indent=2)
# Save best model
final_dir = os.path.join(config.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
model.save_pretrained(final_dir)
torch.save({
'risk_predictor': best_predictor.state_dict(),
'd_fiber': best_dim,
'separation': final_sep,
'p_pos': final_p_pos,
'p_neg': final_p_neg
}, os.path.join(final_dir, "risk_predictor.pt"))
print(f"Results saved to {config.output_dir}/full_results.json")
print(f"Model saved to {final_dir}/")
print("\nDONE!")
if __name__ == "__main__":
main()