CLIWorks commited on
Commit
63d4191
·
verified ·
1 Parent(s): 9f0cc3b

Delete eval_dense.py

Browse files
Files changed (1) hide show
  1. eval_dense.py +0 -673
eval_dense.py DELETED
@@ -1,673 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Evaluate SpiderPortal v5-Dense checkpoint with side-by-side MoE comparison.
4
-
5
- Usage:
6
- python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-final-ep1.pt --moe checkpoints/spiderportal-v5-final-ep1.pt --all
7
- python eval_dense.py --dense checkpoints-dense/spiderportal-v5-dense-ep1-step1000.pt --prompts "The cat sat on the"
8
- """
9
-
10
- import argparse
11
- import math
12
- import sys
13
- import time
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- from dataclasses import dataclass
18
- from transformers import AutoTokenizer
19
-
20
-
21
- @dataclass
22
- class SpiderPortalConfig:
23
- vocab_size: int = 50257
24
- hidden_size: int = 384
25
- num_hidden_layers: int = 8
26
- num_attention_heads: int = 8
27
- num_key_value_heads: int = 2
28
- intermediate_size: int = 1024
29
- num_experts: int = 64
30
- num_experts_per_tok: int = 1
31
- router_aux_loss_coef: float = 0.05
32
- max_loop_iters: int = 1
33
- act_threshold: float = 0.5
34
- max_position_embeddings: int = 131072
35
- rope_theta: float = 10000000.0
36
- rope_scaling: dict = None
37
- sliding_window: int = 4096
38
- attention_dropout: float = 0.0
39
- rms_norm_eps: float = 1e-6
40
- initializer_range: float = 0.02
41
- tie_word_embeddings: bool = True
42
- prelude_layers: int = 2
43
- coda_layers: int = 2
44
- lora_rank: int = 32
45
- loop_embed_dim: int = 48
46
-
47
-
48
- def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
49
- freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
50
- angles = loop_t * freqs
51
- emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
52
- emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
53
- emb_full[:loop_dim] = emb
54
- return h + emb_full.unsqueeze(0).unsqueeze(0)
55
-
56
-
57
- def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
58
- dim = head_dim
59
- orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
60
- pos_freqs = torch.arange(0, dim, 2).float() / dim
61
- beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
62
- scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
63
- return orig_inv_freq * scale
64
-
65
-
66
- class SpiderPortalRMSNorm(nn.Module):
67
- def __init__(self, hidden_size, eps=1e-6):
68
- super().__init__()
69
- self.weight = nn.Parameter(torch.ones(hidden_size))
70
- self.variance_epsilon = eps
71
- def forward(self, hidden_states):
72
- input_dtype = hidden_states.dtype
73
- hidden_states = hidden_states.to(torch.float32)
74
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
- return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
77
-
78
-
79
- class SpiderPortalGQA(nn.Module):
80
- def __init__(self, config):
81
- super().__init__()
82
- self.config = config
83
- self.hidden_size = config.hidden_size
84
- self.num_heads = config.num_attention_heads
85
- self.num_kv_heads = config.num_key_value_heads
86
- self.head_dim = config.hidden_size // config.num_attention_heads
87
- self.num_key_value_groups = self.num_heads // self.num_kv_heads
88
- self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
89
- self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
90
- self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
91
- self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
92
- self.attention_dropout = config.attention_dropout
93
- rope_scaling = getattr(config, 'rope_scaling', None)
94
- if rope_scaling and rope_scaling.get("type") == "yarn":
95
- factor = rope_scaling.get("factor", 1.0)
96
- orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
97
- inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
98
- else:
99
- inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
100
- self.register_buffer("inv_freq", inv_freq, persistent=False)
101
- def _rotate_half(self, x):
102
- x1 = x[..., :x.shape[-1] // 2]
103
- x2 = x[..., x.shape[-1] // 2:]
104
- return torch.cat((-x2, x1), dim=-1)
105
- def _apply_rotary(self, x, cos, sin):
106
- return (x * cos) + (self._rotate_half(x) * sin)
107
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
108
- bsz, q_len, _ = hidden_states.size()
109
- query_states = self.q_proj(hidden_states)
110
- key_states = self.k_proj(hidden_states)
111
- value_states = self.v_proj(hidden_states)
112
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
113
- key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
114
- value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
115
- if position_ids is None:
116
- position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
117
- max_pos = position_ids.max().item() + 1
118
- seq_len = max(max_pos, q_len)
119
- t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
120
- freqs = torch.outer(t, self.inv_freq)
121
- emb = torch.cat((freqs, freqs), dim=-1)
122
- cos, sin = emb.cos(), emb.sin()
123
- cos = cos[position_ids].unsqueeze(1)
124
- sin = sin[position_ids].unsqueeze(1)
125
- query_states = self._apply_rotary(query_states, cos, sin)
126
- key_states = self._apply_rotary(key_states, cos, sin)
127
- if past_key_value is not None:
128
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
129
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
130
- past_kv = (key_states, value_states) if use_cache else None
131
- key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
132
- value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
133
- attn_output = F.scaled_dot_product_attention(
134
- query_states, key_states, value_states,
135
- attn_mask=attention_mask,
136
- dropout_p=self.attention_dropout if self.training else 0.0,
137
- is_causal=attention_mask is None
138
- )
139
- attn_output = attn_output.transpose(1, 2).contiguous()
140
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
141
- return self.o_proj(attn_output), past_kv
142
-
143
-
144
- class SpiderPortalExpert(nn.Module):
145
- def __init__(self, config, intermediate_size=None):
146
- super().__init__()
147
- inter_size = intermediate_size or config.intermediate_size
148
- self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
149
- self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
150
- self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
151
- self.act_fn = nn.SiLU()
152
- def forward(self, hidden_states):
153
- return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
154
-
155
-
156
- class SpiderPortalDenseLayer(nn.Module):
157
- def __init__(self, config):
158
- super().__init__()
159
- self.self_attn = SpiderPortalGQA(config)
160
- dense_intermediate = config.hidden_size * 4 // 3
161
- self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
162
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
163
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
164
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
165
- attn_input = self.input_layernorm(hidden_states)
166
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
167
- hidden_states = hidden_states + attn_output
168
- ffn_input = self.post_attention_layernorm(hidden_states)
169
- ffn_output = self.ffn(ffn_input)
170
- hidden_states = hidden_states + ffn_output
171
- return hidden_states, past_kv
172
-
173
-
174
- class SpiderPortalRecurrentDenseLayer(nn.Module):
175
- """Dense recurrent layer — matches checkpoint keys."""
176
- def __init__(self, config, layer_idx):
177
- super().__init__()
178
- self.layer_idx = layer_idx
179
- self.self_attn = SpiderPortalGQA(config)
180
- self.ffn = SpiderPortalExpert(config, intermediate_size=config.intermediate_size)
181
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
182
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
183
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
184
- attn_input = self.input_layernorm(hidden_states)
185
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
186
- hidden_states = hidden_states + attn_output
187
- ffn_input = self.post_attention_layernorm(hidden_states)
188
- ffn_output = self.ffn(ffn_input)
189
- hidden_states = hidden_states + ffn_output
190
- return hidden_states, 0.0, past_kv
191
-
192
-
193
- # MoE layer for comparison model
194
- class SpiderPortalRouter(nn.Module):
195
- def __init__(self, config):
196
- super().__init__()
197
- self.num_experts = config.num_experts
198
- self.num_experts_per_tok = config.num_experts_per_tok
199
- self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
200
- self.register_buffer("router_bias", torch.zeros(config.num_experts))
201
- def forward(self, hidden_states):
202
- router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
203
- routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
204
- biased_logits = router_logits + self.router_bias
205
- biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
206
- top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
207
- top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
208
- top_weights = top_weights.to(hidden_states.dtype)
209
- mean_probs = routing_weights.mean(dim=0)
210
- aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
211
- return top_weights, top_indices, aux_loss
212
-
213
-
214
- class SpiderPortalMoE(nn.Module):
215
- def __init__(self, config):
216
- super().__init__()
217
- self.config = config
218
- self.num_experts = config.num_experts
219
- self.num_experts_per_tok = config.num_experts_per_tok
220
- self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
221
- self.shared_expert = SpiderPortalExpert(config)
222
- self.router = SpiderPortalRouter(config)
223
- def forward(self, hidden_states):
224
- batch_size, seq_len, hidden_dim = hidden_states.shape
225
- top_weights, top_indices, aux_loss = self.router(hidden_states)
226
- flat_hidden = hidden_states.view(-1, hidden_dim)
227
- final_output = torch.zeros_like(flat_hidden)
228
- for expert_idx in range(self.num_experts_per_tok):
229
- expert_ids = top_indices[:, expert_idx]
230
- expert_weights = top_weights[:, expert_idx:expert_idx+1]
231
- for e in range(self.num_experts):
232
- mask = expert_ids == e
233
- if mask.any():
234
- expert_output = self.experts[e](flat_hidden[mask])
235
- final_output[mask] += expert_output * expert_weights[mask]
236
- shared_output = self.shared_expert(flat_hidden)
237
- final_output = final_output + shared_output
238
- return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
239
-
240
-
241
- class SpiderPortalMoELayer(nn.Module):
242
- def __init__(self, config, layer_idx):
243
- super().__init__()
244
- self.layer_idx = layer_idx
245
- self.self_attn = SpiderPortalGQA(config)
246
- self.moe = SpiderPortalMoE(config)
247
- self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
248
- self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
249
- def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
250
- attn_input = self.input_layernorm(hidden_states)
251
- attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
252
- hidden_states = hidden_states + attn_output
253
- moe_input = self.post_attention_layernorm(hidden_states)
254
- moe_output, aux_loss = self.moe(moe_input)
255
- hidden_states = hidden_states + moe_output
256
- return hidden_states, aux_loss, past_kv
257
-
258
-
259
- class LTIInjection(nn.Module):
260
- def __init__(self, config):
261
- super().__init__()
262
- self.hidden_size = config.hidden_size
263
- self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
264
- self.delta_t = nn.Parameter(torch.tensor(1.0))
265
- self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
266
- with torch.no_grad():
267
- self.B.weight.data.normal_(mean=0.0, std=0.01)
268
- def get_A(self):
269
- return -torch.exp(self.log_A)
270
- def forward(self, h_t, e):
271
- A = self.get_A()
272
- return A * h_t + self.B(e)
273
-
274
-
275
- class ACTHalting(nn.Module):
276
- def __init__(self, config):
277
- super().__init__()
278
- self.halt_predictor = nn.Linear(config.hidden_size, 1)
279
- self.threshold = config.act_threshold
280
- def forward(self, hidden_states):
281
- return torch.sigmoid(self.halt_predictor(hidden_states))
282
-
283
-
284
- class LoRAAdapter(nn.Module):
285
- def __init__(self, config):
286
- super().__init__()
287
- rank = config.lora_rank
288
- self.down = nn.Linear(config.hidden_size, rank, bias=False)
289
- self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
290
- self.scale = nn.Embedding(config.max_loop_iters, rank)
291
- with torch.no_grad():
292
- self.scale.weight.data.zero_()
293
- self.down.weight.data.normal_(mean=0.0, std=0.001)
294
- def forward(self, x, loop_t):
295
- max_t = self.scale.num_embeddings - 1
296
- t_idx = min(loop_t, max_t)
297
- s = self.scale(torch.tensor(t_idx, device=x.device))
298
- down = self.down(x) * s
299
- return down @ self.B
300
-
301
-
302
- class SpiderPortalDenseModel(nn.Module):
303
- def __init__(self, config):
304
- super().__init__()
305
- self.config = config
306
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
307
- self.recurrent_layers = nn.ModuleList([SpiderPortalRecurrentDenseLayer(config, i) for i in range(config.num_hidden_layers)])
308
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
309
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
310
- self.injection = LTIInjection(config)
311
- self.act_halting = ACTHalting(config)
312
- self.lora_adapter = LoRAAdapter(config)
313
- self.loop_embed_dim = config.loop_embed_dim
314
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
315
- n_loops = n_loops or self.config.max_loop_iters
316
- input_embedding = input_embedding if input_embedding is not None else hidden_states
317
- for layer in self.prelude_layers:
318
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
319
- e = hidden_states.clone()
320
- B, T_seq, D = hidden_states.shape
321
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
322
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
323
- h_out = torch.zeros_like(hidden_states)
324
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
325
- for t in range(n_loops):
326
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
327
- if t > 0:
328
- injection = self.injection(hidden_states, input_embedding)
329
- hidden_states = hidden_states + injection
330
- new_past_key_values = []
331
- for i, layer in enumerate(self.recurrent_layers):
332
- hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
333
- new_past_key_values.append(past_kv)
334
- lora_delta = self.lora_adapter(hidden_states, t)
335
- hidden_states = hidden_states + lora_delta
336
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
337
- still_running = ~halted
338
- remainder = (1.0 - cumulative_p).clamp(min=0)
339
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
340
- weight = weight * still_running.to(hidden_states.dtype)
341
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
342
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
343
- halted = halted | (cumulative_p >= self.config.act_threshold)
344
- if halted.all() and not self.training:
345
- break
346
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
347
- hidden_states = h_out + never_halted * hidden_states
348
- for layer in self.coda_layers:
349
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
350
- hidden_states = self.norm(hidden_states)
351
- return hidden_states, 0.0, new_past_key_values
352
-
353
-
354
- class SpiderPortalMoEModel(nn.Module):
355
- def __init__(self, config):
356
- super().__init__()
357
- self.config = config
358
- self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
359
- self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
360
- self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
361
- self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
362
- self.injection = LTIInjection(config)
363
- self.act_halting = ACTHalting(config)
364
- self.lora_adapter = LoRAAdapter(config)
365
- self.loop_embed_dim = config.loop_embed_dim
366
- def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
367
- n_loops = n_loops or self.config.max_loop_iters
368
- input_embedding = input_embedding if input_embedding is not None else hidden_states
369
- total_aux_loss = 0.0
370
- for layer in self.prelude_layers:
371
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
372
- e = hidden_states.clone()
373
- B, T_seq, D = hidden_states.shape
374
- halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
375
- cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
376
- h_out = torch.zeros_like(hidden_states)
377
- past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
378
- for t in range(n_loops):
379
- h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
380
- if t > 0:
381
- injection = self.injection(hidden_states, input_embedding)
382
- hidden_states = hidden_states + injection
383
- new_past_key_values = []
384
- for i, layer in enumerate(self.recurrent_layers):
385
- hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
386
- total_aux_loss = total_aux_loss + aux_loss
387
- new_past_key_values.append(past_kv)
388
- lora_delta = self.lora_adapter(hidden_states, t)
389
- hidden_states = hidden_states + lora_delta
390
- halt_prob = self.act_halting(hidden_states).squeeze(-1)
391
- still_running = ~halted
392
- remainder = (1.0 - cumulative_p).clamp(min=0)
393
- weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
394
- weight = weight * still_running.to(hidden_states.dtype)
395
- h_out = h_out + weight.unsqueeze(-1) * hidden_states
396
- cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
397
- halted = halted | (cumulative_p >= self.config.act_threshold)
398
- if halted.all() and not self.training:
399
- break
400
- never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
401
- hidden_states = h_out + never_halted * hidden_states
402
- for layer in self.coda_layers:
403
- hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
404
- hidden_states = self.norm(hidden_states)
405
- return hidden_states, total_aux_loss, new_past_key_values
406
-
407
-
408
- class SpiderPortalForConditionalGeneration(nn.Module):
409
- def __init__(self, config, model_class=SpiderPortalDenseModel):
410
- super().__init__()
411
- self.config = config
412
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
413
- self.model = model_class(config)
414
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
415
- if config.tie_word_embeddings:
416
- self.lm_head.weight = self.embed_tokens.weight
417
- self.apply(self._init_weights)
418
- def _init_weights(self, module):
419
- if isinstance(module, nn.Linear):
420
- if hasattr(self, 'model') and module is self.model.injection.B:
421
- return
422
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
423
- if module.bias is not None:
424
- module.bias.data.zero_()
425
- elif isinstance(module, nn.Embedding):
426
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
427
- def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
428
- hidden_states = self.embed_tokens(input_ids)
429
- model_dtype = next(self.model.parameters()).dtype
430
- hidden_states = hidden_states.to(model_dtype)
431
- input_embedding = hidden_states.clone()
432
- if attention_mask is None:
433
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
434
- causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
435
- causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
436
- causal_mask = causal_mask.triu(1)
437
- hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
438
- logits = self.lm_head(hidden_states)
439
- return {"loss": None, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
440
-
441
-
442
- DEFAULT_PROMPTS = [
443
- "The cat sat on the",
444
- "The capital of France is",
445
- "If I have 3 apples and eat 1, I have",
446
- "Once upon a time, there was a",
447
- "Python is a programming language that",
448
- "Two plus two equals",
449
- "When it rains, the ground gets",
450
- "The door opened slowly and",
451
- "What is the meaning of life? The",
452
- "def fibonacci(n):\n if n <= 1:\n return",
453
- ]
454
-
455
-
456
- def load_model(checkpoint_path, device="cpu", model_class=SpiderPortalDenseModel):
457
- print(f"Loading checkpoint: {checkpoint_path}")
458
- ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
459
-
460
- cfg = ckpt.get("cfg")
461
- vocab_size = ckpt.get("vocab_size", 50257)
462
-
463
- if cfg is None:
464
- cfg = SpiderPortalConfig(
465
- hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
466
- num_key_value_heads=2, intermediate_size=1024,
467
- num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
468
- router_aux_loss_coef=0.05, max_loop_iters=1,
469
- prelude_layers=2, coda_layers=2, lora_rank=32,
470
- rope_theta=10000000.0,
471
- rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
472
- max_position_embeddings=131072, sliding_window=4096,
473
- tie_word_embeddings=True,
474
- )
475
- cfg.vocab_size = vocab_size
476
-
477
- model_state = ckpt.get("model_state_dict", ckpt)
478
- model = SpiderPortalForConditionalGeneration(cfg, model_class=model_class)
479
-
480
- missing, unexpected = model.load_state_dict(model_state, strict=False)
481
- if missing:
482
- print(f" Missing keys ({len(missing)}): {missing[:3]}...")
483
- if unexpected:
484
- print(f" Unexpected keys ({len(unexpected)}): {unexpected[:3]}...")
485
- if not missing and not unexpected:
486
- print(" All keys matched perfectly")
487
-
488
- model = model.to(device)
489
- model.eval()
490
-
491
- n_params = sum(p.numel() for p in model.parameters())
492
- print(f" Parameters: {n_params:,} on {device}")
493
-
494
- return model, cfg
495
-
496
-
497
- def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_p=0.9, device="cpu"):
498
- input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
499
-
500
- generated = []
501
- with torch.no_grad():
502
- for _ in range(max_new_tokens):
503
- outputs = model(input_ids, use_cache=False)
504
- logits = outputs["logits"][0, -1, :]
505
-
506
- if temperature > 0:
507
- logits = logits / temperature
508
- probs = F.softmax(logits, dim=-1)
509
- sorted_probs, sorted_indices = torch.sort(probs, descending=True)
510
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
511
- sorted_indices_to_remove = cumulative_probs > top_p
512
- sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
513
- sorted_indices_to_remove[0] = False
514
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
515
- probs[indices_to_remove] = 0.0
516
- probs = probs / probs.sum()
517
- next_token = torch.multinomial(probs, 1)
518
- else:
519
- next_token = torch.argmax(logits, dim=-1, keepdim=True)
520
-
521
- generated.append(next_token.item())
522
- input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
523
-
524
- if next_token.item() == tokenizer.eos_token_id:
525
- break
526
-
527
- return tokenizer.decode(generated, skip_special_tokens=True)
528
-
529
-
530
- def analyze_output(prompt, generated_text):
531
- full = prompt + generated_text
532
- words = full.split()
533
- unique_words = set(w.lower() for w in words)
534
- vocab_diversity = len(unique_words) / max(len(words), 1)
535
-
536
- n = 4
537
- if len(words) >= n:
538
- ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
539
- unique_ngrams = set(ngrams)
540
- repetition_rate = 1.0 - len(unique_ngrams) / max(len(ngrams), 1)
541
- else:
542
- repetition_rate = 0.0
543
-
544
- has_repetition = False
545
- for pattern in ["... ", "!!!", " and and ", " the the ", " is is "]:
546
- if pattern in full.lower():
547
- has_repetition = True
548
- break
549
-
550
- english_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ '.,!?;:-\"()")
551
- char_ratio = sum(1 for c in generated_text if c in english_chars) / max(len(generated_text), 1)
552
-
553
- return {
554
- "total_words": len(words),
555
- "unique_words": len(unique_words),
556
- "vocab_diversity": vocab_diversity,
557
- "repetition_rate": repetition_rate,
558
- "has_obvious_repetition": has_repetition,
559
- "english_char_ratio": char_ratio,
560
- }
561
-
562
-
563
- def main():
564
- parser = argparse.ArgumentParser(description="Evaluate SpiderPortal Dense vs MoE")
565
- parser.add_argument("--dense", required=True, help="Path to dense checkpoint")
566
- parser.add_argument("--moe", default=None, help="Path to MoE checkpoint for comparison")
567
- parser.add_argument("--prompts", nargs="*", default=None)
568
- parser.add_argument("--file", default=None, help="File with prompts")
569
- parser.add_argument("--all", action="store_true", help="Run default prompt suite")
570
- parser.add_argument("--max-new-tokens", type=int, default=80)
571
- parser.add_argument("--temperature", type=float, default=0.8)
572
- parser.add_argument("--top-p", type=float, default=0.9)
573
- parser.add_argument("--device", default=None)
574
- args = parser.parse_args()
575
-
576
- device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
577
- print(f"Device: {device}")
578
-
579
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
580
- tokenizer.pad_token = tokenizer.eos_token
581
-
582
- prompts = []
583
- if args.all:
584
- prompts = DEFAULT_PROMPTS
585
- elif args.prompts:
586
- prompts = args.prompts
587
- elif args.file:
588
- with open(args.file) as f:
589
- prompts = [line.strip() for line in f if line.strip()]
590
- else:
591
- prompts = DEFAULT_PROMPTS[:3]
592
-
593
- dense_model, _ = load_model(args.dense, device, model_class=SpiderPortalDenseModel)
594
-
595
- moe_model = None
596
- if args.moe:
597
- print()
598
- moe_model, _ = load_model(args.moe, device, model_class=SpiderPortalMoEModel)
599
-
600
- print(f"\nRunning {len(prompts)} prompts (max_new_tokens={args.max_new_tokens}, temp={args.temperature})\n")
601
- print("=" * 80)
602
-
603
- dense_results = []
604
- moe_results = []
605
-
606
- for i, prompt in enumerate(prompts):
607
- print(f"\n[Prompt {i+1}/{len(prompts)}]: {prompt}")
608
-
609
- t0 = time.time()
610
- dense_gen = generate(dense_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device)
611
- dense_elapsed = time.time() - t0
612
- dense_metrics = analyze_output(prompt, dense_gen)
613
-
614
- print(f" [DENSE] {dense_gen}")
615
- print(f" vocab_div={dense_metrics['vocab_diversity']:.2f}, "
616
- f"repetition={dense_metrics['repetition_rate']:.2f}, "
617
- f"english={dense_metrics['english_char_ratio']:.2f}, "
618
- f"tok/s={args.max_new_tokens/max(dense_elapsed,0.001):.1f}")
619
-
620
- if moe_model:
621
- t0 = time.time()
622
- moe_gen = generate(moe_model, tokenizer, prompt, args.max_new_tokens, args.temperature, args.top_p, device)
623
- moe_elapsed = time.time() - t0
624
- moe_metrics = analyze_output(prompt, moe_gen)
625
-
626
- print(f" [MoE ] {moe_gen}")
627
- print(f" vocab_div={moe_metrics['vocab_diversity']:.2f}, "
628
- f"repetition={moe_metrics['repetition_rate']:.2f}, "
629
- f"english={moe_metrics['english_char_ratio']:.2f}, "
630
- f"tok/s={args.max_new_tokens/max(moe_elapsed,0.001):.1f}")
631
-
632
- moe_results.append({"prompt": prompt, "generated": moe_gen, "metrics": moe_metrics})
633
-
634
- dense_results.append({"prompt": prompt, "generated": dense_gen, "metrics": dense_metrics})
635
-
636
- print("\n" + "=" * 80)
637
- print("SUMMARY")
638
- print("=" * 80)
639
-
640
- def print_summary(label, results):
641
- avg_vocab = sum(r["metrics"]["vocab_diversity"] for r in results) / len(results)
642
- avg_rep = sum(r["metrics"]["repetition_rate"] for r in results) / len(results)
643
- avg_eng = sum(r["metrics"]["english_char_ratio"] for r in results) / len(results)
644
- total_rep = sum(1 for r in results if r["metrics"]["has_obvious_repetition"])
645
- print(f"\n{label}:")
646
- print(f" Vocab diversity: {avg_vocab:.2f}")
647
- print(f" Repetition rate: {avg_rep:.2f}")
648
- print(f" English chars: {avg_eng:.2f}")
649
- print(f" Repetition hits: {total_rep}/{len(results)}")
650
-
651
- print_summary("DENSE", dense_results)
652
- if moe_results:
653
- print_summary("MoE ", moe_results)
654
-
655
- print("\nComparison:")
656
- d_vocab = sum(r["metrics"]["vocab_diversity"] for r in dense_results) / len(dense_results)
657
- m_vocab = sum(r["metrics"]["vocab_diversity"] for r in moe_results) / len(moe_results)
658
- d_eng = sum(r["metrics"]["english_char_ratio"] for r in dense_results) / len(dense_results)
659
- m_eng = sum(r["metrics"]["english_char_ratio"] for r in moe_results) / len(moe_results)
660
-
661
- if d_vocab > m_vocab:
662
- print(f" Dense has better vocabulary diversity (+{d_vocab - m_vocab:.2f})")
663
- else:
664
- print(f" MoE has better vocabulary diversity (+{m_vocab - d_vocab:.2f})")
665
-
666
- if d_eng > m_eng:
667
- print(f" Dense produces more English-like text (+{d_eng - m_eng:.2f})")
668
- else:
669
- print(f" MoE produces more English-like text (+{m_eng - d_eng:.2f})")
670
-
671
-
672
- if __name__ == "__main__":
673
- main()