#!/usr/bin/env python3 """ ARC Dense Training - Maximum Information Per Token Instead of optimizing for brevity, we optimize for: - Information density (concepts per token) - Technical depth (domain vocabulary) - Factual claim density - Completeness relative to question complexity While still penalizing: - Repetition (zero information) - Filler phrases (negative information density) - Unnecessary hedging (wastes tokens) """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel, get_peft_model, LoraConfig from dataclasses import dataclass from pathlib import Path import argparse import json import random import re import os os.environ["TRANSFORMERS_VERBOSITY"] = "error" os.environ["TOKENIZERS_PARALLELISM"] = "false" # Technical vocabulary for density scoring TECHNICAL_TERMS = { # CS/ML "algorithm", "tensor", "gradient", "backpropagation", "embedding", "attention", "transformer", "convolution", "recurrent", "optimization", "inference", "latent", "epoch", "batch", "dropout", "regularization", "softmax", "sigmoid", "relu", "encoder", "decoder", "autoregressive", "tokenizer", "perplexity", "entropy", # Math "derivative", "integral", "matrix", "vector", "eigenvalue", "polynomial", "probability", "distribution", "variance", "covariance", "logarithm", # Science "quantum", "entropy", "thermodynamic", "photon", "electron", "molecule", "catalyst", "oxidation", "synthesis", "genome", "protein", "neuron", # Philosophy "ontology", "epistemology", "metaphysics", "phenomenology", "existential", "determinism", "consciousness", "qualia", "dualism", "materialism", } FILLER_PHRASES = [ "it's important to note", "it should be noted", "as you may know", "in other words", "that being said", "at the end of the day", "basically", "essentially", "actually", "literally", "obviously", "of course", "needless to say", "as i mentioned", "let me explain", "i think", "i believe", "in my opinion", "to be honest", "great question", "that's a good question", "interesting question", ] @dataclass class DenseConfig: batch_size: int = 2 # Smaller batch for longer generations gradient_accumulation: int = 8 # Effective batch = 16 max_grad_norm: float = 1.0 learning_rate: float = 3e-6 # Slightly lower for stability max_new_tokens: int = 256 # Allow longer responses for density checkpoint_every: int = 1000 log_every: int = 25 regenerate_prompts_every: int = 5000 temperature: float = 0.7 # Slightly lower for more focused output # Dense-specific min_response_tokens: int = 40 # Don't reward too-short responses target_density: float = 0.15 # Target concepts per token class MultiHeadPredictor(nn.Module): def __init__(self, d_model=4096, n_layers=32, d_fiber=16): super().__init__() self.d_model = d_model self.n_layers = n_layers self.d_fiber = d_fiber self.fiber_projs = nn.ModuleList([ nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_layers) ]) self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) self.heads = nn.ModuleDict() self.loaded_heads = set() def add_head(self, name): self.heads[name] = nn.Sequential( nn.Linear(self.d_fiber, 64), nn.GELU(), nn.Linear(64, 64), nn.GELU(), nn.Linear(64, 1) ) def get_fiber_features(self, hidden_states): device = hidden_states[0].device fibers = [] for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)): proj = proj.to(device) fibers.append(proj(h.float())) weights = F.softmax(self.layer_weights.to(device), dim=0) return sum(w * f for w, f in zip(weights, fibers)) def get_all_risks(self, hidden_states): device = hidden_states[0].device features = self.get_fiber_features(hidden_states) risks = {} for name in self.loaded_heads: self.heads[name] = self.heads[name].to(device) logits = self.heads[name](features).squeeze(-1) risks[name] = torch.sigmoid(logits) return risks def load_predictor(checkpoint_dir: Path, device): predictor = MultiHeadPredictor() rep_path = checkpoint_dir / "cfhot_risk_v2/ckpt_5000/risk_predictor.pt" if rep_path.exists(): ckpt = torch.load(rep_path, map_location=device, weights_only=False) if 'fiber_projs' in ckpt: for i, proj_state in enumerate(ckpt['fiber_projs']): predictor.fiber_projs[i].load_state_dict(proj_state) if 'layer_weights' in ckpt: predictor.layer_weights.data = ckpt['layer_weights'] predictor.add_head('repetition') if 'head_state' in ckpt: predictor.heads['repetition'].load_state_dict(ckpt['head_state']) predictor.loaded_heads.add('repetition') print(" ✓ Loaded repetition head") for head_name in ['hedging', 'verbosity', 'sycophancy']: head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_10000/{head_name}_head.pt" if not head_path.exists(): head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_2000/{head_name}_head.pt" if head_path.exists(): predictor.add_head(head_name) ckpt = torch.load(head_path, map_location=device, weights_only=False) if 'head_state' in ckpt: predictor.heads[head_name].load_state_dict(ckpt['head_state']) elif isinstance(ckpt, dict) and head_name in ckpt: predictor.heads[head_name].load_state_dict(ckpt[head_name]) predictor.loaded_heads.add(head_name) print(f" ✓ Loaded {head_name} head") predictor.eval() for param in predictor.parameters(): param.requires_grad = False return predictor.to(device) def generate_complex_prompts(n: int) -> list: """Generate prompts that demand dense, technical responses""" templates = [ # Technical explanations "Explain {topic} with technical precision.", "How does {topic} work at a fundamental level?", "What are the core mechanisms behind {topic}?", "Describe the architecture of {topic}.", "What distinguishes {topic1} from {topic2} technically?", # Deep dives "Explain the mathematics behind {topic}.", "What are the theoretical foundations of {topic}?", "Describe {topic} as you would to a graduate student.", "What are the key equations governing {topic}?", # Implementation "How would you implement {topic} from scratch?", "What's the most efficient algorithm for {task}?", "Explain the time complexity of {topic}.", # Analysis "What are the fundamental tradeoffs in {topic}?", "Why does {topic} work the way it does?", "What are the failure modes of {topic}?", "Analyze the strengths and weaknesses of {topic}.", # Synthesis "How do {topic1} and {topic2} relate to each other?", "What principles unify {topic1} and {topic2}?", "How would you combine {topic1} with {topic2}?", ] topics = [ "transformer attention", "backpropagation", "gradient descent", "convolutional neural networks", "recurrent neural networks", "reinforcement learning", "Q-learning", "policy gradients", "variational autoencoders", "GANs", "diffusion models", "tokenization", "embedding spaces", "positional encoding", "layer normalization", "batch normalization", "dropout", "LSTM gates", "self-attention", "cross-attention", "beam search", "nucleus sampling", "temperature scaling", "quantization", "pruning", "knowledge distillation", "quantum entanglement", "wave function collapse", "superposition", "natural selection", "genetic drift", "speciation", "thermodynamic entropy", "information entropy", "free energy", "consciousness", "qualia", "the binding problem", "Gödel's incompleteness", "Turing completeness", "P vs NP", "hash tables", "B-trees", "red-black trees", "recursion", "dynamic programming", "memoization", "TCP/IP", "public key cryptography", "consensus algorithms", ] tasks = [ "sorting n elements", "finding shortest path", "matrix multiplication", "string matching", "graph traversal", "balanced tree insertion", "hash collision resolution", "memory allocation", "garbage collection", ] prompts = [] complexities = [] # Track expected complexity for _ in range(n): template = random.choice(templates) if "{topic1}" in template and "{topic2}" in template: t1, t2 = random.sample(topics, 2) prompt = template.format(topic1=t1, topic2=t2) complexity = 3 # Comparison = high complexity elif "{topic}" in template: topic = random.choice(topics) prompt = template.format(topic=topic) complexity = 2 if "mathematics" in template or "equations" in template else 1.5 elif "{task}" in template: task = random.choice(tasks) prompt = template.format(task=task) complexity = 2 else: prompt = template complexity = 1 prompts.append((prompt, complexity)) return prompts def count_technical_terms(text: str) -> int: """Count domain-specific technical vocabulary""" words = set(text.lower().split()) return len(words.intersection(TECHNICAL_TERMS)) def count_filler_phrases(text: str) -> int: """Count filler phrases that waste tokens""" text_lower = text.lower() return sum(1 for phrase in FILLER_PHRASES if phrase in text_lower) def count_factual_claims(text: str) -> int: """Estimate number of factual assertions""" # Simple heuristic: sentences with specific patterns sentences = re.split(r'[.!?]', text) claims = 0 for sent in sentences: sent = sent.strip().lower() if not sent: continue # Patterns indicating factual claims if any(pattern in sent for pattern in [ " is ", " are ", " was ", " were ", " has ", " have ", " means ", " equals ", " produces ", " causes ", " results ", " requires ", " enables ", " allows ", " prevents ", "defined as", "consists of", "composed of", ]): claims += 1 return claims def count_code_and_math(text: str) -> int: """Count structured technical content""" code_blocks = len(re.findall(r'```[\s\S]*?```', text)) inline_code = len(re.findall(r'`[^`]+`', text)) equations = len(re.findall(r'\$[^$]+\$', text)) math_symbols = len(re.findall(r'[∑∏∫∂∇≈≠≤≥∈∀∃→←↔×÷±√∞]', text)) formulas = len(re.findall(r'[a-z]\s*[=<>]\s*[a-z0-9]', text, re.I)) return code_blocks * 5 + inline_code + equations * 3 + math_symbols + formulas def compute_dense_reward(response_ids, risks, tokenizer, complexity, config): """ Dense reward: maximize information per token Reward = (information_score) / (effective_tokens) - penalties """ batch_rewards = [] batch_densities = [] for i in range(len(response_ids)): response = tokenizer.decode(response_ids[i], skip_special_tokens=True) tokens = len(response_ids[i]) if tokens < 5: batch_rewards.append(0.0) batch_densities.append(0.0) continue # === Information Content === # 1. Unique concept words (content words > 4 chars) words = response.split() content_words = set(w.lower() for w in words if len(w) > 4 and w.isalpha()) concept_density = len(content_words) / tokens # 2. Technical term density tech_terms = count_technical_terms(response) tech_density = tech_terms / tokens # 3. Factual claim density claims = count_factual_claims(response) claim_density = claims / max(tokens / 20, 1) # Normalize by ~sentence count # 4. Structured content (code, math) structured = count_code_and_math(response) structured_density = structured / tokens # Combined information score info_score = ( concept_density * 0.3 + tech_density * 0.3 + claim_density * 0.25 + structured_density * 0.15 ) # === Fluff Penalties === rep_risk = risks['repetition'][i, -1].item() if 'repetition' in risks else 0 verb_risk = risks['verbosity'][i, -1].item() if 'verbosity' in risks else 0 hedge_risk = risks['hedging'][i, -1].item() if 'hedging' in risks else 0 filler_count = count_filler_phrases(response) filler_penalty = min(filler_count * 0.05, 0.3) # Probes penalty (repetition worst, verbosity bad, hedging mild) probe_penalty = 0.4 * rep_risk + 0.25 * verb_risk + 0.1 * hedge_risk total_fluff = filler_penalty + probe_penalty # === Completeness === # Scale expected length with question complexity expected_min = config.min_response_tokens * complexity if tokens < expected_min: completeness_penalty = 0.3 * (expected_min - tokens) / expected_min else: completeness_penalty = 0 # Bonus for appropriate length (not too short, not excessively long) if expected_min <= tokens <= expected_min * 3: length_bonus = 0.1 elif tokens > expected_min * 4: length_bonus = -0.1 # Penalize excessive length else: length_bonus = 0 # === Final Reward === # Effective tokens: actual tokens + penalty for fluff effective_tokens = tokens * (1 + total_fluff) # Information per effective token density = info_score / (effective_tokens / 100) reward = density - completeness_penalty + length_bonus reward = max(0, min(1, reward)) # Clamp to [0, 1] batch_rewards.append(reward) batch_densities.append(info_score * 100) # For logging return ( torch.tensor(batch_rewards, dtype=torch.float32, device=response_ids[0].device), sum(batch_densities) / len(batch_densities) if batch_densities else 0 ) def compute_efficiency_decision(risks): """Same efficiency routing as terse training""" rep = risks.get('repetition', torch.zeros(1))[:, -1].mean().item() verb = risks.get('verbosity', torch.zeros(1))[:, -1].mean().item() hedge = risks.get('hedging', torch.zeros(1))[:, -1].mean().item() if rep > 0.45: return {'layers': 20, 'spec_length': 8, 'strategy': 'skip_speculate_aggressive'} elif verb > 0.5: return {'layers': 24, 'spec_length': 6, 'strategy': 'skip_speculate_moderate'} elif rep < 0.4 and verb < 0.4 and hedge < 0.4: return {'layers': 16, 'spec_length': 2, 'strategy': 'early_exit_careful'} else: return {'layers': 32, 'spec_length': 3, 'strategy': 'full_compute'} def train(args): config = DenseConfig() config.learning_rate = args.lr device = torch.device("cuda") print("=" * 60) print(" ARC Dense Training - Maximum Information Density") print("=" * 60) print(f" Batch size: {config.batch_size}") print(f" Gradient accumulation: {config.gradient_accumulation}") print(f" Effective batch: {config.batch_size * config.gradient_accumulation}") print(f" Learning rate: {config.learning_rate}") print(f" Max new tokens: {config.max_new_tokens}") print(f" Min response tokens: {config.min_response_tokens}") print("=" * 60) print("\n[1/3] Loading model...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True ) tokenizer = AutoTokenizer.from_pretrained(args.local_model) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" model = AutoModelForCausalLM.from_pretrained( args.local_model, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="sdpa" ) start_step = 0 if args.resume and Path(args.resume).exists(): print(f" Resuming from {args.resume}") model = PeftModel.from_pretrained(model, args.resume, is_trainable=True) state_path = Path(args.resume) / "training_state.pt" if state_path.exists(): state = torch.load(state_path, weights_only=False) start_step = state.get('step', 0) print(f" Resuming from step {start_step}") elif args.base_checkpoint and Path(args.base_checkpoint).exists(): print(f" Loading base checkpoint: {args.base_checkpoint}") model = PeftModel.from_pretrained(model, args.base_checkpoint, is_trainable=True) print(" ✓ Loaded terse-trained adapter as starting point") else: lora_config = LoraConfig( r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) total = sum(p.numel() for p in model.parameters()) print(f" Model loaded. Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") print("\n[2/3] Loading behavioral prediction heads...") checkpoint_dir = Path.home() / "arc_efficiency_training/checkpoints_backup" predictor = load_predictor(checkpoint_dir, device) print(f" ✓ Predictor loaded with heads: {list(predictor.loaded_heads)}") print("\n[3/3] Setting up optimizer...") optimizer = torch.optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=0.01, betas=(0.9, 0.999) ) print(f" ✓ Optimizer: AdamW, LR: {config.learning_rate}") print(f"\nGenerating {args.prompts} complex prompts...") prompts_with_complexity = generate_complex_prompts(args.prompts) print(f" ✓ Generated {len(prompts_with_complexity)} prompts") checkpoint_dir = Path(args.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) print("\n" + "=" * 60) print(f" Starting DENSE training from step {start_step}") print(f" Total steps: {args.steps}") print("=" * 60 + "\n") model.train() optimizer.zero_grad() step = start_step accum_loss = 0 accum_reward = 0 accum_density = 0 accum_rep = 0 accum_layers = 0 last_strategy = "none" while step < args.steps: batch_data = random.sample(prompts_with_complexity, config.batch_size) batch_prompts = [p[0] for p in batch_data] batch_complexity = [p[1] for p in batch_data] avg_complexity = sum(batch_complexity) / len(batch_complexity) formatted = [ f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n" for p in batch_prompts ] inputs = tokenizer( formatted, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(device) model.eval() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=config.max_new_tokens, do_sample=True, temperature=config.temperature, top_p=0.9, pad_token_id=tokenizer.eos_token_id, use_cache=True ) generated_ids = outputs[:, inputs.input_ids.shape[1]:] with torch.no_grad(): hidden_outputs = model( outputs, output_hidden_states=True, return_dict=True, use_cache=False ) hidden_states = hidden_outputs.hidden_states[1:] risks = predictor.get_all_risks(hidden_states) rewards, avg_density_score = compute_dense_reward( generated_ids, risks, tokenizer, avg_complexity, config ) efficiency = compute_efficiency_decision(risks) model.train() logits = model(outputs, return_dict=True, use_cache=False).logits shift_logits = logits[:, :-1, :].contiguous() shift_labels = outputs[:, 1:].contiguous() log_probs = F.log_softmax(shift_logits.float(), dim=-1) selected_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) mask = (shift_labels != tokenizer.pad_token_id).float() seq_log_probs = (selected_log_probs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) baseline = rewards.float().mean() advantages = rewards - baseline loss = -(seq_log_probs * advantages.to(seq_log_probs.device)).mean() loss = loss / config.gradient_accumulation loss.backward() accum_loss += loss.item() * config.gradient_accumulation accum_reward += rewards.float().mean().item() accum_density += avg_density_score accum_rep += risks['repetition'][:, -1].mean().item() if 'repetition' in risks else 0 accum_layers += efficiency['layers'] last_strategy = efficiency['strategy'] if (step + 1) % config.gradient_accumulation == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) optimizer.step() optimizer.zero_grad() step += 1 if step % config.log_every == 0: avg_loss = accum_loss / config.log_every avg_reward = accum_reward / config.log_every avg_dens = accum_density / config.log_every avg_rep = accum_rep / config.log_every avg_layers = accum_layers / config.log_every print(f"Step {step:6d} | Loss: {avg_loss:.4f} | Reward: {avg_reward:.3f} | " f"Density: {avg_dens:.2f} | Rep: {avg_rep:.3f} | Layers: {avg_layers:.1f} | {last_strategy}") accum_loss = 0 accum_reward = 0 accum_density = 0 accum_rep = 0 accum_layers = 0 if step % config.checkpoint_every == 0: ckpt_path = checkpoint_dir / f"step_{step}" model.save_pretrained(ckpt_path) torch.save({ 'step': step, 'optimizer': optimizer.state_dict(), 'config': config.__dict__, 'mode': 'dense' }, ckpt_path / "training_state.pt") with open(ckpt_path / "README.md", "w") as f: f.write(f"# ARC Dense Checkpoint - Step {step}\n\n") f.write("**Mode:** Dense (maximum information per token)\n\n") f.write(f"Training config:\n```json\n{json.dumps(config.__dict__, indent=2)}\n```\n") print(f" ✓ Saved dense checkpoint at step {step}") if step % config.regenerate_prompts_every == 0 and step > start_step: print(f"\n Regenerating complex prompts...") prompts_with_complexity = generate_complex_prompts(args.prompts) print(f" ✓ Generated {len(prompts_with_complexity)} fresh prompts\n") print("\n" + "=" * 60) print(" Dense training complete!") print("=" * 60) final_path = checkpoint_dir / "final" model.save_pretrained(final_path) print(f" ✓ Saved final dense model to {final_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local-model", type=str, required=True) parser.add_argument("--base-checkpoint", type=str, default=None, help="Start from terse-trained checkpoint") parser.add_argument("--steps", type=int, default=20000) parser.add_argument("--lr", type=float, default=3e-6) parser.add_argument("--prompts", type=int, default=5000) parser.add_argument("--checkpoint-dir", type=str, default="./dense_checkpoints") parser.add_argument("--resume", type=str, default=None) args = parser.parse_args() train(args)