| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.utils.data import Dataset, DataLoader, random_split |
| import urllib.request |
| import os |
| from transformers import AutoTokenizer, logging |
| import pandas as pd |
| from tqdm import tqdm |
| from safetensors.torch import save_file |
|
|
| logging.set_verbosity_error() |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| SAVE_EVERY = 5 |
| MODEL_NAME = "mini_transformer_v3" |
| N_DATA_WORKERS = 8 |
| PIN_MEMORY = True if N_DATA_WORKERS > 0 and torch.cuda.is_available() else False |
| BATCH_SIZE = 512 |
| EVAL_EVERY = 5 |
| LEARNING_RATE = 3e-4 |
| NUM_EPOCHS = 50 |
| USE_AMP = True |
| STRIDE = 64 |
| CHECKPOINT_DIR = f"MODELS/checkpoints/{MODEL_NAME}" |
| os.makedirs(CHECKPOINT_DIR, exist_ok=True) |
| DATASET = "DATA/generated_dataset_very_big.csv" |
|
|
| CONTEXT_LENGTH = 128 |
| EMBEDDING_DIMENSION = 512 |
| HEAD_NUMBER = 4 |
| N_LAYER = 4 |
| |
|
|
|
|
| |
| class TransformerBlock(nn.Module): |
| def __init__(self, emb_dim, num_heads, context_length, dropout=0.1): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(emb_dim) |
| self.ln2 = nn.LayerNorm(emb_dim) |
| self.attn = nn.MultiheadAttention( |
| emb_dim, num_heads, dropout=dropout, batch_first=True |
| ) |
| self.mlp = nn.Sequential( |
| nn.Linear(emb_dim, 4 * emb_dim), |
| nn.GELU(), |
| nn.Linear(4 * emb_dim, emb_dim), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x): |
| attn_out, _ = self.attn( |
| self.ln1(x), self.ln1(x), self.ln1(x), need_weights=False |
| ) |
| x = x + attn_out |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class MiniTransformer(nn.Module): |
| def __init__( |
| self, |
| vocab_size, |
| emb_dim, |
| context_length, |
| num_heads, |
| num_layers, |
| dropout=0.1, |
| ): |
| super().__init__() |
| self.emb = nn.Embedding(vocab_size, emb_dim) |
| self.pos_emb = nn.Embedding(context_length, emb_dim) |
| self.blocks = nn.Sequential( |
| *[ |
| TransformerBlock(emb_dim, num_heads, context_length, dropout) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.ln_f = nn.LayerNorm(emb_dim) |
| self.head = nn.Linear(emb_dim, vocab_size, bias=False) |
| self.context_length = context_length |
|
|
| def forward(self, x): |
| B, T = x.shape |
| pos = torch.arange(T, device=x.device) |
| x = self.emb(x) + self.pos_emb(pos) |
| x = self.blocks(x) |
| x = self.ln_f(x) |
| logits = self.head(x) |
| return logits |
|
|
|
|
| |
| class SlidingWindowDataset(Dataset): |
| def __init__(self, texts, tokenizer, context_length=128, stride=64): |
| self.tokenizer = tokenizer |
| self.context_length = context_length |
| self.stride = stride |
|
|
| |
| self.tokens = [] |
| for text in texts: |
| ids = tokenizer.encode(text, add_special_tokens=False) |
| self.tokens.extend(ids) |
| self.tokens = torch.tensor(self.tokens, dtype=torch.long) |
|
|
| self.n_samples = (len(self.tokens) - context_length) // stride |
|
|
| def __len__(self): |
| return self.n_samples |
|
|
| def __getitem__(self, idx): |
| start = idx * self.stride |
| end = start + self.context_length + 1 |
| chunk = self.tokens[start:end] |
| x = chunk[:-1] |
| y = chunk[1:] |
| return x, y |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "mps") |
| print(f"Using device: {device}") |
| if device.type == "cuda": |
| print(torch.cuda.get_device_name(0)) |
| print(torch.cuda.memory_allocated() / 1024**2, "MB allocated") |
| print(torch.cuda.memory_reserved() / 1024**2, "MB reserved") |
|
|
|
|
| |
| df = pd.read_csv(DATASET) |
| texts = [ |
| f"{row['system_prompt']} {row['question']} {row['answer']}" |
| for _, row in df.iterrows() |
| ] |
|
|
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
| vocab_size = tokenizer.vocab_size |
|
|
| dataset = SlidingWindowDataset(texts, tokenizer, CONTEXT_LENGTH, STRIDE) |
| train_size = int(0.9 * len(dataset)) |
| test_size = len(dataset) - train_size |
| train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) |
| print(f"dataset train lenght: {len(train_dataset)}") |
| loader_train = DataLoader( |
| train_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=N_DATA_WORKERS, |
| pin_memory=PIN_MEMORY, |
| ) |
| loader_test = DataLoader( |
| test_dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=False, |
| num_workers=N_DATA_WORKERS, |
| pin_memory=PIN_MEMORY, |
| ) |
|
|
|
|
| |
|
|
| model = MiniTransformer( |
| vocab_size=vocab_size, |
| emb_dim=EMBEDDING_DIMENSION, |
| context_length=CONTEXT_LENGTH, |
| num_heads=HEAD_NUMBER, |
| num_layers=N_LAYER, |
| ).to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"number of parameters: {n_params}") |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
| scaler = torch.amp.GradScaler(enabled=USE_AMP and device.type == "cuda") |
| criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) |
|
|
|
|
| |
| checkpoint_files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith(".pt")]) |
| if checkpoint_files: |
| latest_ckpt = os.path.join(CHECKPOINT_DIR, checkpoint_files[-1]) |
| ckpt = torch.load(latest_ckpt, map_location=device) |
| model.load_state_dict(ckpt["model_state"]) |
| optimizer.load_state_dict(ckpt["optimizer_state"]) |
| start_epoch = ckpt["epoch"] + 1 |
| print(f"Resumed from {latest_ckpt}") |
| else: |
| start_epoch = 0 |
|
|
| model = torch.compile(model) |
|
|
| |
| for epoch in range(start_epoch, NUM_EPOCHS): |
| model.train() |
| total_loss = 0 |
|
|
| for x, y in tqdm(loader_train, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"): |
| x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) |
| optimizer.zero_grad() |
|
|
| with torch.amp.autocast( |
| "cuda", dtype=torch.float16, enabled=USE_AMP and device.type == "cuda" |
| ): |
| logits = model(x) |
| loss = criterion(logits.view(-1, vocab_size), y.view(-1)) |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| total_loss += loss.item() * x.size(0) |
|
|
| avg_train_loss = total_loss / len(train_dataset) |
| print(f"Train Loss: {avg_train_loss:.4f}") |
|
|
| |
| if (epoch + 1) % EVAL_EVERY == 0: |
| model.eval() |
| total_loss = 0 |
| with torch.no_grad(): |
| for x, y in loader_test: |
| x, y = x.to(device), y.to(device) |
| with torch.amp.autocast( |
| "cuda", |
| dtype=torch.float16, |
| enabled=USE_AMP and device.type == "cuda", |
| ): |
| logits = model(x) |
| loss = criterion(logits.view(-1, vocab_size), y.view(-1)) |
| total_loss += loss.item() * x.size(0) |
| avg_test_loss = total_loss / len(test_dataset) |
| print(f"Test Loss: {avg_test_loss:.4f}") |
|
|
| |
| if SAVE_EVERY > 0 and (epoch + 1) % SAVE_EVERY == 0: |
| torch.save( |
| { |
| "epoch": epoch, |
| "model_state": model.state_dict(), |
| "optimizer_state": optimizer.state_dict(), |
| "scaler_state": scaler.state_dict(), |
| }, |
| os.path.join(CHECKPOINT_DIR, f"checkpoint_{MODEL_NAME}_epoch_{epoch+1}.pt"), |
| ) |
| save_file( |
| model.state_dict(), |
| os.path.join(CHECKPOINT_DIR, f"model_{epoch+1}.safetensors"), |
| ) |
|
|
|
|
| |
| |
|
|