🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
297244f verified | #!/usr/bin/env python3 | |
| """ | |
| ARC Dense Training - Maximum Information Per Token | |
| Instead of optimizing for brevity, we optimize for: | |
| - Information density (concepts per token) | |
| - Technical depth (domain vocabulary) | |
| - Factual claim density | |
| - Completeness relative to question complexity | |
| While still penalizing: | |
| - Repetition (zero information) | |
| - Filler phrases (negative information density) | |
| - Unnecessary hedging (wastes tokens) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel, get_peft_model, LoraConfig | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| import argparse | |
| import json | |
| import random | |
| import re | |
| import os | |
| os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # Technical vocabulary for density scoring | |
| TECHNICAL_TERMS = { | |
| # CS/ML | |
| "algorithm", "tensor", "gradient", "backpropagation", "embedding", "attention", | |
| "transformer", "convolution", "recurrent", "optimization", "inference", "latent", | |
| "epoch", "batch", "dropout", "regularization", "softmax", "sigmoid", "relu", | |
| "encoder", "decoder", "autoregressive", "tokenizer", "perplexity", "entropy", | |
| # Math | |
| "derivative", "integral", "matrix", "vector", "eigenvalue", "polynomial", | |
| "probability", "distribution", "variance", "covariance", "logarithm", | |
| # Science | |
| "quantum", "entropy", "thermodynamic", "photon", "electron", "molecule", | |
| "catalyst", "oxidation", "synthesis", "genome", "protein", "neuron", | |
| # Philosophy | |
| "ontology", "epistemology", "metaphysics", "phenomenology", "existential", | |
| "determinism", "consciousness", "qualia", "dualism", "materialism", | |
| } | |
| FILLER_PHRASES = [ | |
| "it's important to note", "it should be noted", "as you may know", | |
| "in other words", "that being said", "at the end of the day", | |
| "basically", "essentially", "actually", "literally", "obviously", | |
| "of course", "needless to say", "as i mentioned", "let me explain", | |
| "i think", "i believe", "in my opinion", "to be honest", | |
| "great question", "that's a good question", "interesting question", | |
| ] | |
| class DenseConfig: | |
| batch_size: int = 2 # Smaller batch for longer generations | |
| gradient_accumulation: int = 8 # Effective batch = 16 | |
| max_grad_norm: float = 1.0 | |
| learning_rate: float = 3e-6 # Slightly lower for stability | |
| max_new_tokens: int = 256 # Allow longer responses for density | |
| checkpoint_every: int = 1000 | |
| log_every: int = 25 | |
| regenerate_prompts_every: int = 5000 | |
| temperature: float = 0.7 # Slightly lower for more focused output | |
| # Dense-specific | |
| min_response_tokens: int = 40 # Don't reward too-short responses | |
| target_density: float = 0.15 # Target concepts per token | |
| class MultiHeadPredictor(nn.Module): | |
| def __init__(self, d_model=4096, n_layers=32, d_fiber=16): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.n_layers = n_layers | |
| self.d_fiber = d_fiber | |
| self.fiber_projs = nn.ModuleList([ | |
| nn.Linear(d_model, d_fiber, bias=False) | |
| for _ in range(n_layers) | |
| ]) | |
| self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers) | |
| self.heads = nn.ModuleDict() | |
| self.loaded_heads = set() | |
| def add_head(self, name): | |
| self.heads[name] = nn.Sequential( | |
| nn.Linear(self.d_fiber, 64), nn.GELU(), | |
| nn.Linear(64, 64), nn.GELU(), | |
| nn.Linear(64, 1) | |
| ) | |
| def get_fiber_features(self, hidden_states): | |
| device = hidden_states[0].device | |
| fibers = [] | |
| for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)): | |
| proj = proj.to(device) | |
| fibers.append(proj(h.float())) | |
| weights = F.softmax(self.layer_weights.to(device), dim=0) | |
| return sum(w * f for w, f in zip(weights, fibers)) | |
| def get_all_risks(self, hidden_states): | |
| device = hidden_states[0].device | |
| features = self.get_fiber_features(hidden_states) | |
| risks = {} | |
| for name in self.loaded_heads: | |
| self.heads[name] = self.heads[name].to(device) | |
| logits = self.heads[name](features).squeeze(-1) | |
| risks[name] = torch.sigmoid(logits) | |
| return risks | |
| def load_predictor(checkpoint_dir: Path, device): | |
| predictor = MultiHeadPredictor() | |
| rep_path = checkpoint_dir / "cfhot_risk_v2/ckpt_5000/risk_predictor.pt" | |
| if rep_path.exists(): | |
| ckpt = torch.load(rep_path, map_location=device, weights_only=False) | |
| if 'fiber_projs' in ckpt: | |
| for i, proj_state in enumerate(ckpt['fiber_projs']): | |
| predictor.fiber_projs[i].load_state_dict(proj_state) | |
| if 'layer_weights' in ckpt: | |
| predictor.layer_weights.data = ckpt['layer_weights'] | |
| predictor.add_head('repetition') | |
| if 'head_state' in ckpt: | |
| predictor.heads['repetition'].load_state_dict(ckpt['head_state']) | |
| predictor.loaded_heads.add('repetition') | |
| print(" ✓ Loaded repetition head") | |
| for head_name in ['hedging', 'verbosity', 'sycophancy']: | |
| head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_10000/{head_name}_head.pt" | |
| if not head_path.exists(): | |
| head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_2000/{head_name}_head.pt" | |
| if head_path.exists(): | |
| predictor.add_head(head_name) | |
| ckpt = torch.load(head_path, map_location=device, weights_only=False) | |
| if 'head_state' in ckpt: | |
| predictor.heads[head_name].load_state_dict(ckpt['head_state']) | |
| elif isinstance(ckpt, dict) and head_name in ckpt: | |
| predictor.heads[head_name].load_state_dict(ckpt[head_name]) | |
| predictor.loaded_heads.add(head_name) | |
| print(f" ✓ Loaded {head_name} head") | |
| predictor.eval() | |
| for param in predictor.parameters(): | |
| param.requires_grad = False | |
| return predictor.to(device) | |
| def generate_complex_prompts(n: int) -> list: | |
| """Generate prompts that demand dense, technical responses""" | |
| templates = [ | |
| # Technical explanations | |
| "Explain {topic} with technical precision.", | |
| "How does {topic} work at a fundamental level?", | |
| "What are the core mechanisms behind {topic}?", | |
| "Describe the architecture of {topic}.", | |
| "What distinguishes {topic1} from {topic2} technically?", | |
| # Deep dives | |
| "Explain the mathematics behind {topic}.", | |
| "What are the theoretical foundations of {topic}?", | |
| "Describe {topic} as you would to a graduate student.", | |
| "What are the key equations governing {topic}?", | |
| # Implementation | |
| "How would you implement {topic} from scratch?", | |
| "What's the most efficient algorithm for {task}?", | |
| "Explain the time complexity of {topic}.", | |
| # Analysis | |
| "What are the fundamental tradeoffs in {topic}?", | |
| "Why does {topic} work the way it does?", | |
| "What are the failure modes of {topic}?", | |
| "Analyze the strengths and weaknesses of {topic}.", | |
| # Synthesis | |
| "How do {topic1} and {topic2} relate to each other?", | |
| "What principles unify {topic1} and {topic2}?", | |
| "How would you combine {topic1} with {topic2}?", | |
| ] | |
| topics = [ | |
| "transformer attention", "backpropagation", "gradient descent", | |
| "convolutional neural networks", "recurrent neural networks", | |
| "reinforcement learning", "Q-learning", "policy gradients", | |
| "variational autoencoders", "GANs", "diffusion models", | |
| "tokenization", "embedding spaces", "positional encoding", | |
| "layer normalization", "batch normalization", "dropout", | |
| "LSTM gates", "self-attention", "cross-attention", | |
| "beam search", "nucleus sampling", "temperature scaling", | |
| "quantization", "pruning", "knowledge distillation", | |
| "quantum entanglement", "wave function collapse", "superposition", | |
| "natural selection", "genetic drift", "speciation", | |
| "thermodynamic entropy", "information entropy", "free energy", | |
| "consciousness", "qualia", "the binding problem", | |
| "Gödel's incompleteness", "Turing completeness", "P vs NP", | |
| "hash tables", "B-trees", "red-black trees", | |
| "recursion", "dynamic programming", "memoization", | |
| "TCP/IP", "public key cryptography", "consensus algorithms", | |
| ] | |
| tasks = [ | |
| "sorting n elements", "finding shortest path", "matrix multiplication", | |
| "string matching", "graph traversal", "balanced tree insertion", | |
| "hash collision resolution", "memory allocation", "garbage collection", | |
| ] | |
| prompts = [] | |
| complexities = [] # Track expected complexity | |
| for _ in range(n): | |
| template = random.choice(templates) | |
| if "{topic1}" in template and "{topic2}" in template: | |
| t1, t2 = random.sample(topics, 2) | |
| prompt = template.format(topic1=t1, topic2=t2) | |
| complexity = 3 # Comparison = high complexity | |
| elif "{topic}" in template: | |
| topic = random.choice(topics) | |
| prompt = template.format(topic=topic) | |
| complexity = 2 if "mathematics" in template or "equations" in template else 1.5 | |
| elif "{task}" in template: | |
| task = random.choice(tasks) | |
| prompt = template.format(task=task) | |
| complexity = 2 | |
| else: | |
| prompt = template | |
| complexity = 1 | |
| prompts.append((prompt, complexity)) | |
| return prompts | |
| def count_technical_terms(text: str) -> int: | |
| """Count domain-specific technical vocabulary""" | |
| words = set(text.lower().split()) | |
| return len(words.intersection(TECHNICAL_TERMS)) | |
| def count_filler_phrases(text: str) -> int: | |
| """Count filler phrases that waste tokens""" | |
| text_lower = text.lower() | |
| return sum(1 for phrase in FILLER_PHRASES if phrase in text_lower) | |
| def count_factual_claims(text: str) -> int: | |
| """Estimate number of factual assertions""" | |
| # Simple heuristic: sentences with specific patterns | |
| sentences = re.split(r'[.!?]', text) | |
| claims = 0 | |
| for sent in sentences: | |
| sent = sent.strip().lower() | |
| if not sent: | |
| continue | |
| # Patterns indicating factual claims | |
| if any(pattern in sent for pattern in [ | |
| " is ", " are ", " was ", " were ", " has ", " have ", | |
| " means ", " equals ", " produces ", " causes ", " results ", | |
| " requires ", " enables ", " allows ", " prevents ", | |
| "defined as", "consists of", "composed of", | |
| ]): | |
| claims += 1 | |
| return claims | |
| def count_code_and_math(text: str) -> int: | |
| """Count structured technical content""" | |
| code_blocks = len(re.findall(r'```[\s\S]*?```', text)) | |
| inline_code = len(re.findall(r'`[^`]+`', text)) | |
| equations = len(re.findall(r'\$[^$]+\$', text)) | |
| math_symbols = len(re.findall(r'[∑∏∫∂∇≈≠≤≥∈∀∃→←↔×÷±√∞]', text)) | |
| formulas = len(re.findall(r'[a-z]\s*[=<>]\s*[a-z0-9]', text, re.I)) | |
| return code_blocks * 5 + inline_code + equations * 3 + math_symbols + formulas | |
| def compute_dense_reward(response_ids, risks, tokenizer, complexity, config): | |
| """ | |
| Dense reward: maximize information per token | |
| Reward = (information_score) / (effective_tokens) - penalties | |
| """ | |
| batch_rewards = [] | |
| batch_densities = [] | |
| for i in range(len(response_ids)): | |
| response = tokenizer.decode(response_ids[i], skip_special_tokens=True) | |
| tokens = len(response_ids[i]) | |
| if tokens < 5: | |
| batch_rewards.append(0.0) | |
| batch_densities.append(0.0) | |
| continue | |
| # === Information Content === | |
| # 1. Unique concept words (content words > 4 chars) | |
| words = response.split() | |
| content_words = set(w.lower() for w in words if len(w) > 4 and w.isalpha()) | |
| concept_density = len(content_words) / tokens | |
| # 2. Technical term density | |
| tech_terms = count_technical_terms(response) | |
| tech_density = tech_terms / tokens | |
| # 3. Factual claim density | |
| claims = count_factual_claims(response) | |
| claim_density = claims / max(tokens / 20, 1) # Normalize by ~sentence count | |
| # 4. Structured content (code, math) | |
| structured = count_code_and_math(response) | |
| structured_density = structured / tokens | |
| # Combined information score | |
| info_score = ( | |
| concept_density * 0.3 + | |
| tech_density * 0.3 + | |
| claim_density * 0.25 + | |
| structured_density * 0.15 | |
| ) | |
| # === Fluff Penalties === | |
| rep_risk = risks['repetition'][i, -1].item() if 'repetition' in risks else 0 | |
| verb_risk = risks['verbosity'][i, -1].item() if 'verbosity' in risks else 0 | |
| hedge_risk = risks['hedging'][i, -1].item() if 'hedging' in risks else 0 | |
| filler_count = count_filler_phrases(response) | |
| filler_penalty = min(filler_count * 0.05, 0.3) | |
| # Probes penalty (repetition worst, verbosity bad, hedging mild) | |
| probe_penalty = 0.4 * rep_risk + 0.25 * verb_risk + 0.1 * hedge_risk | |
| total_fluff = filler_penalty + probe_penalty | |
| # === Completeness === | |
| # Scale expected length with question complexity | |
| expected_min = config.min_response_tokens * complexity | |
| if tokens < expected_min: | |
| completeness_penalty = 0.3 * (expected_min - tokens) / expected_min | |
| else: | |
| completeness_penalty = 0 | |
| # Bonus for appropriate length (not too short, not excessively long) | |
| if expected_min <= tokens <= expected_min * 3: | |
| length_bonus = 0.1 | |
| elif tokens > expected_min * 4: | |
| length_bonus = -0.1 # Penalize excessive length | |
| else: | |
| length_bonus = 0 | |
| # === Final Reward === | |
| # Effective tokens: actual tokens + penalty for fluff | |
| effective_tokens = tokens * (1 + total_fluff) | |
| # Information per effective token | |
| density = info_score / (effective_tokens / 100) | |
| reward = density - completeness_penalty + length_bonus | |
| reward = max(0, min(1, reward)) # Clamp to [0, 1] | |
| batch_rewards.append(reward) | |
| batch_densities.append(info_score * 100) # For logging | |
| return ( | |
| torch.tensor(batch_rewards, dtype=torch.float32, device=response_ids[0].device), | |
| sum(batch_densities) / len(batch_densities) if batch_densities else 0 | |
| ) | |
| def compute_efficiency_decision(risks): | |
| """Same efficiency routing as terse training""" | |
| rep = risks.get('repetition', torch.zeros(1))[:, -1].mean().item() | |
| verb = risks.get('verbosity', torch.zeros(1))[:, -1].mean().item() | |
| hedge = risks.get('hedging', torch.zeros(1))[:, -1].mean().item() | |
| if rep > 0.45: | |
| return {'layers': 20, 'spec_length': 8, 'strategy': 'skip_speculate_aggressive'} | |
| elif verb > 0.5: | |
| return {'layers': 24, 'spec_length': 6, 'strategy': 'skip_speculate_moderate'} | |
| elif rep < 0.4 and verb < 0.4 and hedge < 0.4: | |
| return {'layers': 16, 'spec_length': 2, 'strategy': 'early_exit_careful'} | |
| else: | |
| return {'layers': 32, 'spec_length': 3, 'strategy': 'full_compute'} | |
| def train(args): | |
| config = DenseConfig() | |
| config.learning_rate = args.lr | |
| device = torch.device("cuda") | |
| print("=" * 60) | |
| print(" ARC Dense Training - Maximum Information Density") | |
| print("=" * 60) | |
| print(f" Batch size: {config.batch_size}") | |
| print(f" Gradient accumulation: {config.gradient_accumulation}") | |
| print(f" Effective batch: {config.batch_size * config.gradient_accumulation}") | |
| print(f" Learning rate: {config.learning_rate}") | |
| print(f" Max new tokens: {config.max_new_tokens}") | |
| print(f" Min response tokens: {config.min_response_tokens}") | |
| print("=" * 60) | |
| print("\n[1/3] Loading model...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.local_model) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.local_model, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa" | |
| ) | |
| start_step = 0 | |
| if args.resume and Path(args.resume).exists(): | |
| print(f" Resuming from {args.resume}") | |
| model = PeftModel.from_pretrained(model, args.resume, is_trainable=True) | |
| state_path = Path(args.resume) / "training_state.pt" | |
| if state_path.exists(): | |
| state = torch.load(state_path, weights_only=False) | |
| start_step = state.get('step', 0) | |
| print(f" Resuming from step {start_step}") | |
| elif args.base_checkpoint and Path(args.base_checkpoint).exists(): | |
| print(f" Loading base checkpoint: {args.base_checkpoint}") | |
| model = PeftModel.from_pretrained(model, args.base_checkpoint, is_trainable=True) | |
| print(" ✓ Loaded terse-trained adapter as starting point") | |
| else: | |
| lora_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| print(f" Model loaded. Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)") | |
| print("\n[2/3] Loading behavioral prediction heads...") | |
| checkpoint_dir = Path.home() / "arc_efficiency_training/checkpoints_backup" | |
| predictor = load_predictor(checkpoint_dir, device) | |
| print(f" ✓ Predictor loaded with heads: {list(predictor.loaded_heads)}") | |
| print("\n[3/3] Setting up optimizer...") | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=config.learning_rate, | |
| weight_decay=0.01, | |
| betas=(0.9, 0.999) | |
| ) | |
| print(f" ✓ Optimizer: AdamW, LR: {config.learning_rate}") | |
| print(f"\nGenerating {args.prompts} complex prompts...") | |
| prompts_with_complexity = generate_complex_prompts(args.prompts) | |
| print(f" ✓ Generated {len(prompts_with_complexity)} prompts") | |
| checkpoint_dir = Path(args.checkpoint_dir) | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| print("\n" + "=" * 60) | |
| print(f" Starting DENSE training from step {start_step}") | |
| print(f" Total steps: {args.steps}") | |
| print("=" * 60 + "\n") | |
| model.train() | |
| optimizer.zero_grad() | |
| step = start_step | |
| accum_loss = 0 | |
| accum_reward = 0 | |
| accum_density = 0 | |
| accum_rep = 0 | |
| accum_layers = 0 | |
| last_strategy = "none" | |
| while step < args.steps: | |
| batch_data = random.sample(prompts_with_complexity, config.batch_size) | |
| batch_prompts = [p[0] for p in batch_data] | |
| batch_complexity = [p[1] for p in batch_data] | |
| avg_complexity = sum(batch_complexity) / len(batch_complexity) | |
| formatted = [ | |
| f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n" | |
| for p in batch_prompts | |
| ] | |
| inputs = tokenizer( | |
| formatted, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ).to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=config.max_new_tokens, | |
| do_sample=True, | |
| temperature=config.temperature, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| use_cache=True | |
| ) | |
| generated_ids = outputs[:, inputs.input_ids.shape[1]:] | |
| with torch.no_grad(): | |
| hidden_outputs = model( | |
| outputs, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| use_cache=False | |
| ) | |
| hidden_states = hidden_outputs.hidden_states[1:] | |
| risks = predictor.get_all_risks(hidden_states) | |
| rewards, avg_density_score = compute_dense_reward( | |
| generated_ids, risks, tokenizer, avg_complexity, config | |
| ) | |
| efficiency = compute_efficiency_decision(risks) | |
| model.train() | |
| logits = model(outputs, return_dict=True, use_cache=False).logits | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = outputs[:, 1:].contiguous() | |
| log_probs = F.log_softmax(shift_logits.float(), dim=-1) | |
| selected_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) | |
| mask = (shift_labels != tokenizer.pad_token_id).float() | |
| seq_log_probs = (selected_log_probs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) | |
| baseline = rewards.float().mean() | |
| advantages = rewards - baseline | |
| loss = -(seq_log_probs * advantages.to(seq_log_probs.device)).mean() | |
| loss = loss / config.gradient_accumulation | |
| loss.backward() | |
| accum_loss += loss.item() * config.gradient_accumulation | |
| accum_reward += rewards.float().mean().item() | |
| accum_density += avg_density_score | |
| accum_rep += risks['repetition'][:, -1].mean().item() if 'repetition' in risks else 0 | |
| accum_layers += efficiency['layers'] | |
| last_strategy = efficiency['strategy'] | |
| if (step + 1) % config.gradient_accumulation == 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| step += 1 | |
| if step % config.log_every == 0: | |
| avg_loss = accum_loss / config.log_every | |
| avg_reward = accum_reward / config.log_every | |
| avg_dens = accum_density / config.log_every | |
| avg_rep = accum_rep / config.log_every | |
| avg_layers = accum_layers / config.log_every | |
| print(f"Step {step:6d} | Loss: {avg_loss:.4f} | Reward: {avg_reward:.3f} | " | |
| f"Density: {avg_dens:.2f} | Rep: {avg_rep:.3f} | Layers: {avg_layers:.1f} | {last_strategy}") | |
| accum_loss = 0 | |
| accum_reward = 0 | |
| accum_density = 0 | |
| accum_rep = 0 | |
| accum_layers = 0 | |
| if step % config.checkpoint_every == 0: | |
| ckpt_path = checkpoint_dir / f"step_{step}" | |
| model.save_pretrained(ckpt_path) | |
| torch.save({ | |
| 'step': step, | |
| 'optimizer': optimizer.state_dict(), | |
| 'config': config.__dict__, | |
| 'mode': 'dense' | |
| }, ckpt_path / "training_state.pt") | |
| with open(ckpt_path / "README.md", "w") as f: | |
| f.write(f"# ARC Dense Checkpoint - Step {step}\n\n") | |
| f.write("**Mode:** Dense (maximum information per token)\n\n") | |
| f.write(f"Training config:\n```json\n{json.dumps(config.__dict__, indent=2)}\n```\n") | |
| print(f" ✓ Saved dense checkpoint at step {step}") | |
| if step % config.regenerate_prompts_every == 0 and step > start_step: | |
| print(f"\n Regenerating complex prompts...") | |
| prompts_with_complexity = generate_complex_prompts(args.prompts) | |
| print(f" ✓ Generated {len(prompts_with_complexity)} fresh prompts\n") | |
| print("\n" + "=" * 60) | |
| print(" Dense training complete!") | |
| print("=" * 60) | |
| final_path = checkpoint_dir / "final" | |
| model.save_pretrained(final_path) | |
| print(f" ✓ Saved final dense model to {final_path}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--local-model", type=str, required=True) | |
| parser.add_argument("--base-checkpoint", type=str, default=None, | |
| help="Start from terse-trained checkpoint") | |
| parser.add_argument("--steps", type=int, default=20000) | |
| parser.add_argument("--lr", type=float, default=3e-6) | |
| parser.add_argument("--prompts", type=int, default=5000) | |
| parser.add_argument("--checkpoint-dir", type=str, default="./dense_checkpoints") | |
| parser.add_argument("--resume", type=str, default=None) | |
| args = parser.parse_args() | |
| train(args) | |