#!/usr/bin/env python3 """SpiderPortal v5 — Single-GPU Optimized Training. For RTX PRO 6000 (96GB) or similar large-VRAM GPU. No DDP, maximal batch size, torch.compile, pre-tokenized data. Usage: python train_single_gpu.py """ import torch import torch.nn as nn import torch.nn.functional as F import math import os import json import gc import random import time from pathlib import Path from dataclasses import dataclass from typing import Optional, Tuple, Dict, List from torch.nn import CrossEntropyLoss @dataclass class SpiderPortalConfig: vocab_size: int = 50278 hidden_size: int = 384 num_hidden_layers: int = 8 num_attention_heads: int = 8 num_key_value_heads: int = 2 intermediate_size: int = 1024 hidden_act: str = "silu" num_experts: int = 64 num_experts_per_tok: int = 1 num_shared_experts: int = 1 router_aux_loss_coef: float = 0.05 max_loop_iters: int = 4 act_threshold: float = 0.5 max_position_embeddings: int = 131072 rope_theta: float = 10000000.0 rope_scaling: dict = None sliding_window: int = 4096 attention_dropout: float = 0.0 rms_norm_eps: float = 1e-6 initializer_range: float = 0.02 use_cache: bool = True tie_word_embeddings: bool = True prelude_layers: int = 2 coda_layers: int = 2 lora_rank: int = 32 loop_embed_dim: int = 48 vision_hidden_size: int = 384 audio_hidden_size: int = 512 vision_num_frames: int = 60 vision_tokens_per_frame: int = 256 vision_temporal_tokens: int = 64 vision_temporal_layers: int = 2 model_type: str = "spiderportal" torch_dtype: str = "bfloat16" def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim)) angles = loop_t * freqs emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim] emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype) emb_full[:loop_dim] = emb return h + emb_full.unsqueeze(0).unsqueeze(0) class SpiderPortalRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight.to(input_dtype) * hidden_states.to(input_dtype) def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0): dim = head_dim orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim)) pos_freqs = torch.arange(0, dim, 2).float() / dim beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max)) scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor))) return orig_inv_freq * scale class SpiderPortalGQA(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads self.num_key_value_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.attention_dropout = config.attention_dropout rope_scaling = getattr(config, 'rope_scaling', None) if rope_scaling and rope_scaling.get("type") == "yarn": factor = rope_scaling.get("factor", 1.0) orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings) inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos) else: inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def _rotate_half(self, x): x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) def _apply_rotary(self, x, cos, sin): return (x * cos) + (self._rotate_half(x) * sin) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) if position_ids is None: position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1) max_pos = position_ids.max().item() + 1 seq_len = max(max_pos, q_len) t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos, sin = emb.cos(), emb.sin() cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) query_states = self._apply_rotary(query_states, cos, sin) key_states = self._apply_rotary(key_states, cos, sin) if past_key_value is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_kv = (key_states, value_states) if use_cache else None key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) return self.o_proj(attn_output), past_kv class SpiderPortalExpert(nn.Module): def __init__(self, config, intermediate_size=None): super().__init__() inter_size = intermediate_size or config.intermediate_size self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False) self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False) self.act_fn = nn.SiLU() def forward(self, hidden_states): return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states)) class SpiderPortalRouter(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range) self.register_buffer("router_bias", torch.zeros(config.num_experts)) def forward(self, hidden_states): router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) biased_logits = router_logits + self.router_bias biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32) top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1) top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) top_weights = top_weights.to(hidden_states.dtype) mean_probs = routing_weights.mean(dim=0) aux_loss = self.num_experts * (mean_probs * mean_probs).sum() return top_weights, top_indices, aux_loss class SpiderPortalMoE(nn.Module): def __init__(self, config): super().__init__() self.config = config self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)]) self.shared_expert = SpiderPortalExpert(config) self.router = SpiderPortalRouter(config) def forward(self, hidden_states): batch_size, seq_len, hidden_dim = hidden_states.shape top_weights, top_indices, aux_loss = self.router(hidden_states) flat_hidden = hidden_states.view(-1, hidden_dim) final_output = torch.zeros_like(flat_hidden) for expert_idx in range(self.num_experts_per_tok): expert_ids = top_indices[:, expert_idx] expert_weights = top_weights[:, expert_idx:expert_idx+1] for e in range(self.num_experts): mask = expert_ids == e if mask.any(): expert_output = self.experts[e](flat_hidden[mask]) final_output[mask] += expert_output * expert_weights[mask] shared_output = self.shared_expert(flat_hidden) final_output = final_output + shared_output return final_output.view(batch_size, seq_len, hidden_dim), aux_loss class SpiderPortalDenseLayer(nn.Module): def __init__(self, config): super().__init__() self.self_attn = SpiderPortalGQA(config) dense_intermediate = config.hidden_size * 4 // 3 self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output ffn_input = self.post_attention_layernorm(hidden_states) ffn_output = self.ffn(ffn_input) hidden_states = hidden_states + ffn_output return hidden_states, past_kv class SpiderPortalMoELayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.layer_idx = layer_idx self.self_attn = SpiderPortalGQA(config) self.moe = SpiderPortalMoE(config) self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): attn_input = self.input_layernorm(hidden_states) attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache) hidden_states = hidden_states + attn_output moe_input = self.post_attention_layernorm(hidden_states) moe_output, aux_loss = self.moe(moe_input) hidden_states = hidden_states + moe_output return hidden_states, aux_loss, past_kv class LTIInjection(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0)) self.delta_t = nn.Parameter(torch.tensor(1.0)) self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False) with torch.no_grad(): self.B.weight.data.normal_(mean=0.0, std=0.01) def get_A(self): return -torch.exp(self.log_A) def forward(self, h_t, e): A = self.get_A() return A * h_t + self.B(e) class ACTHalting(nn.Module): def __init__(self, config): super().__init__() self.halt_predictor = nn.Linear(config.hidden_size, 1) self.threshold = config.act_threshold def forward(self, hidden_states): return torch.sigmoid(self.halt_predictor(hidden_states)) class LoRAAdapter(nn.Module): def __init__(self, config): super().__init__() rank = config.lora_rank self.down = nn.Linear(config.hidden_size, rank, bias=False) self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02) self.scale = nn.Embedding(config.max_loop_iters, rank) with torch.no_grad(): self.scale.weight.data.zero_() self.down.weight.data.normal_(mean=0.0, std=0.001) def forward(self, x, loop_t): max_t = self.scale.num_embeddings - 1 t_idx = min(loop_t, max_t) s = self.scale(torch.tensor(t_idx, device=x.device)) down = self.down(x) * s return down @ self.B class SpiderPortalMoEModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)]) self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)]) self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)]) self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.injection = LTIInjection(config) self.act_halting = ACTHalting(config) self.lora_adapter = LoRAAdapter(config) self.loop_embed_dim = config.loop_embed_dim def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None): n_loops = n_loops or self.config.max_loop_iters input_embedding = input_embedding if input_embedding is not None else hidden_states total_aux_loss = 0.0 for layer in self.prelude_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) e = hidden_states.clone() B, T_seq, D = hidden_states.shape halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool) cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype) h_out = torch.zeros_like(hidden_states) past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers) for t in range(n_loops): h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim) if t > 0: injection = self.injection(hidden_states, input_embedding) hidden_states = hidden_states + injection new_past_key_values = [] for i, layer in enumerate(self.recurrent_layers): hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache) total_aux_loss = total_aux_loss + aux_loss new_past_key_values.append(past_kv) lora_delta = self.lora_adapter(hidden_states, t) hidden_states = hidden_states + lora_delta halt_prob = self.act_halting(hidden_states).squeeze(-1) still_running = ~halted remainder = (1.0 - cumulative_p).clamp(min=0) weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob) weight = weight * still_running.to(hidden_states.dtype) h_out = h_out + weight.unsqueeze(-1) * hidden_states cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype) halted = halted | (cumulative_p >= self.config.act_threshold) if halted.all() and not self.training: break never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1) hidden_states = h_out + never_halted * hidden_states for layer in self.coda_layers: hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids) hidden_states = self.norm(hidden_states) return hidden_states, total_aux_loss, new_past_key_values class SpiderPortalForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.model = SpiderPortalMoEModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): if hasattr(self, 'model') and module is self.model.injection.B: return module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False): hidden_states = self.embed_tokens(input_ids) model_dtype = next(self.model.parameters()).dtype hidden_states = hidden_states.to(model_dtype) input_embedding = hidden_states.clone() if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device) causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min) causal_mask = causal_mask.triu(1) hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) loss = loss + self.config.router_aux_loss_coef * aux_loss return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv} def get_num_params(self): total = sum(p.numel() for p in self.parameters()) return {"total": total, "trainable": total} def train_single_gpu(): device = torch.device("cuda") gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)") config = SpiderPortalConfig( hidden_size=384, num_hidden_layers=8, num_attention_heads=8, num_key_value_heads=2, intermediate_size=1024, num_experts=64, num_experts_per_tok=1, num_shared_experts=1, router_aux_loss_coef=0.05, max_loop_iters=2, prelude_layers=2, coda_layers=2, lora_rank=32, rope_theta=10000000.0, rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768}, max_position_embeddings=131072, sliding_window=4096, tie_word_embeddings=True, ) print("Building model...") model = SpiderPortalForConditionalGeneration(config) model = model.to(torch.bfloat16).to(device) params = model.get_num_params() print(f"Model: {params['total']/1e6:.1f}M params") print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared") BASE_LR = 1e-3 WARMUP_STEPS = 500 optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01, betas=(0.9, 0.95)) import pandas as pd data_dir = Path(__file__).parent / "data" all_records = [] pkl_file = data_dir / "spiderportal_combined.pkl" if pkl_file.exists(): print(f"Loading dataset from {pkl_file}...") df = pd.read_pickle(pkl_file) all_records = df.to_dict("records") else: print(f"No dataset found at {pkl_file}, creating synthetic data...") all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)] print(f"Loaded {len(all_records):,} samples") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token BATCH_SIZE = 128 MAX_LEN = 256 EPOCHS = 3 N_LOOPS = 2 print(f"Batch size: {BATCH_SIZE} (no grad accum)") print(f"Effective batch: {BATCH_SIZE}") print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup (high LR for recurrent MoE)") print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}") def build_prompt(sample): instruction = str(sample.get("instruction", "")).strip() inp = str(sample.get("input", "")).strip() output = str(sample.get("output", "")).strip() if inp: return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" return f"Question: Instruction: {instruction}\nAnswer: {output}\n" print("Pre-tokenizing dataset...") prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"] mask_len = len(prefix_ids) pre_tokenized = [] for i, sample in enumerate(all_records): instruction = str(sample.get("instruction", "")).strip() inp = str(sample.get("input", "")).strip() output = str(sample.get("output", "")).strip() if inp: text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token else: text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length") input_ids = enc["input_ids"] labels = input_ids[:] for j in range(min(mask_len, len(labels))): labels[j] = -100 pre_tokenized.append((input_ids, labels)) if (i + 1) % 50000 == 0: print(f" Tokenized {i+1:,}/{len(all_records):,}") print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples") del all_records gc.collect() global_step = 0 best_loss = float('inf') start_time = time.time() checkpoint_dir = Path("checkpoints") checkpoint_dir.mkdir(exist_ok=True) step_ckpt_files = [] for epoch in range(1, EPOCHS + 1): if epoch > 1: for f in step_ckpt_files: if f.exists(): f.unlink() print(f" Deleted old step checkpoint: {f.name}") step_ckpt_files.clear() gc.collect() indices = list(range(len(pre_tokenized))) random.shuffle(indices) total_loss = 0 num_batches = 0 optimizer.zero_grad() for batch_start in range(0, len(indices), BATCH_SIZE): batch_indices = indices[batch_start:batch_start + BATCH_SIZE] if len(batch_indices) < BATCH_SIZE: continue if global_step < WARMUP_STEPS: lr = BASE_LR * (global_step + 1) / WARMUP_STEPS for param_group in optimizer.param_groups: param_group['lr'] = lr batch_input_ids = [] batch_labels = [] for idx in batch_indices: input_ids, labels = pre_tokenized[idx] batch_input_ids.append(input_ids) batch_labels.append(labels) input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device) labels = torch.tensor(batch_labels, dtype=torch.long, device=device) if global_step == 0: print(" [First forward pass - compiling...]") outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS) loss = outputs["loss"] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() optimizer.zero_grad() global_step += 1 total_loss += loss.item() num_batches += 1 if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0: avg_loss = total_loss / max(num_batches, 1) elapsed = time.time() - start_time steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0 current_lr = optimizer.param_groups[0]['lr'] samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0 print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec") if global_step > 0 and global_step % 500 == 0: ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt" state_dict = {k: v.cpu() for k, v in model.state_dict().items()} torch.save(state_dict, ckpt_path) step_ckpt_files.append(ckpt_path) size_mb = ckpt_path.stat().st_size / (1024 * 1024) print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)") avg_loss = total_loss / max(num_batches, 1) epoch_time = (time.time() - start_time) / 60 print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min") ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt" torch.save({ "step": global_step, "epoch": epoch, "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()}, "optimizer_state_dict": optimizer.state_dict(), "config": config.__dict__, }, ckpt_path) size_mb = ckpt_path.stat().st_size / (1024 * 1024) print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)") if avg_loss < best_loss: best_loss = avg_loss best_path = checkpoint_dir / "spiderportal-v5-best.pt" torch.save({ "step": global_step, "epoch": epoch, "model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()}, "optimizer_state_dict": optimizer.state_dict(), "config": config.__dict__, }, best_path) size_mb = best_path.stat().st_size / (1024 * 1024) print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)") total_time = (time.time() - start_time) / 3600 print(f"\nTraining complete!") print(f"Best loss: {best_loss:.4f}") print(f"Total time: {total_time:.2f} hours") print(f"Total steps: {global_step}") print(f"Checkpoints saved to: {checkpoint_dir}") if __name__ == "__main__": train_single_gpu()