cfhot-weights / code /training_pipelines /03_arc_dense_train_DENSE.py
LoganResearch's picture
🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified
#!/usr/bin/env python3
"""
ARC Dense Training - Maximum Information Per Token
Instead of optimizing for brevity, we optimize for:
- Information density (concepts per token)
- Technical depth (domain vocabulary)
- Factual claim density
- Completeness relative to question complexity
While still penalizing:
- Repetition (zero information)
- Filler phrases (negative information density)
- Unnecessary hedging (wastes tokens)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, get_peft_model, LoraConfig
from dataclasses import dataclass
from pathlib import Path
import argparse
import json
import random
import re
import os
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Technical vocabulary for density scoring
TECHNICAL_TERMS = {
# CS/ML
"algorithm", "tensor", "gradient", "backpropagation", "embedding", "attention",
"transformer", "convolution", "recurrent", "optimization", "inference", "latent",
"epoch", "batch", "dropout", "regularization", "softmax", "sigmoid", "relu",
"encoder", "decoder", "autoregressive", "tokenizer", "perplexity", "entropy",
# Math
"derivative", "integral", "matrix", "vector", "eigenvalue", "polynomial",
"probability", "distribution", "variance", "covariance", "logarithm",
# Science
"quantum", "entropy", "thermodynamic", "photon", "electron", "molecule",
"catalyst", "oxidation", "synthesis", "genome", "protein", "neuron",
# Philosophy
"ontology", "epistemology", "metaphysics", "phenomenology", "existential",
"determinism", "consciousness", "qualia", "dualism", "materialism",
}
FILLER_PHRASES = [
"it's important to note", "it should be noted", "as you may know",
"in other words", "that being said", "at the end of the day",
"basically", "essentially", "actually", "literally", "obviously",
"of course", "needless to say", "as i mentioned", "let me explain",
"i think", "i believe", "in my opinion", "to be honest",
"great question", "that's a good question", "interesting question",
]
@dataclass
class DenseConfig:
batch_size: int = 2 # Smaller batch for longer generations
gradient_accumulation: int = 8 # Effective batch = 16
max_grad_norm: float = 1.0
learning_rate: float = 3e-6 # Slightly lower for stability
max_new_tokens: int = 256 # Allow longer responses for density
checkpoint_every: int = 1000
log_every: int = 25
regenerate_prompts_every: int = 5000
temperature: float = 0.7 # Slightly lower for more focused output
# Dense-specific
min_response_tokens: int = 40 # Don't reward too-short responses
target_density: float = 0.15 # Target concepts per token
class MultiHeadPredictor(nn.Module):
def __init__(self, d_model=4096, n_layers=32, d_fiber=16):
super().__init__()
self.d_model = d_model
self.n_layers = n_layers
self.d_fiber = d_fiber
self.fiber_projs = nn.ModuleList([
nn.Linear(d_model, d_fiber, bias=False)
for _ in range(n_layers)
])
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
self.heads = nn.ModuleDict()
self.loaded_heads = set()
def add_head(self, name):
self.heads[name] = nn.Sequential(
nn.Linear(self.d_fiber, 64), nn.GELU(),
nn.Linear(64, 64), nn.GELU(),
nn.Linear(64, 1)
)
def get_fiber_features(self, hidden_states):
device = hidden_states[0].device
fibers = []
for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)):
proj = proj.to(device)
fibers.append(proj(h.float()))
weights = F.softmax(self.layer_weights.to(device), dim=0)
return sum(w * f for w, f in zip(weights, fibers))
def get_all_risks(self, hidden_states):
device = hidden_states[0].device
features = self.get_fiber_features(hidden_states)
risks = {}
for name in self.loaded_heads:
self.heads[name] = self.heads[name].to(device)
logits = self.heads[name](features).squeeze(-1)
risks[name] = torch.sigmoid(logits)
return risks
def load_predictor(checkpoint_dir: Path, device):
predictor = MultiHeadPredictor()
rep_path = checkpoint_dir / "cfhot_risk_v2/ckpt_5000/risk_predictor.pt"
if rep_path.exists():
ckpt = torch.load(rep_path, map_location=device, weights_only=False)
if 'fiber_projs' in ckpt:
for i, proj_state in enumerate(ckpt['fiber_projs']):
predictor.fiber_projs[i].load_state_dict(proj_state)
if 'layer_weights' in ckpt:
predictor.layer_weights.data = ckpt['layer_weights']
predictor.add_head('repetition')
if 'head_state' in ckpt:
predictor.heads['repetition'].load_state_dict(ckpt['head_state'])
predictor.loaded_heads.add('repetition')
print(" ✓ Loaded repetition head")
for head_name in ['hedging', 'verbosity', 'sycophancy']:
head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_10000/{head_name}_head.pt"
if not head_path.exists():
head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_2000/{head_name}_head.pt"
if head_path.exists():
predictor.add_head(head_name)
ckpt = torch.load(head_path, map_location=device, weights_only=False)
if 'head_state' in ckpt:
predictor.heads[head_name].load_state_dict(ckpt['head_state'])
elif isinstance(ckpt, dict) and head_name in ckpt:
predictor.heads[head_name].load_state_dict(ckpt[head_name])
predictor.loaded_heads.add(head_name)
print(f" ✓ Loaded {head_name} head")
predictor.eval()
for param in predictor.parameters():
param.requires_grad = False
return predictor.to(device)
def generate_complex_prompts(n: int) -> list:
"""Generate prompts that demand dense, technical responses"""
templates = [
# Technical explanations
"Explain {topic} with technical precision.",
"How does {topic} work at a fundamental level?",
"What are the core mechanisms behind {topic}?",
"Describe the architecture of {topic}.",
"What distinguishes {topic1} from {topic2} technically?",
# Deep dives
"Explain the mathematics behind {topic}.",
"What are the theoretical foundations of {topic}?",
"Describe {topic} as you would to a graduate student.",
"What are the key equations governing {topic}?",
# Implementation
"How would you implement {topic} from scratch?",
"What's the most efficient algorithm for {task}?",
"Explain the time complexity of {topic}.",
# Analysis
"What are the fundamental tradeoffs in {topic}?",
"Why does {topic} work the way it does?",
"What are the failure modes of {topic}?",
"Analyze the strengths and weaknesses of {topic}.",
# Synthesis
"How do {topic1} and {topic2} relate to each other?",
"What principles unify {topic1} and {topic2}?",
"How would you combine {topic1} with {topic2}?",
]
topics = [
"transformer attention", "backpropagation", "gradient descent",
"convolutional neural networks", "recurrent neural networks",
"reinforcement learning", "Q-learning", "policy gradients",
"variational autoencoders", "GANs", "diffusion models",
"tokenization", "embedding spaces", "positional encoding",
"layer normalization", "batch normalization", "dropout",
"LSTM gates", "self-attention", "cross-attention",
"beam search", "nucleus sampling", "temperature scaling",
"quantization", "pruning", "knowledge distillation",
"quantum entanglement", "wave function collapse", "superposition",
"natural selection", "genetic drift", "speciation",
"thermodynamic entropy", "information entropy", "free energy",
"consciousness", "qualia", "the binding problem",
"Gödel's incompleteness", "Turing completeness", "P vs NP",
"hash tables", "B-trees", "red-black trees",
"recursion", "dynamic programming", "memoization",
"TCP/IP", "public key cryptography", "consensus algorithms",
]
tasks = [
"sorting n elements", "finding shortest path", "matrix multiplication",
"string matching", "graph traversal", "balanced tree insertion",
"hash collision resolution", "memory allocation", "garbage collection",
]
prompts = []
complexities = [] # Track expected complexity
for _ in range(n):
template = random.choice(templates)
if "{topic1}" in template and "{topic2}" in template:
t1, t2 = random.sample(topics, 2)
prompt = template.format(topic1=t1, topic2=t2)
complexity = 3 # Comparison = high complexity
elif "{topic}" in template:
topic = random.choice(topics)
prompt = template.format(topic=topic)
complexity = 2 if "mathematics" in template or "equations" in template else 1.5
elif "{task}" in template:
task = random.choice(tasks)
prompt = template.format(task=task)
complexity = 2
else:
prompt = template
complexity = 1
prompts.append((prompt, complexity))
return prompts
def count_technical_terms(text: str) -> int:
"""Count domain-specific technical vocabulary"""
words = set(text.lower().split())
return len(words.intersection(TECHNICAL_TERMS))
def count_filler_phrases(text: str) -> int:
"""Count filler phrases that waste tokens"""
text_lower = text.lower()
return sum(1 for phrase in FILLER_PHRASES if phrase in text_lower)
def count_factual_claims(text: str) -> int:
"""Estimate number of factual assertions"""
# Simple heuristic: sentences with specific patterns
sentences = re.split(r'[.!?]', text)
claims = 0
for sent in sentences:
sent = sent.strip().lower()
if not sent:
continue
# Patterns indicating factual claims
if any(pattern in sent for pattern in [
" is ", " are ", " was ", " were ", " has ", " have ",
" means ", " equals ", " produces ", " causes ", " results ",
" requires ", " enables ", " allows ", " prevents ",
"defined as", "consists of", "composed of",
]):
claims += 1
return claims
def count_code_and_math(text: str) -> int:
"""Count structured technical content"""
code_blocks = len(re.findall(r'```[\s\S]*?```', text))
inline_code = len(re.findall(r'`[^`]+`', text))
equations = len(re.findall(r'\$[^$]+\$', text))
math_symbols = len(re.findall(r'[∑∏∫∂∇≈≠≤≥∈∀∃→←↔×÷±√∞]', text))
formulas = len(re.findall(r'[a-z]\s*[=<>]\s*[a-z0-9]', text, re.I))
return code_blocks * 5 + inline_code + equations * 3 + math_symbols + formulas
def compute_dense_reward(response_ids, risks, tokenizer, complexity, config):
"""
Dense reward: maximize information per token
Reward = (information_score) / (effective_tokens) - penalties
"""
batch_rewards = []
batch_densities = []
for i in range(len(response_ids)):
response = tokenizer.decode(response_ids[i], skip_special_tokens=True)
tokens = len(response_ids[i])
if tokens < 5:
batch_rewards.append(0.0)
batch_densities.append(0.0)
continue
# === Information Content ===
# 1. Unique concept words (content words > 4 chars)
words = response.split()
content_words = set(w.lower() for w in words if len(w) > 4 and w.isalpha())
concept_density = len(content_words) / tokens
# 2. Technical term density
tech_terms = count_technical_terms(response)
tech_density = tech_terms / tokens
# 3. Factual claim density
claims = count_factual_claims(response)
claim_density = claims / max(tokens / 20, 1) # Normalize by ~sentence count
# 4. Structured content (code, math)
structured = count_code_and_math(response)
structured_density = structured / tokens
# Combined information score
info_score = (
concept_density * 0.3 +
tech_density * 0.3 +
claim_density * 0.25 +
structured_density * 0.15
)
# === Fluff Penalties ===
rep_risk = risks['repetition'][i, -1].item() if 'repetition' in risks else 0
verb_risk = risks['verbosity'][i, -1].item() if 'verbosity' in risks else 0
hedge_risk = risks['hedging'][i, -1].item() if 'hedging' in risks else 0
filler_count = count_filler_phrases(response)
filler_penalty = min(filler_count * 0.05, 0.3)
# Probes penalty (repetition worst, verbosity bad, hedging mild)
probe_penalty = 0.4 * rep_risk + 0.25 * verb_risk + 0.1 * hedge_risk
total_fluff = filler_penalty + probe_penalty
# === Completeness ===
# Scale expected length with question complexity
expected_min = config.min_response_tokens * complexity
if tokens < expected_min:
completeness_penalty = 0.3 * (expected_min - tokens) / expected_min
else:
completeness_penalty = 0
# Bonus for appropriate length (not too short, not excessively long)
if expected_min <= tokens <= expected_min * 3:
length_bonus = 0.1
elif tokens > expected_min * 4:
length_bonus = -0.1 # Penalize excessive length
else:
length_bonus = 0
# === Final Reward ===
# Effective tokens: actual tokens + penalty for fluff
effective_tokens = tokens * (1 + total_fluff)
# Information per effective token
density = info_score / (effective_tokens / 100)
reward = density - completeness_penalty + length_bonus
reward = max(0, min(1, reward)) # Clamp to [0, 1]
batch_rewards.append(reward)
batch_densities.append(info_score * 100) # For logging
return (
torch.tensor(batch_rewards, dtype=torch.float32, device=response_ids[0].device),
sum(batch_densities) / len(batch_densities) if batch_densities else 0
)
def compute_efficiency_decision(risks):
"""Same efficiency routing as terse training"""
rep = risks.get('repetition', torch.zeros(1))[:, -1].mean().item()
verb = risks.get('verbosity', torch.zeros(1))[:, -1].mean().item()
hedge = risks.get('hedging', torch.zeros(1))[:, -1].mean().item()
if rep > 0.45:
return {'layers': 20, 'spec_length': 8, 'strategy': 'skip_speculate_aggressive'}
elif verb > 0.5:
return {'layers': 24, 'spec_length': 6, 'strategy': 'skip_speculate_moderate'}
elif rep < 0.4 and verb < 0.4 and hedge < 0.4:
return {'layers': 16, 'spec_length': 2, 'strategy': 'early_exit_careful'}
else:
return {'layers': 32, 'spec_length': 3, 'strategy': 'full_compute'}
def train(args):
config = DenseConfig()
config.learning_rate = args.lr
device = torch.device("cuda")
print("=" * 60)
print(" ARC Dense Training - Maximum Information Density")
print("=" * 60)
print(f" Batch size: {config.batch_size}")
print(f" Gradient accumulation: {config.gradient_accumulation}")
print(f" Effective batch: {config.batch_size * config.gradient_accumulation}")
print(f" Learning rate: {config.learning_rate}")
print(f" Max new tokens: {config.max_new_tokens}")
print(f" Min response tokens: {config.min_response_tokens}")
print("=" * 60)
print("\n[1/3] Loading model...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
tokenizer = AutoTokenizer.from_pretrained(args.local_model)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
args.local_model,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa"
)
start_step = 0
if args.resume and Path(args.resume).exists():
print(f" Resuming from {args.resume}")
model = PeftModel.from_pretrained(model, args.resume, is_trainable=True)
state_path = Path(args.resume) / "training_state.pt"
if state_path.exists():
state = torch.load(state_path, weights_only=False)
start_step = state.get('step', 0)
print(f" Resuming from step {start_step}")
elif args.base_checkpoint and Path(args.base_checkpoint).exists():
print(f" Loading base checkpoint: {args.base_checkpoint}")
model = PeftModel.from_pretrained(model, args.base_checkpoint, is_trainable=True)
print(" ✓ Loaded terse-trained adapter as starting point")
else:
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f" Model loaded. Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
print("\n[2/3] Loading behavioral prediction heads...")
checkpoint_dir = Path.home() / "arc_efficiency_training/checkpoints_backup"
predictor = load_predictor(checkpoint_dir, device)
print(f" ✓ Predictor loaded with heads: {list(predictor.loaded_heads)}")
print("\n[3/3] Setting up optimizer...")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=0.01,
betas=(0.9, 0.999)
)
print(f" ✓ Optimizer: AdamW, LR: {config.learning_rate}")
print(f"\nGenerating {args.prompts} complex prompts...")
prompts_with_complexity = generate_complex_prompts(args.prompts)
print(f" ✓ Generated {len(prompts_with_complexity)} prompts")
checkpoint_dir = Path(args.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
print("\n" + "=" * 60)
print(f" Starting DENSE training from step {start_step}")
print(f" Total steps: {args.steps}")
print("=" * 60 + "\n")
model.train()
optimizer.zero_grad()
step = start_step
accum_loss = 0
accum_reward = 0
accum_density = 0
accum_rep = 0
accum_layers = 0
last_strategy = "none"
while step < args.steps:
batch_data = random.sample(prompts_with_complexity, config.batch_size)
batch_prompts = [p[0] for p in batch_data]
batch_complexity = [p[1] for p in batch_data]
avg_complexity = sum(batch_complexity) / len(batch_complexity)
formatted = [
f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
for p in batch_prompts
]
inputs = tokenizer(
formatted,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(device)
model.eval()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=config.max_new_tokens,
do_sample=True,
temperature=config.temperature,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
use_cache=True
)
generated_ids = outputs[:, inputs.input_ids.shape[1]:]
with torch.no_grad():
hidden_outputs = model(
outputs,
output_hidden_states=True,
return_dict=True,
use_cache=False
)
hidden_states = hidden_outputs.hidden_states[1:]
risks = predictor.get_all_risks(hidden_states)
rewards, avg_density_score = compute_dense_reward(
generated_ids, risks, tokenizer, avg_complexity, config
)
efficiency = compute_efficiency_decision(risks)
model.train()
logits = model(outputs, return_dict=True, use_cache=False).logits
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = outputs[:, 1:].contiguous()
log_probs = F.log_softmax(shift_logits.float(), dim=-1)
selected_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
mask = (shift_labels != tokenizer.pad_token_id).float()
seq_log_probs = (selected_log_probs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
baseline = rewards.float().mean()
advantages = rewards - baseline
loss = -(seq_log_probs * advantages.to(seq_log_probs.device)).mean()
loss = loss / config.gradient_accumulation
loss.backward()
accum_loss += loss.item() * config.gradient_accumulation
accum_reward += rewards.float().mean().item()
accum_density += avg_density_score
accum_rep += risks['repetition'][:, -1].mean().item() if 'repetition' in risks else 0
accum_layers += efficiency['layers']
last_strategy = efficiency['strategy']
if (step + 1) % config.gradient_accumulation == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
step += 1
if step % config.log_every == 0:
avg_loss = accum_loss / config.log_every
avg_reward = accum_reward / config.log_every
avg_dens = accum_density / config.log_every
avg_rep = accum_rep / config.log_every
avg_layers = accum_layers / config.log_every
print(f"Step {step:6d} | Loss: {avg_loss:.4f} | Reward: {avg_reward:.3f} | "
f"Density: {avg_dens:.2f} | Rep: {avg_rep:.3f} | Layers: {avg_layers:.1f} | {last_strategy}")
accum_loss = 0
accum_reward = 0
accum_density = 0
accum_rep = 0
accum_layers = 0
if step % config.checkpoint_every == 0:
ckpt_path = checkpoint_dir / f"step_{step}"
model.save_pretrained(ckpt_path)
torch.save({
'step': step,
'optimizer': optimizer.state_dict(),
'config': config.__dict__,
'mode': 'dense'
}, ckpt_path / "training_state.pt")
with open(ckpt_path / "README.md", "w") as f:
f.write(f"# ARC Dense Checkpoint - Step {step}\n\n")
f.write("**Mode:** Dense (maximum information per token)\n\n")
f.write(f"Training config:\n```json\n{json.dumps(config.__dict__, indent=2)}\n```\n")
print(f" ✓ Saved dense checkpoint at step {step}")
if step % config.regenerate_prompts_every == 0 and step > start_step:
print(f"\n Regenerating complex prompts...")
prompts_with_complexity = generate_complex_prompts(args.prompts)
print(f" ✓ Generated {len(prompts_with_complexity)} fresh prompts\n")
print("\n" + "=" * 60)
print(" Dense training complete!")
print("=" * 60)
final_path = checkpoint_dir / "final"
model.save_pretrained(final_path)
print(f" ✓ Saved final dense model to {final_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local-model", type=str, required=True)
parser.add_argument("--base-checkpoint", type=str, default=None,
help="Start from terse-trained checkpoint")
parser.add_argument("--steps", type=int, default=20000)
parser.add_argument("--lr", type=float, default=3e-6)
parser.add_argument("--prompts", type=int, default=5000)
parser.add_argument("--checkpoint-dir", type=str, default="./dense_checkpoints")
parser.add_argument("--resume", type=str, default=None)
args = parser.parse_args()
train(args)