| |
| """SpiderPortal v5 — Single-GPU Optimized Training. |
| |
| For RTX PRO 6000 (96GB) or similar large-VRAM GPU. |
| No DDP, maximal batch size, torch.compile, pre-tokenized data. |
| |
| Usage: |
| python train_single_gpu.py |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import os |
| import json |
| import gc |
| import random |
| import time |
| from pathlib import Path |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Dict, List |
| from torch.nn import CrossEntropyLoss |
|
|
| @dataclass |
| class SpiderPortalConfig: |
| vocab_size: int = 50278 |
| hidden_size: int = 384 |
| num_hidden_layers: int = 8 |
| num_attention_heads: int = 8 |
| num_key_value_heads: int = 2 |
| intermediate_size: int = 1024 |
| hidden_act: str = "silu" |
| num_experts: int = 64 |
| num_experts_per_tok: int = 1 |
| num_shared_experts: int = 1 |
| router_aux_loss_coef: float = 0.05 |
| max_loop_iters: int = 4 |
| act_threshold: float = 0.5 |
| max_position_embeddings: int = 131072 |
| rope_theta: float = 10000000.0 |
| rope_scaling: dict = None |
| sliding_window: int = 4096 |
| attention_dropout: float = 0.0 |
| rms_norm_eps: float = 1e-6 |
| initializer_range: float = 0.02 |
| use_cache: bool = True |
| tie_word_embeddings: bool = True |
| prelude_layers: int = 2 |
| coda_layers: int = 2 |
| lora_rank: int = 32 |
| loop_embed_dim: int = 48 |
| vision_hidden_size: int = 384 |
| audio_hidden_size: int = 512 |
| vision_num_frames: int = 60 |
| vision_tokens_per_frame: int = 256 |
| vision_temporal_tokens: int = 64 |
| vision_temporal_layers: int = 2 |
| model_type: str = "spiderportal" |
| torch_dtype: str = "bfloat16" |
|
|
| def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) |
| angles = loop_t * freqs |
| emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] |
| emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) |
| emb_full[:loop_dim] = emb |
| return h + emb_full.unsqueeze(0).unsqueeze(0) |
|
|
| class SpiderPortalRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight.to(input_dtype) * hidden_states.to(input_dtype) |
|
|
| def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0): |
| dim = head_dim |
| orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim)) |
| pos_freqs = torch.arange(0, dim, 2).float() / dim |
| beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max)) |
| scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor))) |
| return orig_inv_freq * scale |
|
|
| class SpiderPortalGQA(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.num_kv_heads = config.num_key_value_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.num_key_value_groups = self.num_heads // self.num_kv_heads |
| self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| self.attention_dropout = config.attention_dropout |
| rope_scaling = getattr(config, 'rope_scaling', None) |
| if rope_scaling and rope_scaling.get("type") == "yarn": |
| factor = rope_scaling.get("factor", 1.0) |
| orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings) |
| inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos) |
| else: |
| inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| def _rotate_half(self, x): |
| x1 = x[..., :x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
| def _apply_rotary(self, x, cos, sin): |
| return (x * cos) + (self._rotate_half(x) * sin) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| bsz, q_len, _ = hidden_states.size() |
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| if position_ids is None: |
| position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1) |
| max_pos = position_ids.max().item() + 1 |
| seq_len = max(max_pos, q_len) |
| t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos, sin = emb.cos(), emb.sin() |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
| query_states = self._apply_rotary(query_states, cos, sin) |
| key_states = self._apply_rotary(key_states, cos, sin) |
| if past_key_value is not None: |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| past_kv = (key_states, value_states) if use_cache else None |
| key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
| value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) |
| attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
| return self.o_proj(attn_output), past_kv |
|
|
| class SpiderPortalExpert(nn.Module): |
| def __init__(self, config, intermediate_size=None): |
| super().__init__() |
| inter_size = intermediate_size or config.intermediate_size |
| self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False) |
| self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False) |
| self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False) |
| self.act_fn = nn.SiLU() |
| def forward(self, hidden_states): |
| return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) |
|
|
| class SpiderPortalRouter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.num_experts = config.num_experts |
| self.num_experts_per_tok = config.num_experts_per_tok |
| self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range) |
| self.register_buffer("router_bias", torch.zeros(config.num_experts)) |
| def forward(self, hidden_states): |
| router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight |
| routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) |
| biased_logits = router_logits + self.router_bias |
| biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32) |
| top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1) |
| top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) |
| top_weights = top_weights.to(hidden_states.dtype) |
| mean_probs = routing_weights.mean(dim=0) |
| aux_loss = self.num_experts * (mean_probs * mean_probs).sum() |
| return top_weights, top_indices, aux_loss |
|
|
| class SpiderPortalMoE(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.num_experts = config.num_experts |
| self.num_experts_per_tok = config.num_experts_per_tok |
| self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)]) |
| self.shared_expert = SpiderPortalExpert(config) |
| self.router = SpiderPortalRouter(config) |
| def forward(self, hidden_states): |
| batch_size, seq_len, hidden_dim = hidden_states.shape |
| top_weights, top_indices, aux_loss = self.router(hidden_states) |
| flat_hidden = hidden_states.view(-1, hidden_dim) |
| final_output = torch.zeros_like(flat_hidden) |
| for expert_idx in range(self.num_experts_per_tok): |
| expert_ids = top_indices[:, expert_idx] |
| expert_weights = top_weights[:, expert_idx:expert_idx+1] |
| for e in range(self.num_experts): |
| mask = expert_ids == e |
| if mask.any(): |
| expert_output = self.experts[e](flat_hidden[mask]) |
| final_output[mask] += expert_output * expert_weights[mask] |
| shared_output = self.shared_expert(flat_hidden) |
| final_output = final_output + shared_output |
| return final_output.view(batch_size, seq_len, hidden_dim), aux_loss |
|
|
| class SpiderPortalDenseLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self_attn = SpiderPortalGQA(config) |
| dense_intermediate = config.hidden_size * 4 // 3 |
| self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
| ffn_input = self.post_attention_layernorm(hidden_states) |
| ffn_output = self.ffn(ffn_input) |
| hidden_states = hidden_states + ffn_output |
| return hidden_states, past_kv |
|
|
| class SpiderPortalMoELayer(nn.Module): |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.layer_idx = layer_idx |
| self.self_attn = SpiderPortalGQA(config) |
| self.moe = SpiderPortalMoE(config) |
| self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| attn_input = self.input_layernorm(hidden_states) |
| attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) |
| hidden_states = hidden_states + attn_output |
| moe_input = self.post_attention_layernorm(hidden_states) |
| moe_output, aux_loss = self.moe(moe_input) |
| hidden_states = hidden_states + moe_output |
| return hidden_states, aux_loss, past_kv |
|
|
| class LTIInjection(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) |
| self.delta_t = nn.Parameter(torch.tensor(1.0)) |
| self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| with torch.no_grad(): |
| self.B.weight.data.normal_(mean=0.0, std=0.01) |
| def get_A(self): |
| return -torch.exp(self.log_A) |
| def forward(self, h_t, e): |
| A = self.get_A() |
| return A * h_t + self.B(e) |
|
|
| class ACTHalting(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.halt_predictor = nn.Linear(config.hidden_size, 1) |
| self.threshold = config.act_threshold |
| def forward(self, hidden_states): |
| return torch.sigmoid(self.halt_predictor(hidden_states)) |
|
|
| class LoRAAdapter(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| rank = config.lora_rank |
| self.down = nn.Linear(config.hidden_size, rank, bias=False) |
| self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02) |
| self.scale = nn.Embedding(config.max_loop_iters, rank) |
| with torch.no_grad(): |
| self.scale.weight.data.zero_() |
| self.down.weight.data.normal_(mean=0.0, std=0.001) |
| def forward(self, x, loop_t): |
| max_t = self.scale.num_embeddings - 1 |
| t_idx = min(loop_t, max_t) |
| s = self.scale(torch.tensor(t_idx, device=x.device)) |
| down = self.down(x) * s |
| return down @ self.B |
|
|
| class SpiderPortalMoEModel(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) |
| self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)]) |
| self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) |
| self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.injection = LTIInjection(config) |
| self.act_halting = ACTHalting(config) |
| self.lora_adapter = LoRAAdapter(config) |
| self.loop_embed_dim = config.loop_embed_dim |
| def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None): |
| n_loops = n_loops or self.config.max_loop_iters |
| input_embedding = input_embedding if input_embedding is not None else hidden_states |
| total_aux_loss = 0.0 |
| for layer in self.prelude_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| e = hidden_states.clone() |
| B, T_seq, D = hidden_states.shape |
| halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) |
| cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) |
| h_out = torch.zeros_like(hidden_states) |
| past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) |
| for t in range(n_loops): |
| h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) |
| if t > 0: |
| injection = self.injection(hidden_states, input_embedding) |
| hidden_states = hidden_states + injection |
| new_past_key_values = [] |
| for i, layer in enumerate(self.recurrent_layers): |
| hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache) |
| total_aux_loss = total_aux_loss + aux_loss |
| new_past_key_values.append(past_kv) |
| lora_delta = self.lora_adapter(hidden_states, t) |
| hidden_states = hidden_states + lora_delta |
| halt_prob = self.act_halting(hidden_states).squeeze(-1) |
| still_running = ~halted |
| remainder = (1.0 - cumulative_p).clamp(min=0) |
| weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) |
| weight = weight * still_running.to(hidden_states.dtype) |
| h_out = h_out + weight.unsqueeze(-1) * hidden_states |
| cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) |
| halted = halted | (cumulative_p >= self.config.act_threshold) |
| if halted.all() and not self.training: |
| break |
| never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) |
| hidden_states = h_out + never_halted * hidden_states |
| for layer in self.coda_layers: |
| hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) |
| hidden_states = self.norm(hidden_states) |
| return hidden_states, total_aux_loss, new_past_key_values |
|
|
| class SpiderPortalForConditionalGeneration(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.model = SpiderPortalMoEModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.embed_tokens.weight |
| self.apply(self._init_weights) |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| if hasattr(self, 'model') and module is self.model.injection.B: |
| return |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False): |
| hidden_states = self.embed_tokens(input_ids) |
| model_dtype = next(self.model.parameters()).dtype |
| hidden_states = hidden_states.to(model_dtype) |
| input_embedding = hidden_states.clone() |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device) |
| causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min) |
| causal_mask = causal_mask.triu(1) |
| hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops) |
| logits = self.lm_head(hidden_states) |
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| loss = loss + self.config.router_aux_loss_coef * aux_loss |
| return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv} |
| def get_num_params(self): |
| total = sum(p.numel() for p in self.parameters()) |
| return {"total": total, "trainable": total} |
|
|
| def train_single_gpu(): |
| device = torch.device("cuda") |
| gpu_name = torch.cuda.get_device_name(0) |
| gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)") |
|
|
| config = SpiderPortalConfig( |
| hidden_size=384, num_hidden_layers=8, num_attention_heads=8, |
| num_key_value_heads=2, intermediate_size=1024, |
| num_experts=64, num_experts_per_tok=1, num_shared_experts=1, |
| router_aux_loss_coef=0.05, max_loop_iters=2, |
| prelude_layers=2, coda_layers=2, lora_rank=32, |
| rope_theta=10000000.0, |
| rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768}, |
| max_position_embeddings=131072, sliding_window=4096, |
| tie_word_embeddings=True, |
| ) |
|
|
| print("Building model...") |
| model = SpiderPortalForConditionalGeneration(config) |
| model = model.to(torch.bfloat16).to(device) |
|
|
| params = model.get_num_params() |
| print(f"Model: {params['total']/1e6:.1f}M params") |
| print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared") |
|
|
| BASE_LR = 1e-3 |
| WARMUP_STEPS = 500 |
| optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01, betas=(0.9, 0.95)) |
|
|
| import pandas as pd |
| data_dir = Path(__file__).parent / "data" |
| all_records = [] |
| pkl_file = data_dir / "spiderportal_combined.pkl" |
| if pkl_file.exists(): |
| print(f"Loading dataset from {pkl_file}...") |
| df = pd.read_pickle(pkl_file) |
| all_records = df.to_dict("records") |
| else: |
| print(f"No dataset found at {pkl_file}, creating synthetic data...") |
| all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)] |
|
|
| print(f"Loaded {len(all_records):,} samples") |
|
|
| from transformers import AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| BATCH_SIZE = 128 |
| MAX_LEN = 256 |
| EPOCHS = 3 |
| N_LOOPS = 2 |
|
|
| print(f"Batch size: {BATCH_SIZE} (no grad accum)") |
| print(f"Effective batch: {BATCH_SIZE}") |
| print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup (high LR for recurrent MoE)") |
| print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}") |
|
|
| def build_prompt(sample): |
| instruction = str(sample.get("instruction", "")).strip() |
| inp = str(sample.get("input", "")).strip() |
| output = str(sample.get("output", "")).strip() |
| if inp: |
| return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" |
| return f"Question: Instruction: {instruction}\nAnswer: {output}\n" |
|
|
| print("Pre-tokenizing dataset...") |
| prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"] |
| mask_len = len(prefix_ids) |
|
|
| pre_tokenized = [] |
| for i, sample in enumerate(all_records): |
| instruction = str(sample.get("instruction", "")).strip() |
| inp = str(sample.get("input", "")).strip() |
| output = str(sample.get("output", "")).strip() |
| if inp: |
| text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token |
| else: |
| text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token |
| enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length") |
| input_ids = enc["input_ids"] |
| labels = input_ids[:] |
| for j in range(min(mask_len, len(labels))): |
| labels[j] = -100 |
| pre_tokenized.append((input_ids, labels)) |
| if (i + 1) % 50000 == 0: |
| print(f" Tokenized {i+1:,}/{len(all_records):,}") |
|
|
| print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples") |
| del all_records |
| gc.collect() |
|
|
| global_step = 0 |
| best_loss = float('inf') |
| start_time = time.time() |
| checkpoint_dir = Path("checkpoints") |
| checkpoint_dir.mkdir(exist_ok=True) |
| step_ckpt_files = [] |
|
|
| for epoch in range(1, EPOCHS + 1): |
| if epoch > 1: |
| for f in step_ckpt_files: |
| if f.exists(): |
| f.unlink() |
| print(f" Deleted old step checkpoint: {f.name}") |
| step_ckpt_files.clear() |
| gc.collect() |
|
|
| indices = list(range(len(pre_tokenized))) |
| random.shuffle(indices) |
| total_loss = 0 |
| num_batches = 0 |
| optimizer.zero_grad() |
|
|
| for batch_start in range(0, len(indices), BATCH_SIZE): |
| batch_indices = indices[batch_start:batch_start + BATCH_SIZE] |
| if len(batch_indices) < BATCH_SIZE: |
| continue |
|
|
| if global_step < WARMUP_STEPS: |
| lr = BASE_LR * (global_step + 1) / WARMUP_STEPS |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
|
|
| batch_input_ids = [] |
| batch_labels = [] |
| for idx in batch_indices: |
| input_ids, labels = pre_tokenized[idx] |
| batch_input_ids.append(input_ids) |
| batch_labels.append(labels) |
|
|
| input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device) |
| labels = torch.tensor(batch_labels, dtype=torch.long, device=device) |
|
|
| if global_step == 0: |
| print(" [First forward pass - compiling...]") |
|
|
| outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS) |
| loss = outputs["loss"] |
| loss.backward() |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| total_loss += loss.item() |
| num_batches += 1 |
|
|
| if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0: |
| avg_loss = total_loss / max(num_batches, 1) |
| elapsed = time.time() - start_time |
| steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0 |
| current_lr = optimizer.param_groups[0]['lr'] |
| samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0 |
| print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec") |
|
|
| if global_step > 0 and global_step % 500 == 0: |
| ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt" |
| state_dict = {k: v.cpu() for k, v in model.state_dict().items()} |
| torch.save(state_dict, ckpt_path) |
| step_ckpt_files.append(ckpt_path) |
| size_mb = ckpt_path.stat().st_size / (1024 * 1024) |
| print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)") |
|
|
| avg_loss = total_loss / max(num_batches, 1) |
| epoch_time = (time.time() - start_time) / 60 |
| print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min") |
|
|
| ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt" |
| torch.save({ |
| "step": global_step, |
| "epoch": epoch, |
| "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()}, |
| "optimizer_state_dict": optimizer.state_dict(), |
| "config": config.__dict__, |
| }, ckpt_path) |
| size_mb = ckpt_path.stat().st_size / (1024 * 1024) |
| print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)") |
|
|
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| best_path = checkpoint_dir / "spiderportal-v5-best.pt" |
| torch.save({ |
| "step": global_step, |
| "epoch": epoch, |
| "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()}, |
| "optimizer_state_dict": optimizer.state_dict(), |
| "config": config.__dict__, |
| }, best_path) |
| size_mb = best_path.stat().st_size / (1024 * 1024) |
| print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)") |
|
|
| total_time = (time.time() - start_time) / 3600 |
| print(f"\nTraining complete!") |
| print(f"Best loss: {best_loss:.4f}") |
| print(f"Total time: {total_time:.2f} hours") |
| print(f"Total steps: {global_step}") |
| print(f"Checkpoints saved to: {checkpoint_dir}") |
|
|
| if __name__ == "__main__": |
| train_single_gpu() |
|
|