SemiticGPT-3B / model_arch.py
ronnengmail's picture
Upload model_arch.py with huggingface_hub
2e0bc42 verified
"""Shared model architecture for multilingual 3B GPT — must match training exactly."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
VOCAB_SIZE = 32000
DIM = 3072
DEPTH = 26
N_HEADS = 24
HEAD_DIM = DIM // N_HEADS # 128
MAX_SEQ_LEN = 2048
ROPE_THETA = 10000.0
HIDDEN_DIM = ((int(2 * DIM * 4 / 3) + 63) // 64) * 64 # SwiGLU hidden
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
def precompute_freqs_cis(dim, max_seq_len, theta=ROPE_THETA):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs)
def apply_rotary_emb(x, freqs_cis):
# x: (B, n_heads, S, head_dim)
B, H, S, D = x.shape
x_complex = torch.view_as_complex(x.float().reshape(B, H, S, D // 2, 2))
freqs = freqs_cis[:S].unsqueeze(0).unsqueeze(1) # (1, 1, S, D//2)
x_rot = torch.view_as_real(x_complex * freqs).reshape(B, H, S, D)
return x_rot.type_as(x)
class FusedAttention(nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x, freqs_cis, mask=None):
B, S, D = x.shape
qkv = self.qkv(x).reshape(B, S, 3, self.n_heads, self.head_dim)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
q = q.transpose(1, 2) # (B, H, S, D)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
# Scaled dot-product attention
scale = math.sqrt(self.head_dim)
attn = (q @ k.transpose(-2, -1)) / scale
if mask is not None:
attn = attn + mask
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, S, D)
return self.out_proj(out)
class SwiGLUFFN(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, dim, n_heads, hidden_dim):
super().__init__()
self.attn_norm = RMSNorm(dim)
self.attn = FusedAttention(dim, n_heads)
self.ffn_norm = RMSNorm(dim)
self.ffn = SwiGLUFFN(dim, hidden_dim)
def forward(self, x, freqs_cis, mask=None):
x = x + self.attn(self.attn_norm(x), freqs_cis, mask)
x = x + self.ffn(self.ffn_norm(x))
return x
class MultilingualGPT(nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM)
self.layers = nn.ModuleList([
TransformerBlock(DIM, N_HEADS, HIDDEN_DIM) for _ in range(DEPTH)
])
self.norm = RMSNorm(DIM)
self.head = nn.Linear(DIM, VOCAB_SIZE, bias=False)
# Tied embeddings
self.head.weight = self.tok_emb.weight
# Precompute RoPE
self.register_buffer('freqs_cis', precompute_freqs_cis(HEAD_DIM, MAX_SEQ_LEN))
def forward(self, tokens, targets=None):
B, S = tokens.shape
x = self.tok_emb(tokens)
mask = torch.triu(torch.full((S, S), float('-inf'), device=tokens.device), diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, S, S)
for layer in self.layers:
x = layer(x, self.freqs_cis, mask)
x = self.norm(x)
logits = self.head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
return logits, loss
def load_model(path, device='cuda'):
"""Load model from checkpoint, stripping prefixes."""
model = MultilingualGPT()
ckpt = torch.load(path, map_location='cpu', weights_only=False)
state = ckpt.get('model_state_dict', ckpt)
# Strip prefixes
cleaned = {}
for k, v in state.items():
new_k = k
for prefix in ['_orig_mod.', 'module.']:
if new_k.startswith(prefix):
new_k = new_k[len(prefix):]
cleaned[new_k] = v
# Handle tied weights - remove head.weight if present (will be tied)
if 'head.weight' in cleaned and 'tok_emb.weight' in cleaned:
if torch.equal(cleaned['head.weight'], cleaned['tok_emb.weight']):
del cleaned['head.weight']
model.load_state_dict(cleaned, strict=False)
model = model.to(device).eval()
return model
def load_tokenizer(path):
"""Load SentencePiece tokenizer."""
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(path)
return sp