at / eval_dense.py
CLIWorks's picture
Upload 2 files
2ae4a50 verified
raw
history blame
33.2 kB
#!/usr/bin/env python3
"""
Evaluate SpiderPortal v5-Dense checkpoint with side-by-side MoE comparison.
Usage:
python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-final-ep1.pt --moe checkpoints/spiderportal-v5-final-ep1.pt --all
python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-ep1-step1000.pt --prompts "The cat sat on the"
"""
import argparse
import math
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from transformers import AutoTokenizer
@dataclass
class SpiderPortalConfig:
vocab_size: int = 50257
hidden_size: int = 384
num_hidden_layers: int = 8
num_attention_heads: int = 8
num_key_value_heads: int = 2
intermediate_size: int = 1024
num_experts: int = 64
num_experts_per_tok: int = 1
router_aux_loss_coef: float = 0.05
max_loop_iters: int = 1
act_threshold: float = 0.5
max_position_embeddings: int = 131072
rope_theta: float = 10000000.0
rope_scaling: dict = None
sliding_window: int = 4096
attention_dropout: float = 0.0
rms_norm_eps: float = 1e-6
initializer_range: float = 0.02
tie_word_embeddings: bool = True
prelude_layers: int = 2
coda_layers: int = 2
lora_rank: int = 32
loop_embed_dim: int = 48
def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
angles = loop_t * freqs
emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
emb_full[:loop_dim] = emb
return h + emb_full.unsqueeze(0).unsqueeze(0)
def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
dim = head_dim
orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
pos_freqs = torch.arange(0, dim, 2).float() / dim
beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
return orig_inv_freq * scale
class SpiderPortalRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
class SpiderPortalGQA(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.attention_dropout = config.attention_dropout
rope_scaling = getattr(config, 'rope_scaling', None)
if rope_scaling and rope_scaling.get("type") == "yarn":
factor = rope_scaling.get("factor", 1.0)
orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
else:
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _rotate_half(self, x):
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary(self, x, cos, sin):
return (x * cos) + (self._rotate_half(x) * sin)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
if position_ids is None:
position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
max_pos = position_ids.max().item() + 1
seq_len = max(max_pos, q_len)
t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos, sin = emb.cos(), emb.sin()
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
query_states = self._apply_rotary(query_states, cos, sin)
key_states = self._apply_rotary(key_states, cos, sin)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_kv = (key_states, value_states) if use_cache else None
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=attention_mask is None
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
return self.o_proj(attn_output), past_kv
class SpiderPortalExpert(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
inter_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, hidden_states):
return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
class SpiderPortalDenseLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = SpiderPortalGQA(config)
dense_intermediate = config.hidden_size * 4 // 3
self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
ffn_input = self.post_attention_layernorm(hidden_states)
ffn_output = self.ffn(ffn_input)
hidden_states = hidden_states + ffn_output
return hidden_states, past_kv
class SpiderPortalRecurrentDenseLayer(nn.Module):
"""Dense recurrent layer — matches checkpoint keys."""
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.self_attn = SpiderPortalGQA(config)
self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
ffn_input = self.post_attention_layernorm(hidden_states)
ffn_output = self.ffn(ffn_input)
hidden_states = hidden_states + ffn_output
return hidden_states, 0.0, past_kv
# MoE layer for comparison model
class SpiderPortalRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
self.register_buffer("router_bias", torch.zeros(config.num_experts))
def forward(self, hidden_states):
router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
biased_logits = router_logits + self.router_bias
biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
top_weights = top_weights.to(hidden_states.dtype)
mean_probs = routing_weights.mean(dim=0)
aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
return top_weights, top_indices, aux_loss
class SpiderPortalMoE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
self.shared_expert = SpiderPortalExpert(config)
self.router = SpiderPortalRouter(config)
def forward(self, hidden_states):
batch_size, seq_len, hidden_dim = hidden_states.shape
top_weights, top_indices, aux_loss = self.router(hidden_states)
flat_hidden = hidden_states.view(-1, hidden_dim)
final_output = torch.zeros_like(flat_hidden)
for expert_idx in range(self.num_experts_per_tok):
expert_ids = top_indices[:, expert_idx]
expert_weights = top_weights[:, expert_idx:expert_idx+1]
for e in range(self.num_experts):
mask = expert_ids == e
if mask.any():
expert_output = self.experts[e](flat_hidden[mask])
final_output[mask] += expert_output * expert_weights[mask]
shared_output = self.shared_expert(flat_hidden)
final_output = final_output + shared_output
return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
class SpiderPortalMoELayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.self_attn = SpiderPortalGQA(config)
self.moe = SpiderPortalMoE(config)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
moe_input = self.post_attention_layernorm(hidden_states)
moe_output, aux_loss = self.moe(moe_input)
hidden_states = hidden_states + moe_output
return hidden_states, aux_loss, past_kv
class LTIInjection(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
self.delta_t = nn.Parameter(torch.tensor(1.0))
self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
with torch.no_grad():
self.B.weight.data.normal_(mean=0.0, std=0.01)
def get_A(self):
return -torch.exp(self.log_A)
def forward(self, h_t, e):
A = self.get_A()
return A * h_t + self.B(e)
class ACTHalting(nn.Module):
def __init__(self, config):
super().__init__()
self.halt_predictor = nn.Linear(config.hidden_size, 1)
self.threshold = config.act_threshold
def forward(self, hidden_states):
return torch.sigmoid(self.halt_predictor(hidden_states))
class LoRAAdapter(nn.Module):
def __init__(self, config):
super().__init__()
rank = config.lora_rank
self.down = nn.Linear(config.hidden_size, rank, bias=False)
self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
self.scale = nn.Embedding(config.max_loop_iters, rank)
with torch.no_grad():
self.scale.weight.data.zero_()
self.down.weight.data.normal_(mean=0.0, std=0.001)
def forward(self, x, loop_t):
max_t = self.scale.num_embeddings - 1
t_idx = min(loop_t, max_t)
s = self.scale(torch.tensor(t_idx, device=x.device))
down = self.down(x) * s
return down @ self.B
class SpiderPortalDenseModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)])
self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.injection = LTIInjection(config)
self.act_halting = ACTHalting(config)
self.lora_adapter = LoRAAdapter(config)
self.loop_embed_dim = config.loop_embed_dim
def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
n_loops = n_loops or self.config.max_loop_iters
input_embedding = input_embedding if input_embedding is not None else hidden_states
for layer in self.prelude_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
e = hidden_states.clone()
B, T_seq, D = hidden_states.shape
halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
h_out = torch.zeros_like(hidden_states)
past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
for t in range(n_loops):
h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
if t > 0:
injection = self.injection(hidden_states, input_embedding)
hidden_states = hidden_states + injection
new_past_key_values = []
for i, layer in enumerate(self.recurrent_layers):
hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
new_past_key_values.append(past_kv)
lora_delta = self.lora_adapter(hidden_states, t)
hidden_states = hidden_states + lora_delta
halt_prob = self.act_halting(hidden_states).squeeze(-1)
still_running = ~halted
remainder = (1.0 - cumulative_p).clamp(min=0)
weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
weight = weight * still_running.to(hidden_states.dtype)
h_out = h_out + weight.unsqueeze(-1) * hidden_states
cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
halted = halted | (cumulative_p >= self.config.act_threshold)
if halted.all() and not self.training:
break
never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
hidden_states = h_out + never_halted * hidden_states
for layer in self.coda_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
hidden_states = self.norm(hidden_states)
return hidden_states, 0.0, new_past_key_values
class SpiderPortalMoEModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.injection = LTIInjection(config)
self.act_halting = ACTHalting(config)
self.lora_adapter = LoRAAdapter(config)
self.loop_embed_dim = config.loop_embed_dim
def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
n_loops = n_loops or self.config.max_loop_iters
input_embedding = input_embedding if input_embedding is not None else hidden_states
total_aux_loss = 0.0
for layer in self.prelude_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
e = hidden_states.clone()
B, T_seq, D = hidden_states.shape
halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
h_out = torch.zeros_like(hidden_states)
past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
for t in range(n_loops):
h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
if t > 0:
injection = self.injection(hidden_states, input_embedding)
hidden_states = hidden_states + injection
new_past_key_values = []
for i, layer in enumerate(self.recurrent_layers):
hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
total_aux_loss = total_aux_loss + aux_loss
new_past_key_values.append(past_kv)
lora_delta = self.lora_adapter(hidden_states, t)
hidden_states = hidden_states + lora_delta
halt_prob = self.act_halting(hidden_states).squeeze(-1)
still_running = ~halted
remainder = (1.0 - cumulative_p).clamp(min=0)
weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
weight = weight * still_running.to(hidden_states.dtype)
h_out = h_out + weight.unsqueeze(-1) * hidden_states
cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
halted = halted | (cumulative_p >= self.config.act_threshold)
if halted.all() and not self.training:
break
never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
hidden_states = h_out + never_halted * hidden_states
for layer in self.coda_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
hidden_states = self.norm(hidden_states)
return hidden_states, total_aux_loss, new_past_key_values
class SpiderPortalForConditionalGeneration(nn.Module):
def __init__(self, config, model_class=SpiderPortalDenseModel):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.model = model_class(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
if hasattr(self, 'model') and module is self.model.injection.B:
return
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
hidden_states = self.embed_tokens(input_ids)
model_dtype = next(self.model.parameters()).dtype
hidden_states = hidden_states.to(model_dtype)
input_embedding = hidden_states.clone()
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
causal_mask = causal_mask.triu(1)
hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
logits = self.lm_head(hidden_states)
return {"loss": None, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
DEFAULT_PROMPTS = [
"The cat sat on the",
"The capital of France is",
"If I have 3 apples and eat 1, I have",
"Once upon a time, there was a",
"Python is a programming language that",
"Two plus two equals",
"When it rains, the ground gets",
"The door opened slowly and",
"What is the meaning of life? The",
"def fibonacci(n):\n if n <= 1:\n return",
]
def load_model(checkpoint_path, device="cpu", model_class=SpiderPortalDenseModel):
print(f"Loading checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
cfg = ckpt.get("cfg")
vocab_size = ckpt.get("vocab_size", 50257)
if cfg is None:
cfg = SpiderPortalConfig(
hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
num_key_value_heads=2, intermediate_size=1024,
num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
router_aux_loss_coef=0.05, max_loop_iters=1,
prelude_layers=2, coda_layers=2, lora_rank=32,
rope_theta=10000000.0,
rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
max_position_embeddings=131072, sliding_window=4096,
tie_word_embeddings=True,
)
cfg.vocab_size = vocab_size
model_state = ckpt.get("model_state_dict", ckpt)
model = SpiderPortalForConditionalGeneration(cfg, model_class=model_class)
missing, unexpected = model.load_state_dict(model_state, strict=False)
if missing:
print(f" Missing keys ({len(missing)}): {missing[:3]}...")
if unexpected:
print(f" Unexpected keys ({len(unexpected)}): {unexpected[:3]}...")
if not missing and not unexpected:
print(" All keys matched perfectly")
model = model.to(device)
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {n_params:,} on {device}")
return model, cfg
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9, device="cpu"):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated = []
with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model(input_ids, use_cache=False)
logits = outputs["logits"][0, -1, :]
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
probs[indices_to_remove] = 0.0
probs = probs / probs.sum()
next_token = torch.multinomial(probs, 1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
generated.append(next_token.item())
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(generated, skip_special_tokens=True)
def analyze_output(prompt, generated_text):
full = prompt + generated_text
words = full.split()
unique_words = set(w.lower() for w in words)
vocab_diversity = len(unique_words) / max(len(words), 1)
n = 4
if len(words) >= n:
ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
unique_ngrams = set(ngrams)
repetition_rate = 1.0 - len(unique_ngrams) / max(len(ngrams), 1)
else:
repetition_rate = 0.0
has_repetition = False
for pattern in ["... ", "!!!", " and and ", " the the ", " is is "]:
if pattern in full.lower():
has_repetition = True
break
english_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ '.,!?;:-\"()")
char_ratio = sum(1 for c in generated_text if c in english_chars) / max(len(generated_text), 1)
return {
"total_words": len(words),
"unique_words": len(unique_words),
"vocab_diversity": vocab_diversity,
"repetition_rate": repetition_rate,
"has_obvious_repetition": has_repetition,
"english_char_ratio": char_ratio,
}
def main():
parser = argparse.ArgumentParser(description="Evaluate SpiderPortal Dense vs MoE")
parser.add_argument("--dense", required=True, help="Path to dense checkpoint")
parser.add_argument("--moe", default=None, help="Path to MoE checkpoint for comparison")
parser.add_argument("--prompts", nargs="*", default=None)
parser.add_argument("--file", default=None, help="File with prompts")
parser.add_argument("--all", action="store_true", help="Run default prompt suite")
parser.add_argument("--max-new-tokens", type=int, default=80)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--device", default=None)
args = parser.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
prompts = []
if args.all:
prompts = DEFAULT_PROMPTS
elif args.prompts:
prompts = args.prompts
elif args.file:
with open(args.file) as f:
prompts = [line.strip() for line in f if line.strip()]
else:
prompts = DEFAULT_PROMPTS[:3]
dense_model, _ = load_model(args.dense, device, model_class=SpiderPortalDenseModel)
moe_model = None
if args.moe:
print()
moe_model, _ = load_model(args.moe, device, model_class=SpiderPortalMoEModel)
print(f"\nRunning {len(prompts)} prompts (max_new_tokens={args.max_new_tokens}, temp={args.temperature})\n")
print("=" * 80)
dense_results = []
moe_results = []
for i, prompt in enumerate(prompts):
print(f"\n[Prompt {i+1}/{len(prompts)}]: {prompt}")
t0 = time.time()
dense_gen = generate(dense_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device)
dense_elapsed = time.time() - t0
dense_metrics = analyze_output(prompt, dense_gen)
print(f" [DENSE] {dense_gen}")
print(f" vocab_div={dense_metrics['vocab_diversity']:.2f}, "
f"repetition={dense_metrics['repetition_rate']:.2f}, "
f"english={dense_metrics['english_char_ratio']:.2f}, "
f"tok/s={args.max_new_tokens/max(dense_elapsed,0.001):.1f}")
if moe_model:
t0 = time.time()
moe_gen = generate(moe_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device)
moe_elapsed = time.time() - t0
moe_metrics = analyze_output(prompt, moe_gen)
print(f" [MoE ] {moe_gen}")
print(f" vocab_div={moe_metrics['vocab_diversity']:.2f}, "
f"repetition={moe_metrics['repetition_rate']:.2f}, "
f"english={moe_metrics['english_char_ratio']:.2f}, "
f"tok/s={args.max_new_tokens/max(moe_elapsed,0.001):.1f}")
moe_results.append({"prompt": prompt, "generated": moe_gen, "metrics": moe_metrics})
dense_results.append({"prompt": prompt, "generated": dense_gen, "metrics": dense_metrics})
print("\n" + "=" * 80)
print("SUMMARY")
print("=" * 80)
def print_summary(label, results):
avg_vocab = sum(r["metrics"]["vocab_diversity"] for r in results) / len(results)
avg_rep = sum(r["metrics"]["repetition_rate"] for r in results) / len(results)
avg_eng = sum(r["metrics"]["english_char_ratio"] for r in results) / len(results)
total_rep = sum(1 for r in results if r["metrics"]["has_obvious_repetition"])
print(f"\n{label}:")
print(f" Vocab diversity: {avg_vocab:.2f}")
print(f" Repetition rate: {avg_rep:.2f}")
print(f" English chars: {avg_eng:.2f}")
print(f" Repetition hits: {total_rep}/{len(results)}")
print_summary("DENSE", dense_results)
if moe_results:
print_summary("MoE ", moe_results)
print("\nComparison:")
d_vocab = sum(r["metrics"]["vocab_diversity"] for r in dense_results) / len(dense_results)
m_vocab = sum(r["metrics"]["vocab_diversity"] for r in moe_results) / len(moe_results)
d_eng = sum(r["metrics"]["english_char_ratio"] for r in dense_results) / len(dense_results)
m_eng = sum(r["metrics"]["english_char_ratio"] for r in moe_results) / len(moe_results)
if d_vocab > m_vocab:
print(f" Dense has better vocabulary diversity (+{d_vocab - m_vocab:.2f})")
else:
print(f" MoE has better vocabulary diversity (+{m_vocab - d_vocab:.2f})")
if d_eng > m_eng:
print(f" Dense produces more English-like text (+{d_eng - m_eng:.2f})")
else:
print(f" MoE produces more English-like text (+{m_eng - d_eng:.2f})")
if __name__ == "__main__":
main()