| |
| """ |
| Evaluate SpiderPortal v5-Dense checkpoint with side-by-side MoE comparison. |
| |
| Usage: |
| python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-final-ep1.pt --moe checkpoints/spiderportal-v5-final-ep1.pt --all |
| python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-ep1-step1000.pt --prompts "The cat sat on the" |
| """ |
|
|
| import argparse |
| import math |
| import sys |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from transformers import AutoTokenizer |
|
|
|
|
| @dataclass |
| class SpiderPortalConfig: |
| vocab_size: int = 50257 |
| hidden_size: int = 384 |
| num_hidden_layers: int = 8 |
| num_attention_heads: int = 8 |
| num_key_value_heads: int = 2 |
| intermediate_size: int = 1024 |
| num_experts: int = 64 |
| num_experts_per_tok: int = 1 |
| router_aux_loss_coef: float = 0.05 |
| max_loop_iters: int = 1 |
| act_threshold: float = 0.5 |
| max_position_embeddings: int = 131072 |
| rope_theta: float = 10000000.0 |
| rope_scaling: dict = None |
| sliding_window: int = 4096 |
| attention_dropout: float = 0.0 |
| rms_norm_eps: float = 1e-6 |
| initializer_range: float = 0.02 |
| tie_word_embeddings: bool = True |
| prelude_layers: int = 2 |
| coda_layers: int = 2 |
| lora_rank: int = 32 |
| loop_embed_dim: int = 48 |
|
|
|
|
| def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) |
| angles = loop_t * freqs |
| emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] |
| emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) |
| emb_full[:loop_dim] = emb |
| return h + emb_full.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0): |
| dim = head_dim |
| orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim)) |
| pos_freqs = torch.arange(0, dim, 2).float() / dim |
| beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max)) |
| scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor))) |
| return orig_inv_freq * scale |
|
|
|
|
| class SpiderPortalRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight.to(input_dtype) * hidden_states.to(input_dtype) |
|
|
|
|
| class SpiderPortalGQA(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.num_kv_heads = config.num_key_value_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.num_key_value_groups = self.num_heads // self.num_kv_heads |
| self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| self.attention_dropout = config.attention_dropout |
| rope_scaling = getattr(config, 'rope_scaling', None) |
| if rope_scaling and rope_scaling.get("type") == "yarn": |
| factor = rope_scaling.get("factor", 1.0) |
| orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings) |
| inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos) |
| else: |
| inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| def _rotate_half(self, x): |
| x1 = x[..., :x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
| def _apply_rotary(self, x, cos, sin): |
| return (x * cos) + (self._rotate_half(x) * sin) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| bsz, q_len, _ = hidden_states.size() |
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| if position_ids is None: |
| position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1) |
| max_pos = position_ids.max().item() + 1 |
| seq_len = max(max_pos, q_len) |
| t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos, sin = emb.cos(), emb.sin() |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
| query_states = self._apply_rotary(query_states, cos, sin) |
| key_states = self._apply_rotary(key_states, cos, sin) |
| if past_key_value is not None: |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| past_kv = (key_states, value_states) if use_cache else None |
| key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
| value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
| attn_output = F.scaled_dot_product_attention( |
| query_states, key_states, value_states, |
| attn_mask=attention_mask, |
| dropout_p=self.attention_dropout if self.training else 0.0, |
| is_causal=attention_mask is None |
| ) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| return self.o_proj(attn_output), past_kv |
|
|
|
|
| class SpiderPortalExpert(nn.Module): |
| def __init__(self, config, intermediate_size=None): |
| super().__init__() |
| inter_size = intermediate_size or config.intermediate_size |
| self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False) |
| self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False) |
| self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False) |
| self.act_fn = nn.SiLU() |
| def forward(self, hidden_states): |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) |
|
|
|
|
| class SpiderPortalDenseLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self_attn = SpiderPortalGQA(config) |
| dense_intermediate = config.hidden_size * 4 // 3 |
| self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
| ffn_input = self.post_attention_layernorm(hidden_states) |
| ffn_output = self.ffn(ffn_input) |
| hidden_states = hidden_states + ffn_output |
| return hidden_states, past_kv |
|
|
|
|
| class SpiderPortalRecurrentDenseLayer(nn.Module): |
| """Dense recurrent layer — matches checkpoint keys.""" |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.self_attn = SpiderPortalGQA(config) |
| self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
| ffn_input = self.post_attention_layernorm(hidden_states) |
| ffn_output = self.ffn(ffn_input) |
| hidden_states = hidden_states + ffn_output |
| return hidden_states, 0.0, past_kv |
|
|
|
|
| |
| class SpiderPortalRouter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.num_experts = config.num_experts |
| self.num_experts_per_tok = config.num_experts_per_tok |
| self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range) |
| self.register_buffer("router_bias", torch.zeros(config.num_experts)) |
| def forward(self, hidden_states): |
| router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight |
| routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) |
| biased_logits = router_logits + self.router_bias |
| biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32) |
| top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1) |
| top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) |
| top_weights = top_weights.to(hidden_states.dtype) |
| mean_probs = routing_weights.mean(dim=0) |
| aux_loss = self.num_experts * (mean_probs * mean_probs).sum() |
| return top_weights, top_indices, aux_loss |
|
|
|
|
| class SpiderPortalMoE(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.num_experts = config.num_experts |
| self.num_experts_per_tok = config.num_experts_per_tok |
| self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)]) |
| self.shared_expert = SpiderPortalExpert(config) |
| self.router = SpiderPortalRouter(config) |
| def forward(self, hidden_states): |
| batch_size, seq_len, hidden_dim = hidden_states.shape |
| top_weights, top_indices, aux_loss = self.router(hidden_states) |
| flat_hidden = hidden_states.view(-1, hidden_dim) |
| final_output = torch.zeros_like(flat_hidden) |
| for expert_idx in range(self.num_experts_per_tok): |
| expert_ids = top_indices[:, expert_idx] |
| expert_weights = top_weights[:, expert_idx:expert_idx+1] |
| for e in range(self.num_experts): |
| mask = expert_ids == e |
| if mask.any(): |
| expert_output = self.experts[e](flat_hidden[mask]) |
| final_output[mask] += expert_output * expert_weights[mask] |
| shared_output = self.shared_expert(flat_hidden) |
| final_output = final_output + shared_output |
| return final_output.view(batch_size, seq_len, hidden_dim), aux_loss |
|
|
|
|
| class SpiderPortalMoELayer(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.self_attn = SpiderPortalGQA(config) |
| self.moe = SpiderPortalMoE(config) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
| moe_input = self.post_attention_layernorm(hidden_states) |
| moe_output, aux_loss = self.moe(moe_input) |
| hidden_states = hidden_states + moe_output |
| return hidden_states, aux_loss, past_kv |
|
|
|
|
| class LTIInjection(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) |
| self.delta_t = nn.Parameter(torch.tensor(1.0)) |
| self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| with torch.no_grad(): |
| self.B.weight.data.normal_(mean=0.0, std=0.01) |
| def get_A(self): |
| return -torch.exp(self.log_A) |
| def forward(self, h_t, e): |
| A = self.get_A() |
| return A * h_t + self.B(e) |
|
|
|
|
| class ACTHalting(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.halt_predictor = nn.Linear(config.hidden_size, 1) |
| self.threshold = config.act_threshold |
| def forward(self, hidden_states): |
| return torch.sigmoid(self.halt_predictor(hidden_states)) |
|
|
|
|
| class LoRAAdapter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| rank = config.lora_rank |
| self.down = nn.Linear(config.hidden_size, rank, bias=False) |
| self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02) |
| self.scale = nn.Embedding(config.max_loop_iters, rank) |
| with torch.no_grad(): |
| self.scale.weight.data.zero_() |
| self.down.weight.data.normal_(mean=0.0, std=0.001) |
| def forward(self, x, loop_t): |
| max_t = self.scale.num_embeddings - 1 |
| t_idx = min(loop_t, max_t) |
| s = self.scale(torch.tensor(t_idx, device=x.device)) |
| down = self.down(x) * s |
| return down @ self.B |
|
|
|
|
| class SpiderPortalDenseModel(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) |
| self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)]) |
| self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) |
| self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.injection = LTIInjection(config) |
| self.act_halting = ACTHalting(config) |
| self.lora_adapter = LoRAAdapter(config) |
| self.loop_embed_dim = config.loop_embed_dim |
| def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None): |
| n_loops = n_loops or self.config.max_loop_iters |
| input_embedding = input_embedding if input_embedding is not None else hidden_states |
| for layer in self.prelude_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| e = hidden_states.clone() |
| B, T_seq, D = hidden_states.shape |
| halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) |
| cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) |
| h_out = torch.zeros_like(hidden_states) |
| past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) |
| for t in range(n_loops): |
| h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) |
| if t > 0: |
| injection = self.injection(hidden_states, input_embedding) |
| hidden_states = hidden_states + injection |
| new_past_key_values = [] |
| for i, layer in enumerate(self.recurrent_layers): |
| hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache) |
| new_past_key_values.append(past_kv) |
| lora_delta = self.lora_adapter(hidden_states, t) |
| hidden_states = hidden_states + lora_delta |
| halt_prob = self.act_halting(hidden_states).squeeze(-1) |
| still_running = ~halted |
| remainder = (1.0 - cumulative_p).clamp(min=0) |
| weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) |
| weight = weight * still_running.to(hidden_states.dtype) |
| h_out = h_out + weight.unsqueeze(-1) * hidden_states |
| cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) |
| halted = halted | (cumulative_p >= self.config.act_threshold) |
| if halted.all() and not self.training: |
| break |
| never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) |
| hidden_states = h_out + never_halted * hidden_states |
| for layer in self.coda_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| hidden_states = self.norm(hidden_states) |
| return hidden_states, 0.0, new_past_key_values |
|
|
|
|
| class SpiderPortalMoEModel(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) |
| self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)]) |
| self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) |
| self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.injection = LTIInjection(config) |
| self.act_halting = ACTHalting(config) |
| self.lora_adapter = LoRAAdapter(config) |
| self.loop_embed_dim = config.loop_embed_dim |
| def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None): |
| n_loops = n_loops or self.config.max_loop_iters |
| input_embedding = input_embedding if input_embedding is not None else hidden_states |
| total_aux_loss = 0.0 |
| for layer in self.prelude_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| e = hidden_states.clone() |
| B, T_seq, D = hidden_states.shape |
| halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) |
| cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) |
| h_out = torch.zeros_like(hidden_states) |
| past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) |
| for t in range(n_loops): |
| h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) |
| if t > 0: |
| injection = self.injection(hidden_states, input_embedding) |
| hidden_states = hidden_states + injection |
| new_past_key_values = [] |
| for i, layer in enumerate(self.recurrent_layers): |
| hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache) |
| total_aux_loss = total_aux_loss + aux_loss |
| new_past_key_values.append(past_kv) |
| lora_delta = self.lora_adapter(hidden_states, t) |
| hidden_states = hidden_states + lora_delta |
| halt_prob = self.act_halting(hidden_states).squeeze(-1) |
| still_running = ~halted |
| remainder = (1.0 - cumulative_p).clamp(min=0) |
| weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) |
| weight = weight * still_running.to(hidden_states.dtype) |
| h_out = h_out + weight.unsqueeze(-1) * hidden_states |
| cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) |
| halted = halted | (cumulative_p >= self.config.act_threshold) |
| if halted.all() and not self.training: |
| break |
| never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) |
| hidden_states = h_out + never_halted * hidden_states |
| for layer in self.coda_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| hidden_states = self.norm(hidden_states) |
| return hidden_states, total_aux_loss, new_past_key_values |
|
|
|
|
| class SpiderPortalForConditionalGeneration(nn.Module): |
| def __init__(self, config, model_class=SpiderPortalDenseModel): |
| super().__init__() |
| self.config = config |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.model = model_class(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.embed_tokens.weight |
| self.apply(self._init_weights) |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| if hasattr(self, 'model') and module is self.model.injection.B: |
| return |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False): |
| hidden_states = self.embed_tokens(input_ids) |
| model_dtype = next(self.model.parameters()).dtype |
| hidden_states = hidden_states.to(model_dtype) |
| input_embedding = hidden_states.clone() |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device) |
| causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min) |
| causal_mask = causal_mask.triu(1) |
| hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops) |
| logits = self.lm_head(hidden_states) |
| return {"loss": None, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv} |
|
|
|
|
| DEFAULT_PROMPTS = [ |
| "The cat sat on the", |
| "The capital of France is", |
| "If I have 3 apples and eat 1, I have", |
| "Once upon a time, there was a", |
| "Python is a programming language that", |
| "Two plus two equals", |
| "When it rains, the ground gets", |
| "The door opened slowly and", |
| "What is the meaning of life? The", |
| "def fibonacci(n):\n if n <= 1:\n return", |
| ] |
|
|
|
|
| def load_model(checkpoint_path, device="cpu", model_class=SpiderPortalDenseModel): |
| print(f"Loading checkpoint: {checkpoint_path}") |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
|
|
| cfg = ckpt.get("cfg") |
| vocab_size = ckpt.get("vocab_size", 50257) |
|
|
| if cfg is None: |
| cfg = SpiderPortalConfig( |
| hidden_size=384, num_hidden_layers=8, num_attention_heads=8, |
| num_key_value_heads=2, intermediate_size=1024, |
| num_experts=64, num_experts_per_tok=1, num_shared_experts=1, |
| router_aux_loss_coef=0.05, max_loop_iters=1, |
| prelude_layers=2, coda_layers=2, lora_rank=32, |
| rope_theta=10000000.0, |
| rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768}, |
| max_position_embeddings=131072, sliding_window=4096, |
| tie_word_embeddings=True, |
| ) |
| cfg.vocab_size = vocab_size |
|
|
| model_state = ckpt.get("model_state_dict", ckpt) |
| model = SpiderPortalForConditionalGeneration(cfg, model_class=model_class) |
|
|
| missing, unexpected = model.load_state_dict(model_state, strict=False) |
| if missing: |
| print(f" Missing keys ({len(missing)}): {missing[:3]}...") |
| if unexpected: |
| print(f" Unexpected keys ({len(unexpected)}): {unexpected[:3]}...") |
| if not missing and not unexpected: |
| print(" All keys matched perfectly") |
|
|
| model = model.to(device) |
| model.eval() |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {n_params:,} on {device}") |
|
|
| return model, cfg |
|
|
|
|
| def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9, device="cpu"): |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
|
|
| generated = [] |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| outputs = model(input_ids, use_cache=False) |
| logits = outputs["logits"][0, -1, :] |
|
|
| if temperature > 0: |
| logits = logits / temperature |
| probs = F.softmax(logits, dim=-1) |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| sorted_indices_to_remove = cumulative_probs > top_p |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
| sorted_indices_to_remove[0] = False |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| probs[indices_to_remove] = 0.0 |
| probs = probs / probs.sum() |
| next_token = torch.multinomial(probs, 1) |
| else: |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| generated.append(next_token.item()) |
| input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) |
|
|
| if next_token.item() == tokenizer.eos_token_id: |
| break |
|
|
| return tokenizer.decode(generated, skip_special_tokens=True) |
|
|
|
|
| def analyze_output(prompt, generated_text): |
| full = prompt + generated_text |
| words = full.split() |
| unique_words = set(w.lower() for w in words) |
| vocab_diversity = len(unique_words) / max(len(words), 1) |
|
|
| n = 4 |
| if len(words) >= n: |
| ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] |
| unique_ngrams = set(ngrams) |
| repetition_rate = 1.0 - len(unique_ngrams) / max(len(ngrams), 1) |
| else: |
| repetition_rate = 0.0 |
|
|
| has_repetition = False |
| for pattern in ["... ", "!!!", " and and ", " the the ", " is is "]: |
| if pattern in full.lower(): |
| has_repetition = True |
| break |
|
|
| english_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ '.,!?;:-\"()") |
| char_ratio = sum(1 for c in generated_text if c in english_chars) / max(len(generated_text), 1) |
|
|
| return { |
| "total_words": len(words), |
| "unique_words": len(unique_words), |
| "vocab_diversity": vocab_diversity, |
| "repetition_rate": repetition_rate, |
| "has_obvious_repetition": has_repetition, |
| "english_char_ratio": char_ratio, |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate SpiderPortal Dense vs MoE") |
| parser.add_argument("--dense", required=True, help="Path to dense checkpoint") |
| parser.add_argument("--moe", default=None, help="Path to MoE checkpoint for comparison") |
| parser.add_argument("--prompts", nargs="*", default=None) |
| parser.add_argument("--file", default=None, help="File with prompts") |
| parser.add_argument("--all", action="store_true", help="Run default prompt suite") |
| parser.add_argument("--max-new-tokens", type=int, default=80) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top-p", type=float, default=0.9) |
| parser.add_argument("--device", default=None) |
| args = parser.parse_args() |
|
|
| device = args.device or ("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
|
|
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| prompts = [] |
| if args.all: |
| prompts = DEFAULT_PROMPTS |
| elif args.prompts: |
| prompts = args.prompts |
| elif args.file: |
| with open(args.file) as f: |
| prompts = [line.strip() for line in f if line.strip()] |
| else: |
| prompts = DEFAULT_PROMPTS[:3] |
|
|
| dense_model, _ = load_model(args.dense, device, model_class=SpiderPortalDenseModel) |
|
|
| moe_model = None |
| if args.moe: |
| print() |
| moe_model, _ = load_model(args.moe, device, model_class=SpiderPortalMoEModel) |
|
|
| print(f"\nRunning {len(prompts)} prompts (max_new_tokens={args.max_new_tokens}, temp={args.temperature})\n") |
| print("=" * 80) |
|
|
| dense_results = [] |
| moe_results = [] |
|
|
| for i, prompt in enumerate(prompts): |
| print(f"\n[Prompt {i+1}/{len(prompts)}]: {prompt}") |
|
|
| t0 = time.time() |
| dense_gen = generate(dense_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device) |
| dense_elapsed = time.time() - t0 |
| dense_metrics = analyze_output(prompt, dense_gen) |
|
|
| print(f" [DENSE] {dense_gen}") |
| print(f" vocab_div={dense_metrics['vocab_diversity']:.2f}, " |
| f"repetition={dense_metrics['repetition_rate']:.2f}, " |
| f"english={dense_metrics['english_char_ratio']:.2f}, " |
| f"tok/s={args.max_new_tokens/max(dense_elapsed,0.001):.1f}") |
|
|
| if moe_model: |
| t0 = time.time() |
| moe_gen = generate(moe_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device) |
| moe_elapsed = time.time() - t0 |
| moe_metrics = analyze_output(prompt, moe_gen) |
|
|
| print(f" [MoE ] {moe_gen}") |
| print(f" vocab_div={moe_metrics['vocab_diversity']:.2f}, " |
| f"repetition={moe_metrics['repetition_rate']:.2f}, " |
| f"english={moe_metrics['english_char_ratio']:.2f}, " |
| f"tok/s={args.max_new_tokens/max(moe_elapsed,0.001):.1f}") |
|
|
| moe_results.append({"prompt": prompt, "generated": moe_gen, "metrics": moe_metrics}) |
|
|
| dense_results.append({"prompt": prompt, "generated": dense_gen, "metrics": dense_metrics}) |
|
|
| print("\n" + "=" * 80) |
| print("SUMMARY") |
| print("=" * 80) |
|
|
| def print_summary(label, results): |
| avg_vocab = sum(r["metrics"]["vocab_diversity"] for r in results) / len(results) |
| avg_rep = sum(r["metrics"]["repetition_rate"] for r in results) / len(results) |
| avg_eng = sum(r["metrics"]["english_char_ratio"] for r in results) / len(results) |
| total_rep = sum(1 for r in results if r["metrics"]["has_obvious_repetition"]) |
| print(f"\n{label}:") |
| print(f" Vocab diversity: {avg_vocab:.2f}") |
| print(f" Repetition rate: {avg_rep:.2f}") |
| print(f" English chars: {avg_eng:.2f}") |
| print(f" Repetition hits: {total_rep}/{len(results)}") |
|
|
| print_summary("DENSE", dense_results) |
| if moe_results: |
| print_summary("MoE ", moe_results) |
|
|
| print("\nComparison:") |
| d_vocab = sum(r["metrics"]["vocab_diversity"] for r in dense_results) / len(dense_results) |
| m_vocab = sum(r["metrics"]["vocab_diversity"] for r in moe_results) / len(moe_results) |
| d_eng = sum(r["metrics"]["english_char_ratio"] for r in dense_results) / len(dense_results) |
| m_eng = sum(r["metrics"]["english_char_ratio"] for r in moe_results) / len(moe_results) |
|
|
| if d_vocab > m_vocab: |
| print(f" Dense has better vocabulary diversity (+{d_vocab - m_vocab:.2f})") |
| else: |
| print(f" MoE has better vocabulary diversity (+{m_vocab - d_vocab:.2f})") |
|
|
| if d_eng > m_eng: |
| print(f" Dense produces more English-like text (+{d_eng - m_eng:.2f})") |
| else: |
| print(f" MoE produces more English-like text (+{m_eng - d_eng:.2f})") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|