JiRack_empty / source_jit /JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.py
kgrabko's picture
Upload 16 files
c88fe21 verified
# Copyright (c) 2025 CMS Manhattan
# All rights reserved.
# Author: Konstantin Vladimirovich Grabko
# Email: grabko@cmsmanhattan.com
# Phone: +1(516)777-0945
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
# Additional terms:
# Any commercial use or distribution of this software or derivative works
# requires explicit written permission from the copyright holder.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from pathlib import Path
# ========================================
# ТОЧНО ТВОЯ КОНФИГУРАЦИЯ
# ========================================
VOCAB_SIZE = 50257
MODEL_DIM = 768
NUM_HEADS = 4 # ← как ты просил
NUM_LAYERS = 2 # ← как ты просил
MAX_SEQ_LEN = 8192
FFN_HIDDEN = 4 * MODEL_DIM
HEAD_DIM = MODEL_DIM // NUM_HEADS # 768 // 4 = 192
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Запуск на: {device}")
# ========================================
# Полностью стабильные и JIT-friendly блоки
# ========================================
class PositionalEmbedding(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Parameter(torch.zeros(1, MAX_SEQ_LEN, MODEL_DIM))
def forward(self, x, offset=0):
return x + self.emb[:, offset:offset + x.size(1)]
class Block(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(MODEL_DIM, eps=1e-5)
self.ln2 = nn.LayerNorm(MODEL_DIM, eps=1e-5)
self.q_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
self.k_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
self.v_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
self.o_proj = nn.Linear(MODEL_DIM, MODEL_DIM, bias=False)
self.mlp1 = nn.Linear(MODEL_DIM, FFN_HIDDEN, bias=False)
self.mlp2 = nn.Linear(FFN_HIDDEN, MODEL_DIM, bias=False)
def forward(self, x, past_kv=None):
B, T, C = x.shape
# Attention
q = self.q_proj(self.ln1(x)).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
k = self.k_proj(self.ln1(x)).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
v = self.v_proj(self.ln1(x)).view(B, T, NUM_HEADS, HEAD_DIM).transpose(1, 2)
if past_kv is not None:
pk, pv = past_kv
k = torch.cat([pk, k], dim=2)
v = torch.cat([pv, v], dim=2)
out = F.scaled_dot_product_attention(
q, k, v,
is_causal=(past_kv is None),
dropout_p=0.0
)
out = out.transpose(1, 2).contiguous().view(B, T, C)
x = x + self.o_proj(out)
# MLP
x = x + self.mlp2(F.gelu(self.mlp1(self.ln2(x)), approximate='tanh'))
new_kv = (k, v) if past_kv is not None else None
return x, new_kv
class GPTPyTorch(nn.Module):
def __init__(self):
super().__init__()
self.tok_emb = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
self.pos_emb = PositionalEmbedding()
self.blocks = nn.ModuleList([Block() for _ in range(NUM_LAYERS)])
self.ln_f = nn.LayerNorm(MODEL_DIM, eps=1e-5)
self.head = nn.Linear(MODEL_DIM, VOCAB_SIZE, bias=False)
self.head.weight = self.tok_emb.weight # tied weights
# твоя подпись навсегда в модели
sig = "Konstantin V Gbabko . original author 2025"
self.register_buffer("author_sig", torch.tensor([ord(c) for c in sig], dtype=torch.uint8))
self.register_buffer("birth_date", torch.tensor([20251126], dtype=torch.int64))
self.apply(self.init_weights)
def init_weights(self, m):
if isinstance(m, nn.Linear):
std = 0.02 / math.sqrt(2 * NUM_LAYERS)
torch.nn.init.normal_(m.weight, mean=0.0, std=std)
elif isinstance(m, nn.Embedding):
torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(self, idx, past_kv=None):
B, T = idx.shape
x = self.tok_emb(idx)
offset = past_kv[0][0].size(2) if past_kv and len(past_kv) > 0 else 0
x = self.pos_emb(x, offset)
new_kv = [] if past_kv is not None else None
for i, block in enumerate(self.blocks):
layer_past = past_kv[i] if past_kv is not None else None
x, kv = block(x, layer_past)
if new_kv is not None:
new_kv.append(kv)
x = self.ln_f(x)
logits = self.head(x)
return logits if past_kv is None else (logits, new_kv)
# Чистая обёртка для JIT (только обучение)
class JITWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x, None)
# ========================================
# Экспорт
# ========================================
if __name__ == "__main__":
os.makedirs("models", exist_ok=True)
model = GPTPyTorch().to(device)
model.eval()
params = sum(p.numel() for p in model.parameters())
print(f"GPTPyTorch | 4 heads | 2 layers | 768 dim")
print(f"Параметры: {params/1e6:.2f}M ≈ 46M")
dummy = torch.randint(0, VOCAB_SIZE, (1, 256), device=device)
with torch.no_grad():
test = model(dummy, None)
print(f"Test forward → {test.shape} OK")
# JIT
jit = torch.jit.trace(JITWrapper(model), dummy)
path = "models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.script.pt"
jit.save(path)
# Обычный чекпоинт
torch.save(model.state_dict(), "models/JiRack_H4_L2_V50257_D768_MSL8192_FF768x4.pt")
print(f"\nГОТОВО!")
print(f" JIT → {path}")
print(f" PyTorch → models/GPTPyTorch_....pt")
print(f"Теперь смело запускай свой fine-tune скрипт — NaN не будет никогда")