cfhot-weights / code /training_pipelines /01_cfhot_risk_v2_REPETITION_125x.py
LoganResearch's picture
🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified
#!/usr/bin/env python3
"""
CF-HoT Phase 1: Train Risk Predictor (FIXED)
=============================================
Fixed: Class weighting to handle imbalanced data (most tokens aren't repeats)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import os
import time
import random
from dataclasses import dataclass
from typing import Tuple
@dataclass
class Config:
model_path: str = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5"
output_dir: str = "./results/cfhot_risk_v2"
d_fiber: int = 16
d_control: int = 64
max_steps: int = 3000
batch_size: int = 1
grad_accum: int = 8
max_length: int = 256
lr_lora: float = 2e-5
lr_predictor: float = 1e-4
weight_decay: float = 0.01
rep_window: int = 32
log_every: int = 10
save_every: int = 500
eval_every: int = 200
class RiskPredictor(nn.Module):
def __init__(self, d_model: int, n_layers: int, config: Config):
super().__init__()
self.config = config
self.n_layers = n_layers
self.fiber_projs = nn.ModuleList([
nn.Linear(d_model, config.d_fiber, bias=False)
for _ in range(n_layers)
])
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
# Output LOGITS, not probabilities
self.predictor = nn.Sequential(
nn.Linear(config.d_fiber, config.d_control),
nn.GELU(),
nn.Linear(config.d_control, config.d_control),
nn.GELU(),
nn.Linear(config.d_control, 1)
# NO sigmoid here - we'll use BCEWithLogitsLoss
)
for proj in self.fiber_projs:
nn.init.normal_(proj.weight, std=0.02)
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
fibers = []
for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)):
if i < len(hidden_states):
fiber = proj(h.float())
fibers.append(fiber)
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
aggregated = sum(w * f for w, f in zip(weights, fibers))
logits = self.predictor(aggregated).squeeze(-1) # [B, S] LOGITS
return logits
def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
B, S = input_ids.shape
device = input_ids.device
labels = torch.zeros(B, S, device=device)
for offset in range(1, min(window + 1, S)):
if offset < S:
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
return labels
def main():
config = Config()
os.makedirs(config.output_dir, exist_ok=True)
print("=" * 70)
print("CF-HoT RISK PREDICTOR v2 (CLASS-WEIGHTED)")
print("=" * 70)
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
tokenizer.pad_token = tokenizer.eos_token
print("Loading model...")
bnb = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4"
)
model = AutoModelForCausalLM.from_pretrained(
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16
)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
device = next(model.parameters()).device
print(f"Device: {device}")
print("Adding LoRA...")
model = get_peft_model(model, LoraConfig(
r=64, lora_alpha=128,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
))
model.print_trainable_parameters()
print("Adding Risk Predictor...")
n_layers = model.config.num_hidden_layers
d_model = model.config.hidden_size
risk_predictor = RiskPredictor(d_model, n_layers, config).to(device).float()
print(f"Risk Predictor params: {sum(p.numel() for p in risk_predictor.parameters()):,}")
print("Loading data...")
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
random.shuffle(texts)
print(f"Loaded {len(texts)} samples")
lora_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW([
{'params': lora_params, 'lr': config.lr_lora},
{'params': risk_predictor.parameters(), 'lr': config.lr_predictor}
], weight_decay=config.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config.max_steps, eta_min=1e-6
)
print("\n" + "=" * 70)
print("TRAINING (with class weighting)")
print("=" * 70)
model.train()
risk_predictor.train()
step = 0
data_idx = 0
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
acc_precision, acc_recall, acc_f1 = 0, 0, 0
start_time = time.time()
while step < config.max_steps:
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
data_idx += config.batch_size
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
padding='max_length', return_tensors='pt')
input_ids = enc['input_ids'].to(device)
attention_mask = enc['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
output_hidden_states=True
)
lm_loss = outputs.loss
# Get logits from risk predictor
risk_logits = risk_predictor(outputs.hidden_states[1:])
# Compute labels
rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window)
# CLASS-WEIGHTED LOSS
# Compute class weights dynamically based on batch
mask = attention_mask.float()
n_pos = (rep_labels * mask).sum().clamp(min=1)
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
pos_weight = n_neg / n_pos # Weight positives by ratio of negatives to positives
pos_weight = pos_weight.clamp(max=10.0) # Cap at 10x
# BCEWithLogitsLoss with pos_weight
bce_loss = F.binary_cross_entropy_with_logits(
risk_logits, rep_labels,
pos_weight=torch.ones_like(rep_labels) * pos_weight,
reduction='none'
)
risk_loss = (bce_loss * mask).sum() / mask.sum()
# Total loss
loss = lm_loss + risk_loss
(loss / config.grad_accum).backward()
# Metrics (apply sigmoid for evaluation)
with torch.no_grad():
risk_pred = torch.sigmoid(risk_logits)
pred_binary = (risk_pred > 0.5).float()
tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum()
fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum()
fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
acc_loss += loss.item()
acc_lm += lm_loss.item()
acc_risk_loss += risk_loss.item()
acc_precision += precision.item()
acc_recall += recall.item()
acc_f1 += f1.item()
step += 1
if step % config.grad_accum == 0:
torch.nn.utils.clip_grad_norm_(
list(lora_params) + list(risk_predictor.parameters()), 1.0
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % config.log_every == 0:
eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600
n = config.log_every
print(
f"Step {step:5d} | "
f"Loss: {acc_loss/n:.4f} | "
f"LM: {acc_lm/n:.4f} | "
f"Risk: {acc_risk_loss/n:.4f} | "
f"P: {acc_precision/n:.3f} | "
f"R: {acc_recall/n:.3f} | "
f"F1: {acc_f1/n:.3f} | "
f"ETA: {eta:.1f}h"
)
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
acc_precision, acc_recall, acc_f1 = 0, 0, 0
if step % config.save_every == 0:
ckpt = os.path.join(config.output_dir, f"ckpt_{step}")
os.makedirs(ckpt, exist_ok=True)
model.save_pretrained(ckpt)
torch.save({
'risk_predictor': risk_predictor.state_dict(),
'step': step
}, os.path.join(ckpt, "risk_predictor.pt"))
print(f">>> Saved: {ckpt}")
if step % config.eval_every == 0:
model.eval()
risk_predictor.eval()
print("\n--- Evaluation ---")
prompt = "The will to power, as described by Nietzsche, is"
inp = tokenizer(prompt, return_tensors='pt')
input_ids = inp['input_ids'].to(device)
with torch.no_grad():
out = model.generate(
input_ids, max_new_tokens=60,
do_sample=True, temperature=0.8, top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
gen_outputs = model(out, output_hidden_states=True)
gen_logits = risk_predictor(gen_outputs.hidden_states[1:])
gen_risk = torch.sigmoid(gen_logits)
risk_vals = gen_risk[0].cpu().numpy()
print(f" Generated: {generated_text[:200]}...")
print(f" Risk (first 10): {[f'{r:.2f}' for r in risk_vals[:10]]}")
print(f" Risk (last 10): {[f'{r:.2f}' for r in risk_vals[-10:]]}")
print(f" Mean: {risk_vals.mean():.3f}, Max: {risk_vals.max():.3f}, Min: {risk_vals.min():.3f}")
# Check correlation
gen_ids = out[0].cpu().numpy()
actual_reps = []
for t in range(1, len(gen_ids)):
start = max(0, t - config.rep_window)
is_rep = gen_ids[t] in gen_ids[start:t]
actual_reps.append(1 if is_rep else 0)
# Correlation between risk and actual repeats
if len(actual_reps) > 1:
risk_at_reps = [risk_vals[i+1] for i, r in enumerate(actual_reps) if r == 1 and i+1 < len(risk_vals)]
risk_at_nonreps = [risk_vals[i+1] for i, r in enumerate(actual_reps) if r == 0 and i+1 < len(risk_vals)]
if risk_at_reps and risk_at_nonreps:
print(f" Avg risk at REPEATS: {sum(risk_at_reps)/len(risk_at_reps):.3f}")
print(f" Avg risk at NON-REPS: {sum(risk_at_nonreps)/len(risk_at_nonreps):.3f}")
print("--- End Eval ---\n")
model.train()
risk_predictor.train()
final = os.path.join(config.output_dir, "final")
os.makedirs(final, exist_ok=True)
model.save_pretrained(final)
torch.save({
'risk_predictor': risk_predictor.state_dict(),
'step': step
}, os.path.join(final, "risk_predictor.pt"))
print(f"\nDONE! Saved to {final}")
if __name__ == "__main__":
main()