#!/usr/bin/env python3 """ Evaluate SpiderPortal v5-Dense checkpoint with side-by-side MoE comparison. Usage: python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-final-ep1.pt --moe checkpoints/spiderportal-v5-final-ep1.pt --all python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-ep1-step1000.pt --prompts "The cat sat on the" """ import argparse import math import sys import time import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from transformers import AutoTokenizer @dataclass class SpiderPortalConfig: vocab_size: int = 50257 hidden_size: int = 384 num_hidden_layers: int = 8 num_attention_heads: int = 8 num_key_value_heads: int = 2 intermediate_size: int = 1024 num_experts: int = 64 num_experts_per_tok: int = 1 router_aux_loss_coef: float = 0.05 max_loop_iters: int = 1 act_threshold: float = 0.5 max_position_embeddings: int = 131072 rope_theta: float = 10000000.0 rope_scaling: dict = None sliding_window: int = 4096 attention_dropout: float = 0.0 rms_norm_eps: float = 1e-6 initializer_range: float = 0.02 tie_word_embeddings: bool = True prelude_layers: int = 2 coda_layers: int = 2 lora_rank: int = 32 loop_embed_dim: int = 48 def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) angles = loop_t * freqs emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) emb_full[:loop_dim] = emb return h + emb_full.unsqueeze(0).unsqueeze(0) def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0): dim = head_dim orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim)) pos_freqs = torch.arange(0, dim, 2).float() / dim beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max)) scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor))) return orig_inv_freq * scale class SpiderPortalRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight.to(input_dtype) * hidden_states.to(input_dtype) class SpiderPortalGQA(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.attention_dropout = config.attention_dropout rope_scaling = getattr(config, 'rope_scaling', None) if rope_scaling and rope_scaling.get("type") == "yarn": factor = rope_scaling.get("factor", 1.0) orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings) inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos) else: inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def _rotate_half(self, x): x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def _apply_rotary(self, x, cos, sin): return (x * cos) + (self._rotate_half(x) * sin) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) if position_ids is None: position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1) max_pos = position_ids.max().item() + 1 seq_len = max(max_pos, q_len) t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos, sin = emb.cos(), emb.sin() cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) query_states = self._apply_rotary(query_states, cos, sin) key_states = self._apply_rotary(key_states, cos, sin) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_kv = (key_states, value_states) if use_cache else None key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=attention_mask is None ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) return self.o_proj(attn_output), past_kv class SpiderPortalExpert(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() inter_size = intermediate_size or config.intermediate_size self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False) self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, hidden_states): return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class SpiderPortalDenseLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = SpiderPortalGQA(config) dense_intermediate = config.hidden_size * 4 // 3 self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output ffn_input = self.post_attention_layernorm(hidden_states) ffn_output = self.ffn(ffn_input) hidden_states = hidden_states + ffn_output return hidden_states, past_kv class SpiderPortalRecurrentDenseLayer(nn.Module): """Dense recurrent layer — matches checkpoint keys.""" def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.self_attn = SpiderPortalGQA(config) self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output ffn_input = self.post_attention_layernorm(hidden_states) ffn_output = self.ffn(ffn_input) hidden_states = hidden_states + ffn_output return hidden_states, 0.0, past_kv # MoE layer for comparison model class SpiderPortalRouter(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range) self.register_buffer("router_bias", torch.zeros(config.num_experts)) def forward(self, hidden_states): router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) biased_logits = router_logits + self.router_bias biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32) top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1) top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) top_weights = top_weights.to(hidden_states.dtype) mean_probs = routing_weights.mean(dim=0) aux_loss = self.num_experts * (mean_probs * mean_probs).sum() return top_weights, top_indices, aux_loss class SpiderPortalMoE(nn.Module): def __init__(self, config): super().__init__() self.config = config self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)]) self.shared_expert = SpiderPortalExpert(config) self.router = SpiderPortalRouter(config) def forward(self, hidden_states): batch_size, seq_len, hidden_dim = hidden_states.shape top_weights, top_indices, aux_loss = self.router(hidden_states) flat_hidden = hidden_states.view(-1, hidden_dim) final_output = torch.zeros_like(flat_hidden) for expert_idx in range(self.num_experts_per_tok): expert_ids = top_indices[:, expert_idx] expert_weights = top_weights[:, expert_idx:expert_idx+1] for e in range(self.num_experts): mask = expert_ids == e if mask.any(): expert_output = self.experts[e](flat_hidden[mask]) final_output[mask] += expert_output * expert_weights[mask] shared_output = self.shared_expert(flat_hidden) final_output = final_output + shared_output return final_output.view(batch_size, seq_len, hidden_dim), aux_loss class SpiderPortalMoELayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.self_attn = SpiderPortalGQA(config) self.moe = SpiderPortalMoE(config) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output moe_input = self.post_attention_layernorm(hidden_states) moe_output, aux_loss = self.moe(moe_input) hidden_states = hidden_states + moe_output return hidden_states, aux_loss, past_kv class LTIInjection(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) self.delta_t = nn.Parameter(torch.tensor(1.0)) self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False) with torch.no_grad(): self.B.weight.data.normal_(mean=0.0, std=0.01) def get_A(self): return -torch.exp(self.log_A) def forward(self, h_t, e): A = self.get_A() return A * h_t + self.B(e) class ACTHalting(nn.Module): def __init__(self, config): super().__init__() self.halt_predictor = nn.Linear(config.hidden_size, 1) self.threshold = config.act_threshold def forward(self, hidden_states): return torch.sigmoid(self.halt_predictor(hidden_states)) class LoRAAdapter(nn.Module): def __init__(self, config): super().__init__() rank = config.lora_rank self.down = nn.Linear(config.hidden_size, rank, bias=False) self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02) self.scale = nn.Embedding(config.max_loop_iters, rank) with torch.no_grad(): self.scale.weight.data.zero_() self.down.weight.data.normal_(mean=0.0, std=0.001) def forward(self, x, loop_t): max_t = self.scale.num_embeddings - 1 t_idx = min(loop_t, max_t) s = self.scale(torch.tensor(t_idx, device=x.device)) down = self.down(x) * s return down @ self.B class SpiderPortalDenseModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)]) self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.injection = LTIInjection(config) self.act_halting = ACTHalting(config) self.lora_adapter = LoRAAdapter(config) self.loop_embed_dim = config.loop_embed_dim def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None): n_loops = n_loops or self.config.max_loop_iters input_embedding = input_embedding if input_embedding is not None else hidden_states for layer in self.prelude_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) e = hidden_states.clone() B, T_seq, D = hidden_states.shape halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) h_out = torch.zeros_like(hidden_states) past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) for t in range(n_loops): h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) if t > 0: injection = self.injection(hidden_states, input_embedding) hidden_states = hidden_states + injection new_past_key_values = [] for i, layer in enumerate(self.recurrent_layers): hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache) new_past_key_values.append(past_kv) lora_delta = self.lora_adapter(hidden_states, t) hidden_states = hidden_states + lora_delta halt_prob = self.act_halting(hidden_states).squeeze(-1) still_running = ~halted remainder = (1.0 - cumulative_p).clamp(min=0) weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) weight = weight * still_running.to(hidden_states.dtype) h_out = h_out + weight.unsqueeze(-1) * hidden_states cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) halted = halted | (cumulative_p >= self.config.act_threshold) if halted.all() and not self.training: break never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) hidden_states = h_out + never_halted * hidden_states for layer in self.coda_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) hidden_states = self.norm(hidden_states) return hidden_states, 0.0, new_past_key_values class SpiderPortalMoEModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)]) self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.injection = LTIInjection(config) self.act_halting = ACTHalting(config) self.lora_adapter = LoRAAdapter(config) self.loop_embed_dim = config.loop_embed_dim def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None): n_loops = n_loops or self.config.max_loop_iters input_embedding = input_embedding if input_embedding is not None else hidden_states total_aux_loss = 0.0 for layer in self.prelude_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) e = hidden_states.clone() B, T_seq, D = hidden_states.shape halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) h_out = torch.zeros_like(hidden_states) past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) for t in range(n_loops): h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) if t > 0: injection = self.injection(hidden_states, input_embedding) hidden_states = hidden_states + injection new_past_key_values = [] for i, layer in enumerate(self.recurrent_layers): hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache) total_aux_loss = total_aux_loss + aux_loss new_past_key_values.append(past_kv) lora_delta = self.lora_adapter(hidden_states, t) hidden_states = hidden_states + lora_delta halt_prob = self.act_halting(hidden_states).squeeze(-1) still_running = ~halted remainder = (1.0 - cumulative_p).clamp(min=0) weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) weight = weight * still_running.to(hidden_states.dtype) h_out = h_out + weight.unsqueeze(-1) * hidden_states cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) halted = halted | (cumulative_p >= self.config.act_threshold) if halted.all() and not self.training: break never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) hidden_states = h_out + never_halted * hidden_states for layer in self.coda_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) hidden_states = self.norm(hidden_states) return hidden_states, total_aux_loss, new_past_key_values class SpiderPortalForConditionalGeneration(nn.Module): def __init__(self, config, model_class=SpiderPortalDenseModel): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.model = model_class(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): if hasattr(self, 'model') and module is self.model.injection.B: return module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False): hidden_states = self.embed_tokens(input_ids) model_dtype = next(self.model.parameters()).dtype hidden_states = hidden_states.to(model_dtype) input_embedding = hidden_states.clone() if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device) causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min) causal_mask = causal_mask.triu(1) hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops) logits = self.lm_head(hidden_states) return {"loss": None, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv} DEFAULT_PROMPTS = [ "The cat sat on the", "The capital of France is", "If I have 3 apples and eat 1, I have", "Once upon a time, there was a", "Python is a programming language that", "Two plus two equals", "When it rains, the ground gets", "The door opened slowly and", "What is the meaning of life? The", "def fibonacci(n):\n if n <= 1:\n return", ] def load_model(checkpoint_path, device="cpu", model_class=SpiderPortalDenseModel): print(f"Loading checkpoint: {checkpoint_path}") ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) cfg = ckpt.get("cfg") vocab_size = ckpt.get("vocab_size", 50257) if cfg is None: cfg = SpiderPortalConfig( hidden_size=384, num_hidden_layers=8, num_attention_heads=8, num_key_value_heads=2, intermediate_size=1024, num_experts=64, num_experts_per_tok=1, num_shared_experts=1, router_aux_loss_coef=0.05, max_loop_iters=1, prelude_layers=2, coda_layers=2, lora_rank=32, rope_theta=10000000.0, rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768}, max_position_embeddings=131072, sliding_window=4096, tie_word_embeddings=True, ) cfg.vocab_size = vocab_size model_state = ckpt.get("model_state_dict", ckpt) model = SpiderPortalForConditionalGeneration(cfg, model_class=model_class) missing, unexpected = model.load_state_dict(model_state, strict=False) if missing: print(f" Missing keys ({len(missing)}): {missing[:3]}...") if unexpected: print(f" Unexpected keys ({len(unexpected)}): {unexpected[:3]}...") if not missing and not unexpected: print(" All keys matched perfectly") model = model.to(device) model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {n_params:,} on {device}") return model, cfg def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9, device="cpu"): input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) generated = [] with torch.no_grad(): for _ in range(max_new_tokens): outputs = model(input_ids, use_cache=False) logits = outputs["logits"][0, -1, :] if temperature > 0: logits = logits / temperature probs = F.softmax(logits, dim=-1) sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() sorted_indices_to_remove[0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] probs[indices_to_remove] = 0.0 probs = probs / probs.sum() next_token = torch.multinomial(probs, 1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) generated.append(next_token.item()) input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) if next_token.item() == tokenizer.eos_token_id: break return tokenizer.decode(generated, skip_special_tokens=True) def analyze_output(prompt, generated_text): full = prompt + generated_text words = full.split() unique_words = set(w.lower() for w in words) vocab_diversity = len(unique_words) / max(len(words), 1) n = 4 if len(words) >= n: ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] unique_ngrams = set(ngrams) repetition_rate = 1.0 - len(unique_ngrams) / max(len(ngrams), 1) else: repetition_rate = 0.0 has_repetition = False for pattern in ["... ", "!!!", " and and ", " the the ", " is is "]: if pattern in full.lower(): has_repetition = True break english_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ '.,!?;:-\"()") char_ratio = sum(1 for c in generated_text if c in english_chars) / max(len(generated_text), 1) return { "total_words": len(words), "unique_words": len(unique_words), "vocab_diversity": vocab_diversity, "repetition_rate": repetition_rate, "has_obvious_repetition": has_repetition, "english_char_ratio": char_ratio, } def main(): parser = argparse.ArgumentParser(description="Evaluate SpiderPortal Dense vs MoE") parser.add_argument("--dense", required=True, help="Path to dense checkpoint") parser.add_argument("--moe", default=None, help="Path to MoE checkpoint for comparison") parser.add_argument("--prompts", nargs="*", default=None) parser.add_argument("--file", default=None, help="File with prompts") parser.add_argument("--all", action="store_true", help="Run default prompt suite") parser.add_argument("--max-new-tokens", type=int, default=80) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-p", type=float, default=0.9) parser.add_argument("--device", default=None) args = parser.parse_args() device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token prompts = [] if args.all: prompts = DEFAULT_PROMPTS elif args.prompts: prompts = args.prompts elif args.file: with open(args.file) as f: prompts = [line.strip() for line in f if line.strip()] else: prompts = DEFAULT_PROMPTS[:3] dense_model, _ = load_model(args.dense, device, model_class=SpiderPortalDenseModel) moe_model = None if args.moe: print() moe_model, _ = load_model(args.moe, device, model_class=SpiderPortalMoEModel) print(f"\nRunning {len(prompts)} prompts (max_new_tokens={args.max_new_tokens}, temp={args.temperature})\n") print("=" * 80) dense_results = [] moe_results = [] for i, prompt in enumerate(prompts): print(f"\n[Prompt {i+1}/{len(prompts)}]: {prompt}") t0 = time.time() dense_gen = generate(dense_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device) dense_elapsed = time.time() - t0 dense_metrics = analyze_output(prompt, dense_gen) print(f" [DENSE] {dense_gen}") print(f" vocab_div={dense_metrics['vocab_diversity']:.2f}, " f"repetition={dense_metrics['repetition_rate']:.2f}, " f"english={dense_metrics['english_char_ratio']:.2f}, " f"tok/s={args.max_new_tokens/max(dense_elapsed,0.001):.1f}") if moe_model: t0 = time.time() moe_gen = generate(moe_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device) moe_elapsed = time.time() - t0 moe_metrics = analyze_output(prompt, moe_gen) print(f" [MoE ] {moe_gen}") print(f" vocab_div={moe_metrics['vocab_diversity']:.2f}, " f"repetition={moe_metrics['repetition_rate']:.2f}, " f"english={moe_metrics['english_char_ratio']:.2f}, " f"tok/s={args.max_new_tokens/max(moe_elapsed,0.001):.1f}") moe_results.append({"prompt": prompt, "generated": moe_gen, "metrics": moe_metrics}) dense_results.append({"prompt": prompt, "generated": dense_gen, "metrics": dense_metrics}) print("\n" + "=" * 80) print("SUMMARY") print("=" * 80) def print_summary(label, results): avg_vocab = sum(r["metrics"]["vocab_diversity"] for r in results) / len(results) avg_rep = sum(r["metrics"]["repetition_rate"] for r in results) / len(results) avg_eng = sum(r["metrics"]["english_char_ratio"] for r in results) / len(results) total_rep = sum(1 for r in results if r["metrics"]["has_obvious_repetition"]) print(f"\n{label}:") print(f" Vocab diversity: {avg_vocab:.2f}") print(f" Repetition rate: {avg_rep:.2f}") print(f" English chars: {avg_eng:.2f}") print(f" Repetition hits: {total_rep}/{len(results)}") print_summary("DENSE", dense_results) if moe_results: print_summary("MoE ", moe_results) print("\nComparison:") d_vocab = sum(r["metrics"]["vocab_diversity"] for r in dense_results) / len(dense_results) m_vocab = sum(r["metrics"]["vocab_diversity"] for r in moe_results) / len(moe_results) d_eng = sum(r["metrics"]["english_char_ratio"] for r in dense_results) / len(dense_results) m_eng = sum(r["metrics"]["english_char_ratio"] for r in moe_results) / len(moe_results) if d_vocab > m_vocab: print(f" Dense has better vocabulary diversity (+{d_vocab - m_vocab:.2f})") else: print(f" MoE has better vocabulary diversity (+{m_vocab - d_vocab:.2f})") if d_eng > m_eng: print(f" Dense produces more English-like text (+{d_eng - m_eng:.2f})") else: print(f" MoE produces more English-like text (+{m_eng - d_eng:.2f})") if __name__ == "__main__": main()