at / train_single_gpu.py
CLIWorks's picture
Upload train_single_gpu.py
e9af599 verified
raw
history blame
28.6 kB
#!/usr/bin/env python3
"""SpiderPortal v5 — Single-GPU Optimized Training.
For RTX PRO 6000 (96GB) or similar large-VRAM GPU.
No DDP, maximal batch size, torch.compile, pre-tokenized data.
Usage:
python train_single_gpu.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import json
import gc
import random
import time
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List
from torch.nn import CrossEntropyLoss
@dataclass
class SpiderPortalConfig:
vocab_size: int = 50278
hidden_size: int = 384
num_hidden_layers: int = 8
num_attention_heads: int = 8
num_key_value_heads: int = 2
intermediate_size: int = 1024
hidden_act: str = "silu"
num_experts: int = 64
num_experts_per_tok: int = 1
num_shared_experts: int = 1
router_aux_loss_coef: float = 0.05
max_loop_iters: int = 4
act_threshold: float = 0.5
max_position_embeddings: int = 131072
rope_theta: float = 10000000.0
rope_scaling: dict = None
sliding_window: int = 4096
attention_dropout: float = 0.0
rms_norm_eps: float = 1e-6
initializer_range: float = 0.02
use_cache: bool = True
tie_word_embeddings: bool = True
prelude_layers: int = 2
coda_layers: int = 2
lora_rank: int = 32
loop_embed_dim: int = 48
vision_hidden_size: int = 384
audio_hidden_size: int = 512
vision_num_frames: int = 60
vision_tokens_per_frame: int = 256
vision_temporal_tokens: int = 64
vision_temporal_layers: int = 2
model_type: str = "spiderportal"
torch_dtype: str = "bfloat16"
def loop_index_embedding(h, loop_t, loop_dim, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, loop_dim, 2, device=h.device, dtype=h.dtype) / loop_dim))
angles = loop_t * freqs
emb = torch.cat([angles.sin(), angles.cos()], dim=-1)[:loop_dim]
emb_full = torch.zeros(h.shape[-1], device=h.device, dtype=h.dtype)
emb_full[:loop_dim] = emb
return h + emb_full.unsqueeze(0).unsqueeze(0)
class SpiderPortalRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight.to(input_dtype) * hidden_states.to(input_dtype)
def compute_yarn_inv_freq(head_dim, rope_theta, factor, orig_max, beta_fast=32.0, beta_slow=1.0):
dim = head_dim
orig_inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2).float() / dim))
pos_freqs = torch.arange(0, dim, 2).float() / dim
beta = (pos_freqs * math.log(rope_theta) / math.log(orig_max))
scale = torch.where(beta < beta_slow, torch.ones_like(beta), torch.where(beta > beta_fast, torch.ones_like(beta) / factor, 1.0 - (beta - beta_slow) / (beta_fast - beta_slow) * (1.0 - 1.0 / factor)))
return orig_inv_freq * scale
class SpiderPortalGQA(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.attention_dropout = config.attention_dropout
rope_scaling = getattr(config, 'rope_scaling', None)
if rope_scaling and rope_scaling.get("type") == "yarn":
factor = rope_scaling.get("factor", 1.0)
orig_max_pos = rope_scaling.get("original_max_position_embeddings", config.max_position_embeddings)
inv_freq = compute_yarn_inv_freq(self.head_dim, config.rope_theta, factor, orig_max_pos)
else:
inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _rotate_half(self, x):
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary(self, x, cos, sin):
return (x * cos) + (self._rotate_half(x) * sin)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
if position_ids is None:
position_ids = torch.arange(q_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
max_pos = position_ids.max().item() + 1
seq_len = max(max_pos, q_len)
t = torch.arange(seq_len, device=hidden_states.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos, sin = emb.cos(), emb.sin()
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
query_states = self._apply_rotary(query_states, cos, sin)
key_states = self._apply_rotary(key_states, cos, sin)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_kv = (key_states, value_states) if use_cache else None
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
return self.o_proj(attn_output), past_kv
class SpiderPortalExpert(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
inter_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, inter_size, bias=False)
self.down_proj = nn.Linear(inter_size, config.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, hidden_states):
return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
class SpiderPortalRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.weight = nn.Parameter(torch.randn(config.hidden_size, config.num_experts) * config.initializer_range)
self.register_buffer("router_bias", torch.zeros(config.num_experts))
def forward(self, hidden_states):
router_logits = hidden_states.view(-1, hidden_states.size(-1)) @ self.weight
routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32)
biased_logits = router_logits + self.router_bias
biased_weights = F.softmax(biased_logits, dim=-1, dtype=torch.float32)
top_weights, top_indices = torch.topk(biased_weights, self.num_experts_per_tok, dim=-1)
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
top_weights = top_weights.to(hidden_states.dtype)
mean_probs = routing_weights.mean(dim=0)
aux_loss = self.num_experts * (mean_probs * mean_probs).sum()
return top_weights, top_indices, aux_loss
class SpiderPortalMoE(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.experts = nn.ModuleList([SpiderPortalExpert(config) for _ in range(config.num_experts)])
self.shared_expert = SpiderPortalExpert(config)
self.router = SpiderPortalRouter(config)
def forward(self, hidden_states):
batch_size, seq_len, hidden_dim = hidden_states.shape
top_weights, top_indices, aux_loss = self.router(hidden_states)
flat_hidden = hidden_states.view(-1, hidden_dim)
final_output = torch.zeros_like(flat_hidden)
for expert_idx in range(self.num_experts_per_tok):
expert_ids = top_indices[:, expert_idx]
expert_weights = top_weights[:, expert_idx:expert_idx+1]
for e in range(self.num_experts):
mask = expert_ids == e
if mask.any():
expert_output = self.experts[e](flat_hidden[mask])
final_output[mask] += expert_output * expert_weights[mask]
shared_output = self.shared_expert(flat_hidden)
final_output = final_output + shared_output
return final_output.view(batch_size, seq_len, hidden_dim), aux_loss
class SpiderPortalDenseLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = SpiderPortalGQA(config)
dense_intermediate = config.hidden_size * 4 // 3
self.ffn = SpiderPortalExpert(config, intermediate_size=dense_intermediate)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
ffn_input = self.post_attention_layernorm(hidden_states)
ffn_output = self.ffn(ffn_input)
hidden_states = hidden_states + ffn_output
return hidden_states, past_kv
class SpiderPortalMoELayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.self_attn = SpiderPortalGQA(config)
self.moe = SpiderPortalMoE(config)
self.input_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
attn_input = self.input_layernorm(hidden_states)
attn_output, past_kv = self.self_attn(attn_input, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache)
hidden_states = hidden_states + attn_output
moe_input = self.post_attention_layernorm(hidden_states)
moe_output, aux_loss = self.moe(moe_input)
hidden_states = hidden_states + moe_output
return hidden_states, aux_loss, past_kv
class LTIInjection(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.log_A = nn.Parameter(torch.full((config.hidden_size,), -2.0))
self.delta_t = nn.Parameter(torch.tensor(1.0))
self.B = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
with torch.no_grad():
self.B.weight.data.normal_(mean=0.0, std=0.01)
def get_A(self):
return -torch.exp(self.log_A)
def forward(self, h_t, e):
A = self.get_A()
return A * h_t + self.B(e)
class ACTHalting(nn.Module):
def __init__(self, config):
super().__init__()
self.halt_predictor = nn.Linear(config.hidden_size, 1)
self.threshold = config.act_threshold
def forward(self, hidden_states):
return torch.sigmoid(self.halt_predictor(hidden_states))
class LoRAAdapter(nn.Module):
def __init__(self, config):
super().__init__()
rank = config.lora_rank
self.down = nn.Linear(config.hidden_size, rank, bias=False)
self.B = nn.Parameter(torch.randn(rank, config.hidden_size) * 0.02)
self.scale = nn.Embedding(config.max_loop_iters, rank)
with torch.no_grad():
self.scale.weight.data.zero_()
self.down.weight.data.normal_(mean=0.0, std=0.001)
def forward(self, x, loop_t):
max_t = self.scale.num_embeddings - 1
t_idx = min(loop_t, max_t)
s = self.scale(torch.tensor(t_idx, device=x.device))
down = self.down(x) * s
return down @ self.B
class SpiderPortalMoEModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.prelude_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.prelude_layers)])
self.recurrent_layers = nn.ModuleList([SpiderPortalMoELayer(config, i) for i in range(config.num_hidden_layers)])
self.coda_layers = nn.ModuleList([SpiderPortalDenseLayer(config) for _ in range(config.coda_layers)])
self.norm = SpiderPortalRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.injection = LTIInjection(config)
self.act_halting = ACTHalting(config)
self.lora_adapter = LoRAAdapter(config)
self.loop_embed_dim = config.loop_embed_dim
def forward(self, hidden_states, input_embedding=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, n_loops=None):
n_loops = n_loops or self.config.max_loop_iters
input_embedding = input_embedding if input_embedding is not None else hidden_states
total_aux_loss = 0.0
for layer in self.prelude_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
e = hidden_states.clone()
B, T_seq, D = hidden_states.shape
halted = torch.zeros(B, T_seq, device=hidden_states.device, dtype=torch.bool)
cumulative_p = torch.zeros(B, T_seq, device=hidden_states.device, dtype=hidden_states.dtype)
h_out = torch.zeros_like(hidden_states)
past_key_values = past_key_values if past_key_values is not None else [None] * len(self.recurrent_layers)
for t in range(n_loops):
h_loop = loop_index_embedding(hidden_states, t, self.loop_embed_dim)
if t > 0:
injection = self.injection(hidden_states, input_embedding)
hidden_states = hidden_states + injection
new_past_key_values = []
for i, layer in enumerate(self.recurrent_layers):
hidden_states, aux_loss, past_kv = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values[i] if t == 0 else None, use_cache=use_cache)
total_aux_loss = total_aux_loss + aux_loss
new_past_key_values.append(past_kv)
lora_delta = self.lora_adapter(hidden_states, t)
hidden_states = hidden_states + lora_delta
halt_prob = self.act_halting(hidden_states).squeeze(-1)
still_running = ~halted
remainder = (1.0 - cumulative_p).clamp(min=0)
weight = torch.where(cumulative_p + halt_prob >= self.config.act_threshold, remainder, halt_prob)
weight = weight * still_running.to(hidden_states.dtype)
h_out = h_out + weight.unsqueeze(-1) * hidden_states
cumulative_p = cumulative_p + halt_prob * still_running.to(hidden_states.dtype)
halted = halted | (cumulative_p >= self.config.act_threshold)
if halted.all() and not self.training:
break
never_halted = (~halted).to(hidden_states.dtype).unsqueeze(-1)
hidden_states = h_out + never_halted * hidden_states
for layer in self.coda_layers:
hidden_states, _ = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
hidden_states = self.norm(hidden_states)
return hidden_states, total_aux_loss, new_past_key_values
class SpiderPortalForConditionalGeneration(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.model = SpiderPortalMoEModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
if hasattr(self, 'model') and module is self.model.injection.B:
return
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, n_loops=None, use_cache=False):
hidden_states = self.embed_tokens(input_ids)
model_dtype = next(self.model.parameters()).dtype
hidden_states = hidden_states.to(model_dtype)
input_embedding = hidden_states.clone()
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
causal_mask = torch.full((attention_mask.size(0), 1, attention_mask.size(1), attention_mask.size(1)), 0.0, dtype=hidden_states.dtype, device=hidden_states.device)
causal_mask = causal_mask.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(2), torch.finfo(hidden_states.dtype).min)
causal_mask = causal_mask.triu(1)
hidden_states, aux_loss, past_kv = self.model(hidden_states, input_embedding=input_embedding, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, n_loops=n_loops)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss + self.config.router_aux_loss_coef * aux_loss
return {"loss": loss, "logits": logits, "aux_loss": aux_loss, "past_key_values": past_kv}
def get_num_params(self):
total = sum(p.numel() for p in self.parameters())
return {"total": total, "trainable": total}
def train_single_gpu():
device = torch.device("cuda")
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name} ({gpu_mem:.1f}GB)")
config = SpiderPortalConfig(
hidden_size=384, num_hidden_layers=8, num_attention_heads=8,
num_key_value_heads=2, intermediate_size=1024,
num_experts=64, num_experts_per_tok=1, num_shared_experts=1,
router_aux_loss_coef=0.05, max_loop_iters=2,
prelude_layers=2, coda_layers=2, lora_rank=32,
rope_theta=10000000.0,
rope_scaling={"type": "yarn", "factor": 4.0, "original_max_position_embeddings": 32768},
max_position_embeddings=131072, sliding_window=4096,
tie_word_embeddings=True,
)
print("Building model...")
model = SpiderPortalForConditionalGeneration(config)
model = model.to(torch.bfloat16).to(device)
params = model.get_num_params()
print(f"Model: {params['total']/1e6:.1f}M params")
print(f"Experts: {config.num_experts} routed + {config.num_shared_experts} shared")
BASE_LR = 1e-3
WARMUP_STEPS = 500
optimizer = torch.optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=0.01, betas=(0.9, 0.95))
import pandas as pd
data_dir = Path(__file__).parent / "data"
all_records = []
pkl_file = data_dir / "spiderportal_combined.pkl"
if pkl_file.exists():
print(f"Loading dataset from {pkl_file}...")
df = pd.read_pickle(pkl_file)
all_records = df.to_dict("records")
else:
print(f"No dataset found at {pkl_file}, creating synthetic data...")
all_records = [{"instruction": f"Question {i}: What is {i} + {i}?", "input": "", "output": f"The answer is {i+i}."} for i in range(10000)]
print(f"Loaded {len(all_records):,} samples")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
BATCH_SIZE = 128
MAX_LEN = 256
EPOCHS = 3
N_LOOPS = 2
print(f"Batch size: {BATCH_SIZE} (no grad accum)")
print(f"Effective batch: {BATCH_SIZE}")
print(f"LR: {BASE_LR} with {WARMUP_STEPS}-step warmup (high LR for recurrent MoE)")
print(f"Max seq len: {MAX_LEN}, N_LOOPS: {N_LOOPS}")
def build_prompt(sample):
instruction = str(sample.get("instruction", "")).strip()
inp = str(sample.get("input", "")).strip()
output = str(sample.get("output", "")).strip()
if inp:
return f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n"
return f"Question: Instruction: {instruction}\nAnswer: {output}\n"
print("Pre-tokenizing dataset...")
prefix_ids = tokenizer("Question:", add_special_tokens=False)["input_ids"]
mask_len = len(prefix_ids)
pre_tokenized = []
for i, sample in enumerate(all_records):
instruction = str(sample.get("instruction", "")).strip()
inp = str(sample.get("input", "")).strip()
output = str(sample.get("output", "")).strip()
if inp:
text = f"Question: Instruction: {instruction}\nInput: {inp}\nAnswer: {output}\n" + tokenizer.eos_token
else:
text = f"Question: Instruction: {instruction}\nAnswer: {output}\n" + tokenizer.eos_token
enc = tokenizer(text, truncation=True, max_length=MAX_LEN, padding="max_length")
input_ids = enc["input_ids"]
labels = input_ids[:]
for j in range(min(mask_len, len(labels))):
labels[j] = -100
pre_tokenized.append((input_ids, labels))
if (i + 1) % 50000 == 0:
print(f" Tokenized {i+1:,}/{len(all_records):,}")
print(f"Pre-tokenization complete: {len(pre_tokenized):,} samples")
del all_records
gc.collect()
global_step = 0
best_loss = float('inf')
start_time = time.time()
checkpoint_dir = Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
step_ckpt_files = []
for epoch in range(1, EPOCHS + 1):
if epoch > 1:
for f in step_ckpt_files:
if f.exists():
f.unlink()
print(f" Deleted old step checkpoint: {f.name}")
step_ckpt_files.clear()
gc.collect()
indices = list(range(len(pre_tokenized)))
random.shuffle(indices)
total_loss = 0
num_batches = 0
optimizer.zero_grad()
for batch_start in range(0, len(indices), BATCH_SIZE):
batch_indices = indices[batch_start:batch_start + BATCH_SIZE]
if len(batch_indices) < BATCH_SIZE:
continue
if global_step < WARMUP_STEPS:
lr = BASE_LR * (global_step + 1) / WARMUP_STEPS
for param_group in optimizer.param_groups:
param_group['lr'] = lr
batch_input_ids = []
batch_labels = []
for idx in batch_indices:
input_ids, labels = pre_tokenized[idx]
batch_input_ids.append(input_ids)
batch_labels.append(labels)
input_ids = torch.tensor(batch_input_ids, dtype=torch.long, device=device)
labels = torch.tensor(batch_labels, dtype=torch.long, device=device)
if global_step == 0:
print(" [First forward pass - compiling...]")
outputs = model(input_ids=input_ids, labels=labels, n_loops=N_LOOPS)
loss = outputs["loss"]
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
global_step += 1
total_loss += loss.item()
num_batches += 1
if (batch_start // BATCH_SIZE) == 0 or global_step < 20 or global_step % 100 == 0:
avg_loss = total_loss / max(num_batches, 1)
elapsed = time.time() - start_time
steps_per_hour = (global_step + 1) / elapsed * 3600 if elapsed > 0 else 0
current_lr = optimizer.param_groups[0]['lr']
samples_per_sec = (global_step * BATCH_SIZE) / elapsed if elapsed > 0 else 0
print(f"Epoch {epoch}/{EPOCHS} | Step {global_step} | loss={avg_loss:.4f} | LR={current_lr:.2e} | {steps_per_hour:.0f} steps/hr | {samples_per_sec:.0f} samples/sec")
if global_step > 0 and global_step % 500 == 0:
ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}-step{global_step}.pt"
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
torch.save(state_dict, ckpt_path)
step_ckpt_files.append(ckpt_path)
size_mb = ckpt_path.stat().st_size / (1024 * 1024)
print(f"Saved weights-only checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
avg_loss = total_loss / max(num_batches, 1)
epoch_time = (time.time() - start_time) / 60
print(f"Epoch {epoch}/{EPOCHS} complete | avg_loss={avg_loss:.4f} | Time: {epoch_time:.1f}min")
ckpt_path = checkpoint_dir / f"spiderportal-v5-ep{epoch}.pt"
torch.save({
"step": global_step,
"epoch": epoch,
"model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
"optimizer_state_dict": optimizer.state_dict(),
"config": config.__dict__,
}, ckpt_path)
size_mb = ckpt_path.stat().st_size / (1024 * 1024)
print(f"Saved epoch checkpoint: {ckpt_path.name} ({size_mb:.0f}MB)")
if avg_loss < best_loss:
best_loss = avg_loss
best_path = checkpoint_dir / "spiderportal-v5-best.pt"
torch.save({
"step": global_step,
"epoch": epoch,
"model_state_dict": {k: v.cpu() for k, v in model.state_dict().items()},
"optimizer_state_dict": optimizer.state_dict(),
"config": config.__dict__,
}, best_path)
size_mb = best_path.stat().st_size / (1024 * 1024)
print(f"Saved best checkpoint: {best_path.name} ({size_mb:.0f}MB)")
total_time = (time.time() - start_time) / 3600
print(f"\nTraining complete!")
print(f"Best loss: {best_loss:.4f}")
print(f"Total time: {total_time:.2f} hours")
print(f"Total steps: {global_step}")
print(f"Checkpoints saved to: {checkpoint_dir}")
if __name__ == "__main__":
train_single_gpu()