cfhot-weights / code /training_pipelines /07b_qwen3b_repetition_FIXED.py
LoganResearch's picture
🧠 Full weight release: 9 probes Γ— 3 architectures + production adapter + training code
297244f verified
#!/usr/bin/env python3
"""
CROSS-ARCHITECTURE REPLICATION v2: Qwen2.5-3B Repetition Detection
====================================================================
FIX: Use 3 specific probe layers [9, 18, 27] instead of all 36.
Matches Pipeline 02 methodology which achieved 125x-168x on LLaMA-8B.
Changes from v1:
- probe_layers = [9, 18, 27] (25%, 50%, 75% of 36 layers)
- 3 fiber projections instead of 36
- Gradient signal concentrated, not diluted
Author: Logan Napolitano / Proprioception AI
Date: February 2026
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import os
import time
import random
import json
from dataclasses import dataclass, field
from typing import Tuple, List
@dataclass
class Config:
model_path: str = "Qwen/Qwen2.5-3B"
output_dir: str = "./results/qwen3b_repetition_v2_fixed"
# Probe layers: 25%, 50%, 75% of 36 layers (matches Pipeline 02 methodology)
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
# Identical to Pipeline 01/02
d_fiber: int = 16
d_control: int = 64
max_steps: int = 10000
batch_size: int = 1
grad_accum: int = 8
max_length: int = 256
lr_lora: float = 2e-5
lr_predictor: float = 1e-4
weight_decay: float = 0.01
rep_window: int = 32
log_every: int = 10
save_every: int = 500
eval_every: int = 200
class RiskPredictor(nn.Module):
"""FIXED: Only 3 probe layers instead of all 36."""
def __init__(self, d_model: int, probe_layers: List[int], config: Config):
super().__init__()
self.config = config
self.probe_layers = probe_layers
n_probes = len(probe_layers)
# Only 3 projections: 2048β†’16 each
self.fiber_projs = nn.ModuleList([
nn.Linear(d_model, config.d_fiber, bias=False)
for _ in range(n_probes)
])
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
self.predictor = nn.Sequential(
nn.Linear(config.d_fiber, config.d_control),
nn.GELU(),
nn.Linear(config.d_control, config.d_control),
nn.GELU(),
nn.Linear(config.d_control, 1)
)
for proj in self.fiber_projs:
nn.init.normal_(proj.weight, std=0.02)
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
fibers = []
for i, layer_idx in enumerate(self.probe_layers):
if layer_idx < len(hidden_states):
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
fibers.append(fiber)
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
aggregated = sum(w * f for w, f in zip(weights, fibers))
logits = self.predictor(aggregated).squeeze(-1)
return logits
def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
B, S = input_ids.shape
device = input_ids.device
labels = torch.zeros(B, S, device=device)
for offset in range(1, min(window + 1, S)):
if offset < S:
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
return labels
def compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50):
model.eval()
risk_predictor.eval()
all_pos_scores = []
all_neg_scores = []
prompts = [
"The meaning of life according to philosophy is",
"In the year 2050, technology will",
"The history of mathematics begins with",
"Climate change affects the planet by",
"Neural networks learn patterns through",
"The ocean contains many species of",
"Music has evolved significantly since",
"Economic theories suggest that markets",
"The human brain processes information",
"Ancient civilizations developed writing",
]
with torch.no_grad():
for i in range(n_samples):
prompt = prompts[i % len(prompts)]
inp = tokenizer(prompt, return_tensors='pt')
input_ids = inp['input_ids'].to(device)
attn_mask = inp['attention_mask'].to(device)
out = model.generate(
input_ids, attention_mask=attn_mask, max_new_tokens=80,
do_sample=True, temperature=0.9, top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
gen_outputs = model(out, output_hidden_states=True)
gen_logits = risk_predictor(gen_outputs.hidden_states)
gen_risk = torch.sigmoid(gen_logits)
risk_vals = gen_risk[0].cpu().numpy()
rep_labels = compute_repetition_labels_fast(out, config.rep_window)
labels = rep_labels[0].cpu().numpy()
for t in range(len(risk_vals)):
if labels[t] > 0.5:
all_pos_scores.append(float(risk_vals[t]))
else:
all_neg_scores.append(float(risk_vals[t]))
if all_pos_scores and all_neg_scores:
p_pos = sum(all_pos_scores) / len(all_pos_scores)
p_neg = sum(all_neg_scores) / len(all_neg_scores)
separation = p_pos / max(p_neg, 1e-8)
return p_pos, p_neg, separation, len(all_pos_scores), len(all_neg_scores)
return 0.0, 0.0, 0.0, 0, 0
def main():
config = Config()
os.makedirs(config.output_dir, exist_ok=True)
print("=" * 70)
print("CROSS-ARCHITECTURE REPLICATION v2 (FIXED PROBE LAYERS)")
print("=" * 70)
print(f"Model: {config.model_path}")
print(f"Probe layers: {config.probe_layers} (25%, 50%, 75%)")
print(f"d_fiber: {config.d_fiber}, d_control: {config.d_control}")
print(f"FIX: 3 focused projections instead of 36 diluted ones")
print()
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Loading Qwen2.5-3B in 4-bit...")
bnb = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
config.model_path, quantization_config=bnb,
device_map='auto', torch_dtype=torch.float16
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
device = next(model.parameters()).device
d_model = model.config.hidden_size
n_layers = model.config.num_hidden_layers
print(f"Architecture: Qwen2ForCausalLM")
print(f"Hidden dim: {d_model}, Layers: {n_layers}")
print(f"Probing layers: {config.probe_layers}")
print()
print("Adding LoRA...")
model = get_peft_model(model, LoraConfig(
r=64, lora_alpha=128,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
))
model.print_trainable_parameters()
print("Adding Risk Predictor (3 probe layers)...")
risk_predictor = RiskPredictor(d_model, config.probe_layers, config).to(device).float()
rp_params = sum(p.numel() for p in risk_predictor.parameters())
print(f"Risk Predictor params: {rp_params:,}")
print()
print("Loading wikitext data...")
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
random.shuffle(texts)
print(f"Loaded {len(texts)} samples")
lora_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{'params': lora_params, 'lr': config.lr_lora},
{'params': risk_predictor.parameters(), 'lr': config.lr_predictor}
], weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config.max_steps, eta_min=1e-6
)
training_log = {
"experiment": "cross_architecture_replication_v2_fixed",
"fix": "3 probe layers [9,18,27] instead of all 36",
"source_model": "LLaMA-3.1-8B (4096d, 32L, probe [8,16,24])",
"target_model": f"Qwen2.5-3B ({d_model}d, {n_layers}L, probe {config.probe_layers})",
"d_fiber": config.d_fiber,
"baseline_separation": "125x (LLaMA-8B repetition)",
"steps": [],
"separations": []
}
print("=" * 70)
print("TRAINING")
print("=" * 70)
model.train()
risk_predictor.train()
step = 0
data_idx = 0
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
acc_precision, acc_recall, acc_f1 = 0, 0, 0
start_time = time.time()
while step < config.max_steps:
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
data_idx += config.batch_size
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
padding='max_length', return_tensors='pt')
input_ids = enc['input_ids'].to(device)
attention_mask = enc['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
output_hidden_states=True
)
lm_loss = outputs.loss
# Pass full hidden_states β€” RiskPredictor indexes into specific layers
risk_logits = risk_predictor(outputs.hidden_states)
rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window)
mask = attention_mask.float()
n_pos = (rep_labels * mask).sum().clamp(min=1)
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
pos_weight = (n_neg / n_pos).clamp(max=10.0)
bce_loss = F.binary_cross_entropy_with_logits(
risk_logits, rep_labels,
pos_weight=torch.ones_like(rep_labels) * pos_weight,
reduction='none'
)
risk_loss = (bce_loss * mask).sum() / mask.sum()
loss = lm_loss + risk_loss
(loss / config.grad_accum).backward()
with torch.no_grad():
risk_pred = torch.sigmoid(risk_logits)
pred_binary = (risk_pred > 0.5).float()
tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum()
fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum()
fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
acc_loss += loss.item()
acc_lm += lm_loss.item()
acc_risk_loss += risk_loss.item()
acc_precision += precision.item()
acc_recall += recall.item()
acc_f1 += f1.item()
step += 1
if step % config.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(
list(lora_params) + list(risk_predictor.parameters()), 1.0
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % config.log_every == 0:
eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600
n = config.log_every
print(
f"Step {step:5d} | "
f"Loss: {acc_loss/n:.4f} | "
f"LM: {acc_lm/n:.4f} | "
f"Risk: {acc_risk_loss/n:.4f} | "
f"P: {acc_precision/n:.3f} | "
f"R: {acc_recall/n:.3f} | "
f"F1: {acc_f1/n:.3f} | "
f"ETA: {eta:.1f}h"
)
training_log["steps"].append({
"step": step, "loss": acc_loss/n, "lm_loss": acc_lm/n,
"risk_loss": acc_risk_loss/n, "precision": acc_precision/n,
"recall": acc_recall/n, "f1": acc_f1/n
})
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
acc_precision, acc_recall, acc_f1 = 0, 0, 0
if step % config.save_every == 0:
ckpt = os.path.join(config.output_dir, f"ckpt_{step}")
os.makedirs(ckpt, exist_ok=True)
model.save_pretrained(ckpt)
torch.save({
'risk_predictor': risk_predictor.state_dict(),
'step': step
}, os.path.join(ckpt, "risk_predictor.pt"))
print(f">>> Saved: {ckpt}")
if step % config.eval_every == 0:
print(f"\n{'='*50}")
print(f"SEPARATION EVAL @ Step {step}")
print(f"{'='*50}")
p_pos, p_neg, separation, n_pos_s, n_neg_s = \
compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=30)
print(f" P(+) = {p_pos:.4f} (n={n_pos_s})")
print(f" P(-) = {p_neg:.4f} (n={n_neg_s})")
print(f" SEPARATION = {separation:.1f}x")
print(f" [LLaMA-8B baseline: 125x]")
training_log["separations"].append({
"step": step, "p_pos": p_pos, "p_neg": p_neg,
"separation": separation, "n_pos": n_pos_s, "n_neg": n_neg_s
})
with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f:
json.dump(training_log, f, indent=2)
print(f"{'='*50}\n")
model.train()
risk_predictor.train()
# FINAL
print("\n" + "=" * 70)
print("FINAL CROSS-ARCHITECTURE COMPARISON")
print("=" * 70)
p_pos, p_neg, separation, n_pos, n_neg = \
compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50)
d = d_model
nl = n_layers
print(f"""
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ CROSS-ARCHITECTURE REPLICATION v2 (FIXED) β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ LLaMA-3.1-8B: 125x (P+=0.998, P-=0.008) β”‚
β”‚ Qwen2.5-3B: {separation:>5.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f}) β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ Architecture: Qwen2 ({d}d, {nl}L) vs LLaMA (4096d, 32L) β”‚
β”‚ Probe layers: {config.probe_layers} β”‚
β”‚ d_fiber: 16 (identical) β”‚
β”‚ Method: IDENTICAL β”‚
β”‚ Conclusion: {"βœ… GENERALIZES" if separation > 10 else "⚠️ INVESTIGATE"} β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
""")
training_log["final"] = {
"p_pos": p_pos, "p_neg": p_neg, "separation": separation,
"n_pos": n_pos, "n_neg": n_neg,
"conclusion": "generalizes" if separation > 10 else "needs_investigation"
}
with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f:
json.dump(training_log, f, indent=2)
final = os.path.join(config.output_dir, "final")
os.makedirs(final, exist_ok=True)
model.save_pretrained(final)
torch.save({
'risk_predictor': risk_predictor.state_dict(),
'step': step, 'separation': separation,
'p_pos': p_pos, 'p_neg': p_neg
}, os.path.join(final, "risk_predictor.pt"))
print(f"Done! Log: {config.output_dir}/replication_log.json")
if __name__ == "__main__":
main()