|
|
| """
|
| Coherence evaluation for language models.
|
|
|
| Measures what standard benchmarks can't see:
|
| Tier 1 β Generation diversity (repetition, collapse detection)
|
| Tier 2 β Multi-distance prediction (context utilization, skip accuracy)
|
| Tier 3 β Semantic consistency (chunk similarity over long generations)
|
|
|
| Usage:
|
| # Custom checkpoint
|
| python -m circuits.coherence_eval --checkpoint circuits/checkpoints/model/best.pt
|
|
|
| # HuggingFace model
|
| python -m circuits.coherence_eval --model gpt2
|
|
|
| # Compare models
|
| python -m circuits.coherence_eval --model EleutherAI/pythia-160m --gpu 0
|
|
|
| # Quick test (fewer prompts, shorter generation)
|
| python -m circuits.coherence_eval --checkpoint path/to/model.pt --num-prompts 5 --gen-length 256
|
|
|
| # Run specific tiers
|
| python -m circuits.coherence_eval --checkpoint path/to/model.pt --tiers 1,3
|
| """
|
|
|
| import argparse
|
| import json
|
| import math
|
| import sys
|
| import time
|
| from pathlib import Path
|
|
|
| import torch
|
| import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
| DEFAULT_PROMPTS = [
|
| "A thought observing itself discovers that it",
|
| "The history of science shows that",
|
| "In the middle of the night, the old house",
|
| "The relationship between language and thought has been",
|
| "When the first settlers arrived, they found",
|
| "The mathematical proof begins by assuming",
|
| "She opened the door to find",
|
| "The economic implications of this policy",
|
| "Deep beneath the ocean surface, researchers discovered",
|
| "The most important lesson from this experiment is",
|
| "According to recent studies, the human brain",
|
| "The old library contained books that",
|
| "As the temperature continued to rise, the effects on",
|
| "The development of artificial intelligence has raised questions about",
|
| "In the small village at the foot of the mountain",
|
| "The fundamental principles of democracy require",
|
| "Looking through the telescope, the astronomer noticed",
|
| "The relationship between music and emotion",
|
| "During the industrial revolution, working conditions",
|
| "The ancient manuscript revealed secrets about",
|
| ]
|
|
|
|
|
|
|
|
|
|
|
|
|
| class ModelWrapper:
|
| """Unified interface for custom circuit models and HuggingFace models."""
|
|
|
| def __init__(self, model, tokenizer, device, model_type="hf",
|
| skip_head=None, skip_k=0, max_seq_len=1024, name="unknown"):
|
| self.model = model
|
| self.tokenizer = tokenizer
|
| self.device = device
|
| self.model_type = model_type
|
| self.skip_head = skip_head
|
| self.skip_k = skip_k
|
| self.max_seq_len = max_seq_len
|
| self.name = name
|
|
|
| @classmethod
|
| def from_checkpoint(cls, path, device):
|
| """Load a custom circuit model from checkpoint."""
|
| from .config import CircuitConfig
|
| from .model import CircuitTransformer
|
| from .mirrored import MirroredConfig, MirroredTransformer
|
| from .slotted_mirrored import SlotMirroredConfig, SlotMirroredTransformer
|
| from .data import get_tokenizer
|
|
|
| checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
| model_type = checkpoint.get("model_type", "standard")
|
|
|
| if model_type == "slot_mirrored":
|
| config = SlotMirroredConfig.from_dict(checkpoint["config"])
|
| model = SlotMirroredTransformer(config).to(device)
|
| arch_desc = f"SlotMirrored ({config.n_slots} slots)"
|
| elif model_type == "mirrored":
|
| config = MirroredConfig.from_dict(checkpoint["config"])
|
| model = MirroredTransformer(config).to(device)
|
| arch_desc = "Mirrored"
|
| else:
|
| config = CircuitConfig.from_dict(checkpoint["config"])
|
| model = CircuitTransformer(config).to(device)
|
| arch_desc = "Standard"
|
|
|
|
|
| state_dict = checkpoint["model"]
|
| if any(k.startswith("_orig_mod.") for k in state_dict):
|
| state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
|
| model.load_state_dict(state_dict)
|
| model.eval()
|
|
|
| tokenizer = get_tokenizer()
|
| skip_head = model.skip_head if hasattr(model, 'skip_head') else None
|
| skip_k = getattr(config, 'aux_skip_k', 0)
|
| max_seq_len = config.max_seq_len
|
|
|
| params = sum(p.numel() for p in model.parameters()) / 1e6
|
| name = f"{Path(path).parent.name}/{Path(path).stem} ({arch_desc}, {params:.1f}M)"
|
|
|
| return cls(model, tokenizer, device, model_type="circuit",
|
| skip_head=skip_head, skip_k=skip_k,
|
| max_seq_len=max_seq_len, name=name)
|
|
|
| @classmethod
|
| def from_pretrained(cls, model_name, device):
|
| """Load a HuggingFace model."""
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_name, trust_remote_code=True,
|
| torch_dtype=torch.float32,
|
| ).to(device)
|
| model.eval()
|
|
|
| max_seq_len = getattr(model.config, 'max_position_embeddings', 1024)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| params = sum(p.numel() for p in model.parameters()) / 1e6
|
| name = f"{model_name} ({params:.1f}M)"
|
|
|
| return cls(model, tokenizer, device, model_type="hf",
|
| max_seq_len=max_seq_len, name=name)
|
|
|
| @property
|
| def has_skip_head(self):
|
| return self.skip_head is not None and self.skip_k > 0
|
|
|
| def generate(self, prompt_text, max_new_tokens=512):
|
| """Generate tokens at temperature 0 (greedy). Returns generated token IDs only."""
|
| prompt_ids = self.tokenizer.encode(prompt_text, return_tensors="pt").to(self.device)
|
|
|
| with torch.no_grad():
|
| if self.model_type == "hf":
|
| output_ids = self.model.generate(
|
| prompt_ids,
|
| max_new_tokens=max_new_tokens,
|
| do_sample=True,
|
| pad_token_id=self.tokenizer.pad_token_id,
|
| temperature=0.8,
|
| top_k=50,
|
| top_p=0.9,
|
| repetition_penalty=1.2,
|
| )
|
| else:
|
| output_ids = self.model.generate(
|
| prompt_ids,
|
| max_new_tokens=max_new_tokens,
|
| temperature=0.8,
|
| top_k=50,
|
| top_p=0.9,
|
| repetition_penalty=1.2,
|
| )
|
|
|
|
|
| gen_ids = output_ids[0, prompt_ids.shape[1]:]
|
| return prompt_ids[0], gen_ids
|
|
|
| def forward_with_hidden(self, input_ids):
|
| """Forward pass returning (logits, hidden_states, skip_logits_or_None).
|
| input_ids: [1, L] tensor.
|
| """
|
| with torch.no_grad():
|
| if self.model_type == "hf":
|
| outputs = self.model(input_ids, output_hidden_states=True)
|
| logits = outputs.logits
|
| hidden = outputs.hidden_states[-1]
|
| return logits, hidden, None
|
| else:
|
|
|
| hidden_capture = {}
|
|
|
| def hook_fn(module, inp, output):
|
| hidden_capture['h'] = output.detach()
|
|
|
| handle = self.model.norm.register_forward_hook(hook_fn)
|
| output = self.model(input_ids)
|
| handle.remove()
|
|
|
| logits = output['logits']
|
| hidden = hidden_capture['h']
|
|
|
| skip_logits = None
|
| if self.has_skip_head:
|
| skip_logits = self.skip_head(hidden)
|
|
|
| return logits, hidden, skip_logits
|
|
|
| def forward(self, input_ids):
|
| """Forward pass returning logits only. input_ids: [1, L] tensor."""
|
| with torch.no_grad():
|
| if self.model_type == "hf":
|
| return self.model(input_ids).logits
|
| else:
|
| return self.model(input_ids)['logits']
|
|
|
|
|
|
|
|
|
|
|
|
|
| def generate_all(wrapper, prompts, gen_length):
|
| """Generate from all prompts. Returns list of (prompt_text, prompt_ids, gen_ids)."""
|
| results = []
|
| for prompt in prompts:
|
| prompt_ids, gen_ids = wrapper.generate(prompt, max_new_tokens=gen_length)
|
| results.append((prompt, prompt_ids, gen_ids))
|
| print(f" [{len(results)}/{len(prompts)}] {len(gen_ids)} tokens", end="\r")
|
| print()
|
| return results
|
|
|
|
|
|
|
|
|
|
|
|
|
| def ngrams(tokens, n):
|
| """Extract n-grams from token list."""
|
| return [tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)]
|
|
|
|
|
| def compute_diversity(gen_ids):
|
| """Compute diversity metrics for a single generation."""
|
| tokens = gen_ids.tolist()
|
| n = len(tokens)
|
| if n < 4:
|
| return {"unique_1g": 0, "unique_2g": 0, "unique_3g": 0, "unique_4g": 0,
|
| "max_repeat": n, "collapsed": True}
|
|
|
| results = {}
|
| for k in [1, 2, 3, 4]:
|
| grams = ngrams(tokens, k)
|
| results[f"unique_{k}g"] = len(set(grams)) / len(grams) if grams else 0.0
|
|
|
|
|
| max_repeat = 1
|
| current = 1
|
| for i in range(1, n):
|
| if tokens[i] == tokens[i - 1]:
|
| current += 1
|
| max_repeat = max(max_repeat, current)
|
| else:
|
| current = 1
|
| results["max_repeat"] = max_repeat
|
|
|
|
|
| max_ngram_repeat = 1
|
| for ng_size in [2, 3, 4, 5, 8]:
|
| grams = ngrams(tokens, ng_size)
|
| streak = 1
|
| for i in range(1, len(grams)):
|
| if grams[i] == grams[i - 1]:
|
| streak += 1
|
| max_ngram_repeat = max(max_ngram_repeat, streak * ng_size)
|
| else:
|
| streak = 1
|
| results["max_ngram_repeat_span"] = max_ngram_repeat
|
|
|
|
|
| results["collapsed"] = (results["unique_4g"] < 0.5) or (max_ngram_repeat > n * 0.25)
|
|
|
| return results
|
|
|
|
|
| def eval_diversity(generations, tokenizer, show_samples=3):
|
| """Tier 1: Compute diversity metrics from pre-generated text."""
|
| print("\n" + "=" * 60)
|
| print("TIER 1: Generation Diversity")
|
| print("=" * 60)
|
|
|
| all_metrics = []
|
| sample_texts = []
|
|
|
| for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
|
| metrics = compute_diversity(gen_ids)
|
| metrics["prompt"] = prompt
|
| metrics["gen_length"] = len(gen_ids)
|
| all_metrics.append(metrics)
|
|
|
| if i < show_samples:
|
| text = tokenizer.decode(gen_ids, skip_special_tokens=True)
|
| sample_texts.append((prompt, text))
|
|
|
| n = len(all_metrics)
|
| if n == 0:
|
| print(" No generations to evaluate.")
|
| return {}
|
|
|
|
|
| agg = {}
|
| for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g",
|
| "max_repeat", "max_ngram_repeat_span"]:
|
| values = [m[key] for m in all_metrics]
|
| agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
|
|
|
| collapse_count = sum(1 for m in all_metrics if m["collapsed"])
|
| agg["collapse_rate"] = collapse_count / n
|
| avg_len = sum(m["gen_length"] for m in all_metrics) / n
|
|
|
|
|
| print(f"\n Prompts evaluated: {n}")
|
| print(f" Avg generation length: {avg_len:.0f} tokens")
|
| print()
|
| print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
|
| print(f" {'-' * 50}")
|
| for key in ["unique_1g", "unique_2g", "unique_3g", "unique_4g"]:
|
| m = agg[key]
|
| print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
|
| for key in ["max_repeat", "max_ngram_repeat_span"]:
|
| m = agg[key]
|
| print(f" {key:<24} {m['mean']:>8.1f} {int(m['min']):>8d} {int(m['max']):>8d}")
|
| print(f"\n Collapse rate: {collapse_count}/{n} ({agg['collapse_rate']:.1%})")
|
|
|
|
|
| if sample_texts:
|
| print(f"\n --- Sample generations (first {len(sample_texts)}) ---")
|
| for prompt, text in sample_texts:
|
| print(f"\n Prompt: \"{prompt}\"")
|
| preview = text[:400].replace("\n", " ")
|
| if len(text) > 400:
|
| preview += "..."
|
| print(f" Output: {preview}")
|
|
|
| return {"per_prompt": all_metrics, "aggregate": agg}
|
|
|
|
|
|
|
|
|
|
|
|
|
| def prepare_eval_sequences(wrapper, num_sequences=50, data_source=None):
|
| """Prepare ground truth sequences for Tier 2."""
|
| max_len = wrapper.max_seq_len
|
|
|
| if data_source and Path(data_source).exists():
|
| with open(data_source) as f:
|
| text = f.read()
|
| all_ids = wrapper.tokenizer.encode(text)
|
| else:
|
| try:
|
| from datasets import load_dataset
|
| print(" Loading WikiText-103 validation...")
|
| ds = load_dataset("wikitext", "wikitext-103-raw-v1",
|
| split="validation", trust_remote_code=True)
|
| text = "\n".join(row["text"] for row in ds if row["text"].strip())
|
| all_ids = wrapper.tokenizer.encode(text)
|
| except Exception as e:
|
| print(f" Could not load eval data: {e}")
|
| print(f" Install 'datasets' or use --eval-data to provide a text file.")
|
| return None
|
|
|
|
|
| sequences = []
|
| for i in range(0, len(all_ids) - max_len, max_len):
|
| seq = torch.tensor(all_ids[i:i + max_len], dtype=torch.long)
|
| sequences.append(seq)
|
| if len(sequences) >= num_sequences:
|
| break
|
|
|
| if len(sequences) < 2:
|
| print(" Not enough text for evaluation sequences.")
|
| return None
|
|
|
| print(f" Prepared {len(sequences)} sequences of {max_len} tokens")
|
| return sequences
|
|
|
|
|
| def eval_context_utilization(wrapper, sequences):
|
| """Tier 2a: Per-position perplexity grouped by depth bucket."""
|
| max_len = wrapper.max_seq_len
|
|
|
|
|
| bucket_bounds = [0, 64, 128, 256, 512]
|
| if max_len > 512:
|
| bucket_bounds.append(max_len)
|
| else:
|
| bucket_bounds.append(max_len)
|
|
|
| bucket_bounds = sorted(set(b for b in bucket_bounds if b <= max_len))
|
| if bucket_bounds[-1] < max_len:
|
| bucket_bounds.append(max_len)
|
| buckets = [(bucket_bounds[i], bucket_bounds[i + 1])
|
| for i in range(len(bucket_bounds) - 1)]
|
|
|
|
|
| all_losses = []
|
| for seq in sequences:
|
| input_ids = seq.unsqueeze(0).to(wrapper.device)
|
| logits = wrapper.forward(input_ids)
|
|
|
| shift_logits = logits[0, :-1]
|
| shift_labels = input_ids[0, 1:]
|
| per_token_loss = F.cross_entropy(shift_logits, shift_labels, reduction='none')
|
| all_losses.append(per_token_loss.cpu())
|
| print(f" [{len(all_losses)}/{len(sequences)}]", end="\r")
|
| print()
|
|
|
|
|
| stacked = torch.stack(all_losses)
|
| bucket_results = {}
|
| for start, end in buckets:
|
| s = min(start, stacked.shape[1])
|
| e = min(end, stacked.shape[1])
|
| if s >= e:
|
| continue
|
| bucket_losses = stacked[:, s:e]
|
| avg_loss = bucket_losses.mean().item()
|
| bucket_results[f"{start}-{end}"] = {
|
| "loss": avg_loss,
|
| "ppl": math.exp(min(avg_loss, 20)),
|
| "n_tokens": bucket_losses.numel(),
|
| }
|
|
|
| return bucket_results
|
|
|
|
|
| def eval_skip_accuracy(wrapper, sequences, distances):
|
| """Tier 2b: Skip head prediction accuracy at various distances."""
|
| if not wrapper.has_skip_head:
|
| return None
|
|
|
| results = {f"t+{K}": {"top1": [], "top5": []} for K in distances}
|
|
|
| for seq in sequences:
|
| input_ids = seq.unsqueeze(0).to(wrapper.device)
|
| _, hidden, _ = wrapper.forward_with_hidden(input_ids)
|
|
|
| for K in distances:
|
| if K >= input_ids.shape[1]:
|
| continue
|
|
|
| skip_logits = wrapper.skip_head(hidden)
|
| targets = input_ids[0, K:]
|
| preds = skip_logits[0, :-K]
|
|
|
| top1 = (preds.argmax(-1) == targets).float().mean().item()
|
| top5_indices = preds.topk(min(5, preds.shape[-1]), dim=-1).indices
|
| top5 = (top5_indices == targets.unsqueeze(-1)).any(-1).float().mean().item()
|
|
|
| results[f"t+{K}"]["top1"].append(top1)
|
| results[f"t+{K}"]["top5"].append(top5)
|
|
|
| print(f" [{len(results['t+' + str(distances[0])]['top1'])}/{len(sequences)}]", end="\r")
|
| print()
|
|
|
|
|
| avg_results = {}
|
| for key in sorted(results.keys(), key=lambda x: int(x.split("+")[1])):
|
| vals = results[key]
|
| if vals["top1"]:
|
| avg_results[key] = {
|
| "top1": sum(vals["top1"]) / len(vals["top1"]),
|
| "top5": sum(vals["top5"]) / len(vals["top5"]),
|
| }
|
|
|
| return avg_results
|
|
|
|
|
| def eval_structural(wrapper, eval_data, distances, num_sequences):
|
| """Run Tier 2 evaluation."""
|
| print("\n" + "=" * 60)
|
| print("TIER 2: Structural Prediction")
|
| print("=" * 60)
|
|
|
| sequences = prepare_eval_sequences(wrapper, num_sequences, eval_data)
|
| if sequences is None:
|
| return {"context_utilization": None, "skip_accuracy": None}
|
|
|
|
|
| print("\n --- 2a: Context Utilization (PPL by position depth) ---")
|
| ctx_results = eval_context_utilization(wrapper, sequences)
|
|
|
| if ctx_results:
|
| print(f"\n {'Depth':<12} {'Loss':>8} {'PPL':>10} {'Tokens':>10}")
|
| print(f" {'-' * 42}")
|
| for bucket, vals in ctx_results.items():
|
| print(f" {bucket:<12} {vals['loss']:>8.3f} {vals['ppl']:>10.2f} {vals['n_tokens']:>10}")
|
|
|
| buckets_list = list(ctx_results.values())
|
| if len(buckets_list) >= 2:
|
| ratio = buckets_list[0]["ppl"] / buckets_list[-1]["ppl"]
|
| print(f"\n Context utilization ratio (first/last): {ratio:.2f}x")
|
| print(f" (Higher = model benefits more from additional context)")
|
|
|
|
|
| skip_results = None
|
| if wrapper.has_skip_head:
|
| print(f"\n --- 2b: Skip Head Accuracy (trained for t+{wrapper.skip_k}) ---")
|
| skip_results = eval_skip_accuracy(wrapper, sequences, distances)
|
|
|
| if skip_results:
|
| print(f"\n {'Distance':<12} {'Top-1':>8} {'Top-5':>8}")
|
| print(f" {'-' * 30}")
|
| for key, vals in skip_results.items():
|
| trained = " *" if int(key.split("+")[1]) == wrapper.skip_k else ""
|
| print(f" {key:<12} {vals['top1']:>8.4f} {vals['top5']:>8.4f}{trained}")
|
| print(f"\n * = trained distance")
|
| else:
|
| print("\n Skip head: not available")
|
|
|
| return {"context_utilization": ctx_results, "skip_accuracy": skip_results}
|
|
|
|
|
|
|
|
|
|
|
|
|
| def compute_chunk_similarity(hidden_states, chunk_size=128):
|
| """Compute cosine similarity between chunks of hidden states.
|
| hidden_states: [L, D] tensor.
|
| """
|
| L, D = hidden_states.shape
|
| n_chunks = L // chunk_size
|
|
|
| if n_chunks < 2:
|
| return None
|
|
|
|
|
| chunks = []
|
| for i in range(n_chunks):
|
| chunk = hidden_states[i * chunk_size:(i + 1) * chunk_size]
|
| chunks.append(chunk.mean(dim=0))
|
|
|
| chunk_vecs = torch.stack(chunks)
|
| chunk_vecs = F.normalize(chunk_vecs, dim=-1)
|
|
|
|
|
| sim_matrix = chunk_vecs @ chunk_vecs.T
|
|
|
|
|
| mask = torch.triu(torch.ones_like(sim_matrix, dtype=torch.bool), diagonal=1)
|
| pairwise_sims = sim_matrix[mask]
|
|
|
|
|
| adjacent = [sim_matrix[i, i + 1].item() for i in range(n_chunks - 1)]
|
|
|
|
|
| q1 = max(1, n_chunks // 4)
|
| distant = []
|
| for i in range(q1):
|
| for j in range(n_chunks - q1, n_chunks):
|
| if i < j:
|
| distant.append(sim_matrix[i, j].item())
|
|
|
| return {
|
| "mean_sim": pairwise_sims.mean().item(),
|
| "min_sim": pairwise_sims.min().item(),
|
| "adjacent_sim": sum(adjacent) / len(adjacent),
|
| "distant_sim": sum(distant) / len(distant) if distant else 0.0,
|
| "n_chunks": n_chunks,
|
| }
|
|
|
|
|
| def eval_consistency(wrapper, generations, chunk_size=128):
|
| """Tier 3: Semantic consistency of generated text via hidden state similarity."""
|
| print("\n" + "=" * 60)
|
| print("TIER 3: Semantic Consistency")
|
| print("=" * 60)
|
|
|
| all_metrics = []
|
|
|
| for i, (prompt, prompt_ids, gen_ids) in enumerate(generations):
|
| if gen_ids.shape[0] < chunk_size * 2:
|
| continue
|
|
|
|
|
| full_ids = torch.cat([prompt_ids, gen_ids]).unsqueeze(0).to(wrapper.device)
|
|
|
|
|
| if full_ids.shape[1] > wrapper.max_seq_len:
|
| full_ids = full_ids[:, :wrapper.max_seq_len]
|
|
|
| _, hidden, _ = wrapper.forward_with_hidden(full_ids)
|
|
|
|
|
| gen_start = prompt_ids.shape[0]
|
| gen_hidden = hidden[0, gen_start:]
|
|
|
| metrics = compute_chunk_similarity(gen_hidden, chunk_size)
|
| if metrics is not None:
|
| metrics["prompt"] = prompt
|
| all_metrics.append(metrics)
|
|
|
| print(f" [{len(all_metrics)}/{len(generations)}]", end="\r")
|
| print()
|
|
|
| if not all_metrics:
|
| print(" No valid generations for consistency evaluation.")
|
| return {}
|
|
|
| n = len(all_metrics)
|
| agg = {}
|
| for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim"]:
|
| values = [m[key] for m in all_metrics]
|
| agg[key] = {"mean": sum(values) / n, "min": min(values), "max": max(values)}
|
|
|
|
|
| drift_vals = [m["adjacent_sim"] - m["distant_sim"] for m in all_metrics]
|
| agg["topic_drift"] = {"mean": sum(drift_vals) / n,
|
| "min": min(drift_vals), "max": max(drift_vals)}
|
|
|
|
|
| print(f"\n Generations evaluated: {n}")
|
| print(f" Chunk size: {chunk_size} tokens")
|
| avg_chunks = sum(m["n_chunks"] for m in all_metrics) / n
|
| print(f" Avg chunks per generation: {avg_chunks:.1f}")
|
| print()
|
| print(f" {'Metric':<24} {'Mean':>8} {'Min':>8} {'Max':>8}")
|
| print(f" {'-' * 50}")
|
| for key in ["mean_sim", "min_sim", "adjacent_sim", "distant_sim", "topic_drift"]:
|
| m = agg[key]
|
| print(f" {key:<24} {m['mean']:>8.3f} {m['min']:>8.3f} {m['max']:>8.3f}")
|
|
|
| return {"per_prompt": all_metrics, "aggregate": agg}
|
|
|
|
|
|
|
|
|
|
|
|
|
| def print_summary(results):
|
| """Print composite summary scores."""
|
| print("\n" + "=" * 60)
|
| print("SUMMARY")
|
| print("=" * 60)
|
|
|
| scores = {}
|
|
|
|
|
| t1 = results.get("tier1_diversity", {})
|
| if t1 and "aggregate" in t1:
|
| div_score = t1["aggregate"].get("unique_4g", {}).get("mean", None)
|
| collapse = t1["aggregate"].get("collapse_rate", None)
|
| if div_score is not None:
|
| scores["diversity"] = div_score
|
| print(f" Diversity (unique 4-gram): {div_score:.3f}", end="")
|
| if collapse is not None:
|
| print(f" (collapse: {collapse:.0%})", end="")
|
| print()
|
|
|
|
|
| t2 = results.get("tier2_structural", {})
|
| if t2:
|
| ctx = t2.get("context_utilization")
|
| if ctx:
|
| buckets = list(ctx.values())
|
| if len(buckets) >= 2:
|
| ratio = buckets[0]["ppl"] / buckets[-1]["ppl"]
|
| scores["context_util"] = ratio
|
| print(f" Context utilization: {ratio:.2f}x")
|
|
|
| skip = t2.get("skip_accuracy")
|
| if skip:
|
|
|
| trained_key = None
|
| for key in skip:
|
| trained_key = key
|
| break
|
| if trained_key:
|
| top5 = skip[trained_key]["top5"]
|
| scores["skip_top5"] = top5
|
| print(f" Skip accuracy ({trained_key} top-5): {top5:.4f}")
|
|
|
|
|
| t3 = results.get("tier3_consistency", {})
|
| if t3 and "aggregate" in t3:
|
| coh_score = t3["aggregate"].get("mean_sim", {}).get("mean", None)
|
| drift = t3["aggregate"].get("topic_drift", {}).get("mean", None)
|
| if coh_score is not None:
|
| scores["coherence"] = coh_score
|
| print(f" Coherence (chunk sim): {coh_score:.3f}", end="")
|
| if drift is not None:
|
| print(f" (drift: {drift:.3f})", end="")
|
| print()
|
|
|
| results["summary"] = scores
|
| return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(
|
| description="Coherence evaluation for language models",
|
| formatter_class=argparse.RawDescriptionHelpFormatter,
|
| )
|
|
|
|
|
| group = parser.add_mutually_exclusive_group(required=True)
|
| group.add_argument("--checkpoint", type=str, help="Path to circuit model checkpoint")
|
| group.add_argument("--model", type=str, help="HuggingFace model name or path")
|
|
|
|
|
| parser.add_argument("--prompts", type=str, help="File with prompts (one per line)")
|
| parser.add_argument("--num-prompts", type=int, default=20,
|
| help="Number of prompts to use (default: 20)")
|
| parser.add_argument("--gen-length", type=int, default=512,
|
| help="Tokens to generate per prompt (default: 512)")
|
| parser.add_argument("--eval-data", type=str,
|
| help="Text file for Tier 2 (default: WikiText-103 validation)")
|
| parser.add_argument("--num-sequences", type=int, default=50,
|
| help="Number of sequences for Tier 2 (default: 50)")
|
| parser.add_argument("--chunk-size", type=int, default=128,
|
| help="Chunk size for Tier 3 similarity (default: 128)")
|
| parser.add_argument("--distances", type=str, default="2,5,10,25,50,100",
|
| help="Skip distances for Tier 2b (default: 2,5,10,25,50,100)")
|
| parser.add_argument("--tiers", type=str, default="1,2,3",
|
| help="Which tiers to run (default: 1,2,3)")
|
|
|
|
|
| parser.add_argument("--gpu", type=int, default=0, help="GPU index (default: 0)")
|
|
|
|
|
| parser.add_argument("--output", type=str, help="Save results to JSON file")
|
| parser.add_argument("--samples", type=int, default=3,
|
| help="Number of sample generations to display (default: 3)")
|
|
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| args = parse_args()
|
|
|
| device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
|
| tiers = [int(t) for t in args.tiers.split(",")]
|
| distances = [int(d) for d in args.distances.split(",")]
|
|
|
|
|
| print("=" * 60)
|
| print("Coherence Evaluation")
|
| print("=" * 60)
|
|
|
| if args.checkpoint:
|
| print(f"Loading: {args.checkpoint}")
|
| wrapper = ModelWrapper.from_checkpoint(args.checkpoint, device)
|
| else:
|
| print(f"Loading: {args.model}")
|
| wrapper = ModelWrapper.from_pretrained(args.model, device)
|
|
|
| print(f"Model: {wrapper.name}")
|
| print(f"Device: {device}")
|
| print(f"Max seq len: {wrapper.max_seq_len}")
|
| if wrapper.has_skip_head:
|
| print(f"Skip head: t+{wrapper.skip_k}")
|
| print(f"Tiers: {tiers}")
|
|
|
|
|
| if args.prompts:
|
| with open(args.prompts) as f:
|
| prompts = [line.strip() for line in f if line.strip()]
|
| else:
|
| prompts = DEFAULT_PROMPTS
|
| prompts = prompts[:args.num_prompts]
|
| print(f"Prompts: {len(prompts)}")
|
|
|
| results = {"model": wrapper.name}
|
| t0 = time.time()
|
|
|
|
|
| generations = None
|
| if 1 in tiers or 3 in tiers:
|
| print(f"\nGenerating {args.gen_length} tokens from {len(prompts)} prompts...")
|
| generations = generate_all(wrapper, prompts, args.gen_length)
|
|
|
|
|
| if 1 in tiers and generations:
|
| results["tier1_diversity"] = eval_diversity(
|
| generations, wrapper.tokenizer, show_samples=args.samples)
|
|
|
|
|
| if 2 in tiers:
|
| results["tier2_structural"] = eval_structural(
|
| wrapper, args.eval_data, distances, args.num_sequences)
|
|
|
|
|
| if 3 in tiers and generations:
|
| results["tier3_consistency"] = eval_consistency(
|
| wrapper, generations, args.chunk_size)
|
|
|
|
|
| print_summary(results)
|
|
|
| elapsed = time.time() - t0
|
| print(f"\nTotal time: {elapsed:.0f}s")
|
|
|
|
|
| if args.output:
|
| def make_serializable(obj):
|
| if isinstance(obj, dict):
|
| return {k: make_serializable(v) for k, v in obj.items()}
|
| elif isinstance(obj, list):
|
| return [make_serializable(v) for v in obj]
|
| elif isinstance(obj, torch.Tensor):
|
| return obj.tolist()
|
| elif isinstance(obj, float):
|
| if math.isnan(obj) or math.isinf(obj):
|
| return str(obj)
|
| return obj
|
|
|
| out_path = Path(args.output)
|
| out_path.parent.mkdir(parents=True, exist_ok=True)
|
| with open(out_path, "w") as f:
|
| json.dump(make_serializable(results), f, indent=2)
|
| print(f"Results saved to {args.output}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|