|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from pathlib import Path
|
|
|
|
|
|
|
|
|
| VOCAB_SIZE = 50257
|
| MODEL_DIM = 768
|
| NUM_HEADS = 4
|
| NUM_LAYERS = 2
|
| MAX_SEQ_LEN = 8192
|
| FFN_HIDDEN = 4 * MODEL_DIM
|
| HEAD_DIM = MODEL_DIM // NUM_HEADS
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Запуск на: {device}")
|
|
|
|
|
|
|
|
|
|
|
| class PositionalEmbedding(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.emb = nn.Parameter(torch.zeros(1, MAX_SEQ_LEN, MODEL_DIM))
|
|
|
| def forward(self, x, offset=0):
|
| return x + self.emb[:, offset:offset + x.size(1)]
|
|
|
| class Block(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.ln1 = nn.LayerNorm(MODEL_DIM, eps=1e-5)
|
| self.ln2 = nn.LayerNorm(MODEL_DIM, eps=1e-5)
|
|
|
| self.q_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
|
| self.k_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
|
| self.v_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
|
| self.o_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
|
|
|
| self.mlp1 = nn.Linear(MODEL_DIM, FFN_HIDDEN, bias=False)
|
| self.mlp2 = nn.Linear(FFN_HIDDEN, MODEL_DIM, bias=False)
|
|
|
| def forward(self, x, past_kv=None):
|
| B, T, C = x.shape
|
|
|
|
|
| q = self.q_proj(self.ln1(x)).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
|
| k = self.k_proj(self.ln1(x)).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
|
| v = self.v_proj(self.ln1(x)).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
|
|
|
| if past_kv is not None:
|
| pk, pv = past_kv
|
| k = torch.cat([pk, k], dim=2)
|
| v = torch.cat([pv, v], dim=2)
|
|
|
| out = F.scaled_dot_product_attention(
|
| q, k, v,
|
| is_causal=(past_kv is None),
|
| dropout_p=0.0
|
| )
|
| out = out.transpose(1, 2).contiguous().view(B, T, C)
|
| x = x + self.o_proj(out)
|
|
|
|
|
| x = x + self.mlp2(F.gelu(self.mlp1(self.ln2(x)), approximate='tanh'))
|
|
|
| new_kv = (k, v) if past_kv is not None else None
|
| return x, new_kv
|
|
|
|
|
| class GPTPyTorch(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.tok_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
|
| self.pos_emb = PositionalEmbedding()
|
| self.blocks = nn.ModuleList([Block() for _ in range(NUM_LAYERS)])
|
| self.ln_f = nn.LayerNorm(MODEL_DIM, eps=1e-5)
|
| self.head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False)
|
|
|
| self.head.weight = self.tok_emb.weight
|
|
|
|
|
| sig = "Konstantin V Gbabko . original author 2025"
|
| self.register_buffer("author_sig", torch.tensor([ord(c) for c in sig], dtype=torch.uint8))
|
| self.register_buffer("birth_date", torch.tensor([20251126], dtype=torch.int64))
|
|
|
| self.apply(self.init_weights)
|
|
|
| def init_weights(self, m):
|
| if isinstance(m, nn.Linear):
|
| std = 0.02 / math.sqrt(2 * NUM_LAYERS)
|
| torch.nn.init.normal_(m.weight, mean=0.0, std=std)
|
| elif isinstance(m, nn.Embedding):
|
| torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
|
|
|
| def forward(self, idx, past_kv=None):
|
| B, T = idx.shape
|
| x = self.tok_emb(idx)
|
|
|
| offset = past_kv[0][0].size(2) if past_kv and len(past_kv) > 0 else 0
|
| x = self.pos_emb(x, offset)
|
|
|
| new_kv = [] if past_kv is not None else None
|
|
|
| for i, block in enumerate(self.blocks):
|
| layer_past = past_kv[i] if past_kv is not None else None
|
| x, kv = block(x, layer_past)
|
| if new_kv is not None:
|
| new_kv.append(kv)
|
|
|
| x = self.ln_f(x)
|
| logits = self.head(x)
|
|
|
| return logits if past_kv is None else (logits, new_kv)
|
|
|
|
|
|
|
| class JITWrapper(nn.Module):
|
| def __init__(self, model):
|
| super().__init__()
|
| self.model = model
|
| def forward(self, x):
|
| return self.model(x, None)
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| os.makedirs("models", exist_ok=True)
|
|
|
| model = GPTPyTorch().to(device)
|
| model.eval()
|
|
|
| params = sum(p.numel() for p in model.parameters())
|
| print(f"GPTPyTorch | 4 heads | 2 layers | 768 dim")
|
| print(f"Параметры: {params/1e6:.2f}M ≈ 46M")
|
|
|
| dummy = torch.randint(0, VOCAB_SIZE, (1, 256), device=device)
|
|
|
| with torch.no_grad():
|
| test = model(dummy, None)
|
| print(f"Test forward → {test.shape} OK")
|
|
|
|
|
| jit = torch.jit.trace(JITWrapper(model), dummy)
|
| path = "models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.script.pt"
|
| jit.save(path)
|
|
|
|
|
| torch.save(model.state_dict(), "models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.pt")
|
|
|
| print(f"\nГОТОВО!")
|
| print(f" JIT → {path}")
|
| print(f" PyTorch → models/GPTPyTorch_....pt")
|
| print(f"Теперь смело запускай свой fine-tune скрипт — NaN не будет никогда") |