spedrox-sac/models / pretrain_slimpajama.ipynb
download
raw
46.9 kB
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pre-Training on SlimPajama-6B (Azure A100)\n",
"\n",
"**Model**: 154M Decoder-only Transformer (GQA + REPO-Attention + Flash-Attention) \n",
"**Dataset**: SlimPajama-6B via Oxen \n",
"**Tracking**: Weights & Biases"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Uncomment on Azure if needed\n",
"# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
"# !pip install oxen wandb transformers tokenizers pyarrow"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/spedrox/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Pad token: [PAD], ID: 50257\n",
"Added 3 special tokens\n",
"PyTorch: 2.5.1+cu121\n",
"CUDA: True\n",
"GPU: NVIDIA A100 80GB PCIe\n",
"VRAM: 85.0 GB\n"
]
}
],
"source": [
"import os, sys, time, math\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.amp import autocast, GradScaler\n",
"from torch.utils.data import DataLoader\n",
"from datetime import datetime\n",
"import wandb\n",
"\n",
"PROJECT_ROOT = os.path.dirname(os.path.abspath(\"__file__\"))\n",
"TRAIN_DIR = os.path.join(PROJECT_ROOT, \"train\")\n",
"sys.path.insert(0, PROJECT_ROOT)\n",
"sys.path.insert(0, TRAIN_DIR)\n",
"\n",
"from transformer.build_transformer import build_transformer\n",
"from dataset_define import SlimPajamaDataset\n",
"from save_checkpoint import save_checkpoint\n",
"from tokenizer import tokenizer\n",
"\n",
"print(f\"PyTorch: {torch.__version__}\")\n",
"print(f\"CUDA: {torch.cuda.is_available()}\")\n",
"if torch.cuda.is_available():\n",
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Vocab: 50260, Pad ID: 50257\n",
"Tokens/step: 360,448\n",
"Est steps/epoch: 16,645\n",
"Total steps: 16,645\n",
"Warmup: 200 steps\n"
]
}
],
"source": [
"# ======================== CONFIG ========================\n",
"DATASET_DIR = os.path.join(PROJECT_ROOT, \"SlimPajama-6B\")\n",
"CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, \"checkpoints\")\n",
"\n",
"# Model\n",
"D_MODEL = 768\n",
"NUM_LAYERS = 12\n",
"NUM_HEADS = 12\n",
"KV_HEADS = 4\n",
"D_FF = 3072\n",
"DROPOUT = 0.1\n",
"MAX_SEQ_LEN = 2048\n",
"USE_REPO = True\n",
"USE_FLASH = True\n",
"\n",
"# Training\n",
"EPOCHS = 1\n",
"BATCH_SIZE = 22\n",
"GRAD_ACCUM = 8\n",
"LEARNING_RATE = 2e-4\n",
"MIN_LR = 2e-5 # 10% of peak\n",
"WEIGHT_DECAY = 0.01\n",
"MAX_GRAD_NORM = 1.0\n",
"WARMUP_STEPS = 200 # Short warmup so model starts learning fast\n",
"\n",
"# Estimated steps (for cosine schedule)\n",
"TOKENS_IN_DATASET = 6_000_000_000\n",
"TOKENS_PER_STEP = BATCH_SIZE * GRAD_ACCUM * MAX_SEQ_LEN\n",
"EST_STEPS_PER_EPOCH = TOKENS_IN_DATASET // TOKENS_PER_STEP\n",
"TOTAL_STEPS = EST_STEPS_PER_EPOCH * EPOCHS\n",
"\n",
"# WandB\n",
"WANDB_PROJECT = \"Spedrox_llm\"\n",
"USE_WANDB = True\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"VOCAB_SIZE = len(tokenizer)\n",
"PAD_ID = tokenizer.pad_token_id\n",
"\n",
"print(f\"Vocab: {VOCAB_SIZE}, Pad ID: {PAD_ID}\")\n",
"print(f\"Tokens/step: {TOKENS_PER_STEP:,}\")\n",
"print(f\"Est steps/epoch: {EST_STEPS_PER_EPOCH:,}\")\n",
"print(f\"Total steps: {TOTAL_STEPS:,}\")\n",
"print(f\"Warmup: {WARMUP_STEPS} steps\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset exists at /home/spedrox/Transformers/SlimPajama-6B\n"
]
}
],
"source": [
"# ======================== CLONE DATASET ========================\n",
"import oxen\n",
"\n",
"if not os.path.exists(DATASET_DIR):\n",
" print(\"Cloning SlimPajama-6B...\")\n",
" oxen.clone(\"https://hub.oxen.ai/datasets/SlimPajama-6B\", DATASET_DIR)\n",
" print(\"Done!\")\n",
"else:\n",
" print(f\"Dataset exists at {DATASET_DIR}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameters: 154,582,996 (154.6M)\n",
"REPO: ON, Flash: ON\n",
"[OK] RMSNorm gamma parameters are non-zero\n"
]
}
],
"source": [
"# ======================== BUILD MODEL ========================\n",
"model = build_transformer(\n",
" src_vocab_size=VOCAB_SIZE, tgt_vocab_size=VOCAB_SIZE,\n",
" src_seq_len=MAX_SEQ_LEN, tgt_seq_len=MAX_SEQ_LEN,\n",
" d_model=D_MODEL, N=NUM_LAYERS, h=NUM_HEADS, kv_h=KV_HEADS,\n",
" dropout=DROPOUT, d_ff=D_FF, use_repo=USE_REPO, use_flash=USE_FLASH,\n",
")\n",
"model = model.to(device)\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"Parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n",
"print(f\"REPO: {'ON' if USE_REPO else 'OFF'}, Flash: {'ON' if USE_FLASH else 'OFF'}\")\n",
"\n",
"# Verify RMSNorm gamma is NOT zero\n",
"for name, p in model.named_parameters():\n",
" if 'gamma' in name:\n",
" assert p.abs().sum() > 0, f\"FATAL: {name} is all zeros!\"\n",
"print(\"[OK] RMSNorm gamma parameters are non-zero\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running sanity check...\n",
"Output shape: torch.Size([2, 128, 50260])\n",
"Output range: [-0.8724, 0.9890]\n",
"Output std: 0.1734\n",
"[OK] Sanity check passed\n"
]
}
],
"source": [
"# ======================== SANITY CHECK ========================\n",
"print(\"Running sanity check...\")\n",
"model.train()\n",
"dummy = torch.randint(0, VOCAB_SIZE, (2, 128), device=device)\n",
"\n",
"with torch.no_grad():\n",
" x = model.tgt_embed(dummy)\n",
" for layer in model.decoder.layers:\n",
" x, _ = layer(x, tgt_mask=None, use_cache=False)\n",
" x = model.decoder.norm(x)\n",
" logits = model.project(x)\n",
"\n",
"print(f\"Output shape: {logits.shape}\")\n",
"print(f\"Output range: [{logits.min().item():.4f}, {logits.max().item():.4f}]\")\n",
"print(f\"Output std: {logits.std().item():.4f}\")\n",
"assert logits.std().item() > 0.01, \"FATAL: Model output has near-zero variance!\"\n",
"print(\"[OK] Sanity check passed\")\n",
"\n",
"del dummy, x, logits\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 50 parquet files in /home/spedrox/Transformers/SlimPajama-6B\n",
"DataLoader ready (batch=22, workers=16)\n"
]
}
],
"source": [
"# ======================== DATASET & DATALOADER ========================\n",
"train_dataset = SlimPajamaDataset(\n",
" data_dir=DATASET_DIR, tokenizer=tokenizer, max_length=MAX_SEQ_LEN,\n",
")\n",
"train_loader = DataLoader(\n",
" train_dataset, batch_size=BATCH_SIZE,\n",
" num_workers=16, pin_memory=True, prefetch_factor=2,\n",
")\n",
"print(f\"DataLoader ready (batch={BATCH_SIZE}, workers=16)\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: [wandb.login()] Using explicit session credentials for https://api.wandb.ai.\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/spedrox/.netrc\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mdinmaybrahmaofficial\u001b[0m (\u001b[33mdinmaybrahmaofficial-indian-institute-of-technology\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
]
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Tracking run with wandb version 0.25.1"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Run data is saved locally in <code>/home/spedrox/Transformers/wandb/run-20260402_174756-yomr6qws</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws' target=\"_blank\">slimpajama_0402_1747</a></strong> to <a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View project at <a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm' target=\"_blank\">https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws' target=\"_blank\">https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"WandB: https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws\n"
]
}
],
"source": [
"# ======================== WANDB ========================\n",
"wandb.login(key=\"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\")\n",
"\n",
"if USE_WANDB:\n",
" wandb.init(\n",
" project=WANDB_PROJECT,\n",
" name=f\"slimpajama_{datetime.now().strftime('%m%d_%H%M')}\",\n",
" config={\n",
" \"model\": \"decoder_only_transformer\",\n",
" \"params\": total_params,\n",
" \"d_model\": D_MODEL, \"layers\": NUM_LAYERS,\n",
" \"heads\": NUM_HEADS, \"kv_heads\": KV_HEADS,\n",
" \"d_ff\": D_FF, \"seq_len\": MAX_SEQ_LEN,\n",
" \"batch_size\": BATCH_SIZE, \"grad_accum\": GRAD_ACCUM,\n",
" \"effective_batch\": BATCH_SIZE * GRAD_ACCUM,\n",
" \"lr\": LEARNING_RATE, \"min_lr\": MIN_LR,\n",
" \"warmup\": WARMUP_STEPS, \"total_steps\": TOTAL_STEPS,\n",
" \"dataset\": \"SlimPajama-6B\",\n",
" \"features\": [\"GQA\", \"REPO-Attention\", \"Flash-Attention\", \"RMSNorm\"],\n",
" },\n",
" tags=[\"pre-training\", \"slimpajama\", \"a100\"]\n",
" )\n",
" wandb.watch(model, log=\"all\", log_freq=200)\n",
" print(f\"WandB: {wandb.run.url}\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Step 0: LR = 0.00000100\n",
" Step 10: LR = 0.00001100\n",
" Step 50: LR = 0.00005100\n",
" Step 100: LR = 0.00010100\n",
" Step 200: LR = 0.00020000\n",
" Step 500: LR = 0.00019985\n",
" Step 1000: LR = 0.00019895\n",
" Step 5000: LR = 0.00016474\n",
" Step 10000: LR = 0.00008329\n"
]
}
],
"source": [
"# ======================== LR SCHEDULE ========================\n",
"def get_lr(step):\n",
" \"\"\"Cosine schedule with linear warmup.\"\"\"\n",
" if step < WARMUP_STEPS:\n",
" # Linear warmup\n",
" return LEARNING_RATE * (step + 1) / WARMUP_STEPS\n",
" # Cosine decay\n",
" progress = (step - WARMUP_STEPS) / max(1, TOTAL_STEPS - WARMUP_STEPS)\n",
" progress = min(progress, 1.0)\n",
" cosine = 0.5 * (1.0 + math.cos(math.pi * progress))\n",
" return MIN_LR + (LEARNING_RATE - MIN_LR) * cosine\n",
"\n",
"# Quick preview\n",
"for s in [0, 10, 50, 100, 200, 500, 1000, 5000, 10000]:\n",
" print(f\" Step {s:>6}: LR = {get_lr(s):.8f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Params with weight decay : 115,867,776\n",
"Params without weight decay: 38,715,220\n",
"Mixed precision: BF16\n",
"GradScaler: OFF (BF16)\n",
"\n",
"Starting training...\n",
" Epochs: 1, Batch: 22, Accum: 8\n",
" Effective batch: 176\n",
" Peak LR: 0.0002, Warmup: 200 steps\n",
" Est total steps: 16,645\n",
"\n",
"E1 B17000 | loss=5.0662 | lr=0.000194 | step=2125 | 9266s | 12.0GB\n",
"E1 B17020 | loss=5.0566 | lr=0.000194 | step=2127 | 9277s | 12.0GB\n",
"E1 B17040 | loss=4.9669 | lr=0.000194 | step=2130 | 9288s | 12.0GB\n",
"E1 B17060 | loss=4.8286 | lr=0.000194 | step=2132 | 9299s | 12.0GB\n",
"E1 B17080 | loss=4.9056 | lr=0.000194 | step=2135 | 9310s | 12.0GB\n",
"E1 B17100 | loss=4.5468 | lr=0.000194 | step=2137 | 9320s | 12.0GB\n",
"E1 B17120 | loss=5.1199 | lr=0.000194 | step=2140 | 9332s | 12.0GB\n",
"E1 B17140 | loss=3.7074 | lr=0.000194 | step=2142 | 9342s | 12.0GB\n",
"E1 B17160 | loss=4.3470 | lr=0.000194 | step=2145 | 9353s | 12.0GB\n",
"E1 B17180 | loss=4.9766 | lr=0.000194 | step=2147 | 9364s | 12.0GB\n",
"E1 B17200 | loss=5.0189 | lr=0.000194 | step=2150 | 9375s | 12.0GB\n",
"E1 B17220 | loss=5.0189 | lr=0.000194 | step=2152 | 9386s | 12.0GB\n",
"E1 B17240 | loss=5.0028 | lr=0.000194 | step=2155 | 9397s | 12.0GB\n",
"E1 B17260 | loss=4.9502 | lr=0.000194 | step=2157 | 9408s | 12.0GB\n",
"E1 B17280 | loss=5.0946 | lr=0.000194 | step=2160 | 9419s | 12.0GB\n",
"E1 B17300 | loss=5.2172 | lr=0.000194 | step=2162 | 9430s | 12.0GB\n",
"E1 B17320 | loss=4.7954 | lr=0.000194 | step=2165 | 9441s | 12.0GB\n",
"E1 B17340 | loss=4.9832 | lr=0.000194 | step=2167 | 9451s | 12.0GB\n",
"E1 B17360 | loss=4.9154 | lr=0.000194 | step=2170 | 9463s | 12.0GB\n",
"E1 B17380 | loss=4.7914 | lr=0.000194 | step=2172 | 9473s | 12.0GB\n",
"E1 B17400 | loss=4.9853 | lr=0.000194 | step=2175 | 9484s | 12.0GB\n",
"E1 B17420 | loss=4.9046 | lr=0.000194 | step=2177 | 9495s | 12.0GB\n",
"E1 B17440 | loss=4.7686 | lr=0.000194 | step=2180 | 9506s | 12.0GB\n",
"E1 B17460 | loss=4.9282 | lr=0.000194 | step=2182 | 9517s | 12.0GB\n",
"E1 B17480 | loss=5.2533 | lr=0.000194 | step=2185 | 9528s | 12.0GB\n",
"E1 B17500 | loss=4.7660 | lr=0.000194 | step=2187 | 9538s | 12.0GB\n",
"E1 B17520 | loss=4.8343 | lr=0.000194 | step=2190 | 9550s | 12.0GB\n",
"E1 B17540 | loss=3.9227 | lr=0.000194 | step=2192 | 9560s | 12.0GB\n",
"E1 B17560 | loss=5.1872 | lr=0.000194 | step=2195 | 9571s | 12.0GB\n",
"E1 B17580 | loss=5.2201 | lr=0.000194 | step=2197 | 9582s | 12.0GB\n",
"E1 B17600 | loss=4.9061 | lr=0.000194 | step=2200 | 9593s | 12.0GB\n",
"E1 B17620 | loss=4.5560 | lr=0.000193 | step=2202 | 9604s | 12.0GB\n",
"E1 B17640 | loss=5.0290 | lr=0.000193 | step=2205 | 9615s | 12.0GB\n",
"E1 B83360 | loss=4.0679 | lr=0.000076 | step=10420 | 45447s | 12.0GB\n",
"E1 B83380 | loss=3.8272 | lr=0.000076 | step=10422 | 45458s | 12.0GB\n",
"E1 B83400 | loss=3.8961 | lr=0.000076 | step=10425 | 45469s | 12.0GB\n",
"E1 B83420 | loss=3.6188 | lr=0.000076 | step=10427 | 45480s | 12.0GB\n",
"E1 B83440 | loss=3.6587 | lr=0.000076 | step=10430 | 45490s | 12.0GB\n",
"E1 B83460 | loss=3.7851 | lr=0.000076 | step=10432 | 45501s | 12.0GB\n",
"E1 B83480 | loss=3.4612 | lr=0.000076 | step=10435 | 45512s | 12.0GB\n",
"E1 B83500 | loss=3.9424 | lr=0.000076 | step=10437 | 45523s | 12.0GB\n",
"E1 B83520 | loss=4.2557 | lr=0.000076 | step=10440 | 45534s | 12.0GB\n",
"E1 B83540 | loss=4.2505 | lr=0.000076 | step=10442 | 45545s | 12.0GB\n",
"E1 B83560 | loss=4.0652 | lr=0.000076 | step=10445 | 45556s | 12.0GB\n",
"E1 B83580 | loss=3.4681 | lr=0.000076 | step=10447 | 45567s | 12.0GB\n",
"E1 B83600 | loss=3.1749 | lr=0.000076 | step=10450 | 45578s | 12.0GB\n",
"E1 B83620 | loss=2.6673 | lr=0.000076 | step=10452 | 45589s | 12.0GB\n",
"E1 B83640 | loss=3.8833 | lr=0.000076 | step=10455 | 45599s | 12.0GB\n",
"E1 B83660 | loss=3.7466 | lr=0.000076 | step=10457 | 45610s | 12.0GB\n",
"E1 B83680 | loss=3.7969 | lr=0.000076 | step=10460 | 45621s | 12.0GB\n",
"E1 B83700 | loss=4.0344 | lr=0.000076 | step=10462 | 45632s | 12.0GB\n",
"E1 B83720 | loss=3.9571 | lr=0.000076 | step=10465 | 45643s | 12.0GB\n",
"E1 B83740 | loss=4.1347 | lr=0.000076 | step=10467 | 45654s | 12.0GB\n",
"E1 B83760 | loss=4.0433 | lr=0.000076 | step=10470 | 45665s | 12.0GB\n",
"E1 B83780 | loss=3.6965 | lr=0.000076 | step=10472 | 45675s | 12.0GB\n",
"E1 B83800 | loss=4.0063 | lr=0.000076 | step=10475 | 45686s | 12.0GB\n",
"E1 B83820 | loss=3.9130 | lr=0.000076 | step=10477 | 45697s | 12.0GB\n",
"E1 B83840 | loss=3.6705 | lr=0.000076 | step=10480 | 45708s | 12.0GB\n",
"E1 B83860 | loss=4.1480 | lr=0.000075 | step=10482 | 45719s | 12.0GB\n",
"E1 B83880 | loss=3.7317 | lr=0.000075 | step=10485 | 45730s | 12.0GB\n",
"E1 B83900 | loss=4.1277 | lr=0.000075 | step=10487 | 45741s | 12.0GB\n",
"E1 B83920 | loss=4.2789 | lr=0.000075 | step=10490 | 45752s | 12.0GB\n",
"E1 B83940 | loss=4.2223 | lr=0.000075 | step=10492 | 45763s | 12.0GB\n",
"E1 B83960 | loss=3.6111 | lr=0.000075 | step=10495 | 45774s | 12.0GB\n",
"E1 B83980 | loss=3.0651 | lr=0.000075 | step=10497 | 45784s | 12.0GB\n",
"E1 B84000 | loss=3.9246 | lr=0.000075 | step=10500 | 45795s | 12.0GB\n",
"E1 B84020 | loss=2.6241 | lr=0.000075 | step=10502 | 45806s | 12.0GB\n",
"E1 B84040 | loss=3.7022 | lr=0.000075 | step=10505 | 45817s | 12.0GB\n",
"E1 B84060 | loss=3.1323 | lr=0.000075 | step=10507 | 45828s | 12.0GB\n",
"E1 B84080 | loss=3.9758 | lr=0.000075 | step=10510 | 45839s | 12.0GB\n",
"E1 B84100 | loss=2.6221 | lr=0.000075 | step=10512 | 45850s | 12.0GB\n",
"E1 B84120 | loss=3.9541 | lr=0.000075 | step=10515 | 45861s | 12.0GB\n",
"E1 B84140 | loss=4.1421 | lr=0.000075 | step=10517 | 45872s | 12.0GB\n",
"E1 B84160 | loss=3.3570 | lr=0.000075 | step=10520 | 45883s | 12.0GB\n",
"E1 B84180 | loss=3.7867 | lr=0.000075 | step=10522 | 45894s | 12.0GB\n",
"E1 B84200 | loss=1.4464 | lr=0.000075 | step=10525 | 45904s | 12.0GB\n",
"E1 B84220 | loss=3.6208 | lr=0.000075 | step=10527 | 45916s | 12.0GB\n",
"E1 B84240 | loss=3.1878 | lr=0.000075 | step=10530 | 45926s | 12.0GB\n",
"E1 B84260 | loss=3.6956 | lr=0.000075 | step=10532 | 45937s | 12.0GB\n",
"E1 B84280 | loss=1.1194 | lr=0.000075 | step=10535 | 45948s | 12.0GB\n",
"E1 B84300 | loss=2.7751 | lr=0.000075 | step=10537 | 45959s | 12.0GB\n",
"E1 B84320 | loss=3.9186 | lr=0.000075 | step=10540 | 45970s | 12.0GB\n",
"E1 B84340 | loss=3.9909 | lr=0.000075 | step=10542 | 45981s | 12.0GB\n",
"E1 B84360 | loss=4.1888 | lr=0.000074 | step=10545 | 45992s | 12.0GB\n",
"E1 B84380 | loss=3.7813 | lr=0.000074 | step=10547 | 46003s | 12.0GB\n",
"E1 B84400 | loss=3.2111 | lr=0.000074 | step=10550 | 46014s | 12.0GB\n",
"E1 B84420 | loss=3.9105 | lr=0.000074 | step=10552 | 46025s | 12.0GB\n",
"E1 B84440 | loss=4.1016 | lr=0.000074 | step=10555 | 46036s | 12.0GB\n",
"E1 B84460 | loss=3.9241 | lr=0.000074 | step=10557 | 46047s | 12.0GB\n",
"E1 B84480 | loss=3.5857 | lr=0.000074 | step=10560 | 46057s | 12.0GB\n",
"E1 B84500 | loss=3.9282 | lr=0.000074 | step=10562 | 46068s | 12.0GB\n",
"E1 B84520 | loss=3.9756 | lr=0.000074 | step=10565 | 46079s | 12.0GB\n",
"E1 B84540 | loss=3.8981 | lr=0.000074 | step=10567 | 46090s | 12.0GB\n",
"E1 B84560 | loss=3.6915 | lr=0.000074 | step=10570 | 46101s | 12.0GB\n",
"E1 B84580 | loss=3.3759 | lr=0.000074 | step=10572 | 46112s | 12.0GB\n",
"E1 B84600 | loss=4.0181 | lr=0.000074 | step=10575 | 46123s | 12.0GB\n",
"E1 B84620 | loss=3.9899 | lr=0.000074 | step=10577 | 46134s | 12.0GB\n",
"E1 B84640 | loss=2.7491 | lr=0.000074 | step=10580 | 46144s | 12.0GB\n",
"E1 B84660 | loss=3.8926 | lr=0.000074 | step=10582 | 46156s | 12.0GB\n",
"E1 B84680 | loss=3.9660 | lr=0.000074 | step=10585 | 46166s | 12.0GB\n",
"E1 B84700 | loss=3.7381 | lr=0.000074 | step=10587 | 46177s | 12.0GB\n",
"E1 B84720 | loss=3.6865 | lr=0.000074 | step=10590 | 46188s | 12.0GB\n",
"E1 B84740 | loss=4.0988 | lr=0.000074 | step=10592 | 46199s | 12.0GB\n",
"E1 B84760 | loss=3.9796 | lr=0.000074 | step=10595 | 46210s | 12.0GB\n",
"E1 B84780 | loss=3.4667 | lr=0.000074 | step=10597 | 46221s | 12.0GB\n",
"E1 B84800 | loss=3.3943 | lr=0.000074 | step=10600 | 46232s | 12.0GB\n",
"E1 B84820 | loss=4.2667 | lr=0.000074 | step=10602 | 46243s | 12.0GB\n",
"E1 B84840 | loss=3.4909 | lr=0.000074 | step=10605 | 46254s | 12.0GB\n",
"E1 B84860 | loss=3.9001 | lr=0.000074 | step=10607 | 46265s | 12.0GB\n",
"E1 B84880 | loss=3.2266 | lr=0.000073 | step=10610 | 46276s | 12.0GB\n",
"E1 B84900 | loss=3.8690 | lr=0.000073 | step=10612 | 46286s | 12.0GB\n",
"E1 B84920 | loss=3.6407 | lr=0.000073 | step=10615 | 46297s | 12.0GB\n",
"E1 B84940 | loss=3.8192 | lr=0.000073 | step=10617 | 46308s | 12.0GB\n",
"E1 B84960 | loss=3.7542 | lr=0.000073 | step=10620 | 46319s | 12.0GB\n",
"E1 B84980 | loss=4.0122 | lr=0.000073 | step=10622 | 46330s | 12.0GB\n",
"E1 B85000 | loss=4.2165 | lr=0.000073 | step=10625 | 46341s | 12.0GB\n",
"E1 B85020 | loss=3.0944 | lr=0.000073 | step=10627 | 46352s | 12.0GB\n",
"E1 B85040 | loss=3.9340 | lr=0.000073 | step=10630 | 46363s | 12.0GB\n",
"E1 B85060 | loss=3.3452 | lr=0.000073 | step=10632 | 46374s | 12.0GB\n",
"E1 B85080 | loss=3.7265 | lr=0.000073 | step=10635 | 46385s | 12.0GB\n",
"E1 B85100 | loss=3.2035 | lr=0.000073 | step=10637 | 46395s | 12.0GB\n",
"E1 B85120 | loss=3.4126 | lr=0.000073 | step=10640 | 46407s | 12.0GB\n",
"E1 B85140 | loss=3.9198 | lr=0.000073 | step=10642 | 46417s | 12.0GB\n",
"E1 B85160 | loss=3.5839 | lr=0.000073 | step=10645 | 46428s | 12.0GB\n",
"E1 B85180 | loss=3.8842 | lr=0.000073 | step=10647 | 46439s | 12.0GB\n",
"E1 B85200 | loss=3.9351 | lr=0.000073 | step=10650 | 46450s | 12.0GB\n",
"E1 B85220 | loss=3.3622 | lr=0.000073 | step=10652 | 46461s | 12.0GB\n",
"E1 B85240 | loss=3.7592 | lr=0.000073 | step=10655 | 46472s | 12.0GB\n",
"E1 B85260 | loss=2.7395 | lr=0.000073 | step=10657 | 46483s | 12.0GB\n",
"E1 B85280 | loss=3.8370 | lr=0.000073 | step=10660 | 46494s | 12.0GB\n",
"E1 B85300 | loss=4.0697 | lr=0.000073 | step=10662 | 46505s | 12.0GB\n",
"E1 B85320 | loss=3.1111 | lr=0.000073 | step=10665 | 46516s | 12.0GB\n",
"E1 B85340 | loss=3.9165 | lr=0.000073 | step=10667 | 46527s | 12.0GB\n",
"E1 B85360 | loss=4.0334 | lr=0.000073 | step=10670 | 46538s | 12.0GB\n",
"E1 B85380 | loss=3.8077 | lr=0.000073 | step=10672 | 46548s | 12.0GB\n",
"E1 B85400 | loss=4.0962 | lr=0.000072 | step=10675 | 46559s | 12.0GB\n",
"E1 B85420 | loss=3.7642 | lr=0.000072 | step=10677 | 46570s | 12.0GB\n",
"E1 B85640 | loss=3.8509 | lr=0.000072 | step=10705 | 46690s | 12.0GB\n",
"E1 B85660 | loss=3.9086 | lr=0.000072 | step=10707 | 46701s | 12.0GB\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (2156 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1047 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2619 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1225 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (3242 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2371 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (3373 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (3054 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1690 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1364 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1551 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (5217 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1419 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1271 > 1024). Running this sequence through the model will result in indexing errors\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (2637 > 1024). Running this sequence through the model will result in indexing errors\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"E1 B 0 | loss=10.8247 | lr=0.000001 | step=0 | 3s | 10.7GB\n",
"E1 B 20 | loss=10.8169 | lr=0.000003 | step=2 | 14s | 12.0GB\n",
"E1 B 40 | loss=10.7877 | lr=0.000006 | step=5 | 25s | 12.0GB\n",
"E1 B 60 | loss=10.7531 | lr=0.000008 | step=7 | 36s | 12.0GB\n",
"E1 B 80 | loss=10.6827 | lr=0.000011 | step=10 | 47s | 12.0GB\n",
"E1 B 100 | loss=10.6539 | lr=0.000013 | step=12 | 57s | 12.0GB\n",
"E1 B 120 | loss=10.5372 | lr=0.000016 | step=15 | 68s | 12.0GB\n",
"E1 B 140 | loss=10.5912 | lr=0.000018 | step=17 | 79s | 12.0GB\n",
"E1 B 160 | loss=10.4607 | lr=0.000021 | step=20 | 90s | 12.0GB\n",
"E1 B 180 | loss=10.3397 | lr=0.000023 | step=22 | 101s | 12.0GB\n",
"E1 B 200 | loss=9.9626 | lr=0.000026 | step=25 | 112s | 12.0GB\n",
"E1 B 220 | loss=10.2695 | lr=0.000028 | step=27 | 123s | 12.0GB\n",
"E1 B 240 | loss=10.2457 | lr=0.000031 | step=30 | 133s | 12.0GB\n",
"E1 B 260 | loss=10.1893 | lr=0.000033 | step=32 | 144s | 12.0GB\n",
"E1 B 280 | loss=10.1125 | lr=0.000036 | step=35 | 155s | 12.0GB\n",
"E1 B 300 | loss=10.0175 | lr=0.000038 | step=37 | 166s | 12.0GB\n",
"E1 B 320 | loss=9.9452 | lr=0.000041 | step=40 | 177s | 12.0GB\n",
"E1 B 340 | loss=9.8831 | lr=0.000043 | step=42 | 188s | 12.0GB\n",
"E1 B 360 | loss=9.4486 | lr=0.000046 | step=45 | 199s | 12.0GB\n",
"E1 B 380 | loss=9.6624 | lr=0.000048 | step=47 | 210s | 12.0GB\n",
"E1 B 400 | loss=9.5566 | lr=0.000051 | step=50 | 221s | 12.0GB\n",
"E1 B 420 | loss=9.5670 | lr=0.000053 | step=52 | 232s | 12.0GB\n",
"E1 B 440 | loss=9.6506 | lr=0.000056 | step=55 | 242s | 12.0GB\n",
"E1 B 460 | loss=9.2208 | lr=0.000058 | step=57 | 253s | 12.0GB\n",
"E1 B 480 | loss=9.1615 | lr=0.000061 | step=60 | 264s | 12.0GB\n",
"E1 B 500 | loss=9.1667 | lr=0.000063 | step=62 | 275s | 12.0GB\n",
"E1 B 520 | loss=8.0803 | lr=0.000066 | step=65 | 286s | 12.0GB\n",
"E1 B 540 | loss=8.9027 | lr=0.000068 | step=67 | 297s | 12.0GB\n",
"E1 B 560 | loss=8.8101 | lr=0.000071 | step=70 | 308s | 12.0GB\n",
"E1 B 580 | loss=8.6734 | lr=0.000073 | step=72 | 319s | 12.0GB\n",
"E1 B 600 | loss=8.6437 | lr=0.000076 | step=75 | 329s | 12.0GB\n",
"E1 B 620 | loss=8.5093 | lr=0.000078 | step=77 | 341s | 12.0GB\n",
"E1 B 640 | loss=8.3063 | lr=0.000081 | step=80 | 351s | 12.0GB\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 94\u001b[39m\n\u001b[32m 90\u001b[39m \u001b[38;5;66;03m# Backward\u001b[39;00m\n\u001b[32m 91\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m scaler.is_enabled():\n\u001b[32m 92\u001b[39m scaler.scale(loss).backward()\n\u001b[32m 93\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m94\u001b[39m loss.backward()\n\u001b[32m 95\u001b[39m \n\u001b[32m 96\u001b[39m real_loss = loss.item() * GRAD_ACCUM\n\u001b[32m 97\u001b[39m total_loss += real_loss\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.venv/lib/python3.12/site-packages/torch/_tensor.py:581\u001b[39m, in \u001b[36mTensor.backward\u001b[39m\u001b[34m(self, gradient, retain_graph, create_graph, inputs)\u001b[39m\n\u001b[32m 571\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 572\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m 573\u001b[39m Tensor.backward,\n\u001b[32m 574\u001b[39m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[32m (...)\u001b[39m\u001b[32m 579\u001b[39m inputs=inputs,\n\u001b[32m 580\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m581\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mautograd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 582\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\n\u001b[32m 583\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.venv/lib/python3.12/site-packages/torch/autograd/__init__.py:347\u001b[39m, in \u001b[36mbackward\u001b[39m\u001b[34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[39m\n\u001b[32m 342\u001b[39m retain_graph = create_graph\n\u001b[32m 344\u001b[39m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[32m 345\u001b[39m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[32m 346\u001b[39m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m347\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 348\u001b[39m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 349\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 350\u001b[39m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 351\u001b[39m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 352\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 353\u001b[39m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 354\u001b[39m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 355\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/.venv/lib/python3.12/site-packages/torch/autograd/graph.py:825\u001b[39m, in \u001b[36m_engine_run_backward\u001b[39m\u001b[34m(t_outputs, *args, **kwargs)\u001b[39m\n\u001b[32m 823\u001b[39m unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[32m 824\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m825\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_execution_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[32m 826\u001b[39m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 827\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[32m 828\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 829\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
"\u001b[31mKeyboardInterrupt\u001b[39m: "
]
}
],
"source": [
"# ======================== TRAINING LOOP ========================\n",
"\n",
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
"\n",
"# Separate param groups: NO weight decay on norms, biases, embeddings\n",
"decay_params = []\n",
"no_decay_params = []\n",
"for name, p in model.named_parameters():\n",
" if not p.requires_grad:\n",
" continue\n",
" if p.dim() == 1 or \"norm\" in name or \"gamma\" in name or \"bias\" in name or \"embed\" in name:\n",
" no_decay_params.append(p)\n",
" else:\n",
" decay_params.append(p)\n",
"\n",
"print(f\"Params with weight decay : {sum(p.numel() for p in decay_params):,}\")\n",
"print(f\"Params without weight decay: {sum(p.numel() for p in no_decay_params):,}\")\n",
"\n",
"optimizer = torch.optim.AdamW([\n",
" {\"params\": decay_params, \"weight_decay\": WEIGHT_DECAY},\n",
" {\"params\": no_decay_params, \"weight_decay\": 0.0},\n",
"], lr=LEARNING_RATE, betas=(0.9, 0.95))\n",
"\n",
"# BF16 on A100 (no GradScaler needed for BF16)\n",
"use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()\n",
"amp_dtype = torch.bfloat16 if use_bf16 else torch.float16\n",
"scaler = GradScaler(enabled=(not use_bf16)) # Only needed for FP16\n",
"\n",
"print(f\"Mixed precision: {'BF16' if use_bf16 else 'FP16'}\")\n",
"print(f\"GradScaler: {'OFF (BF16)' if use_bf16 else 'ON (FP16)'}\")\n",
"\n",
"model.train()\n",
"global_step = 0\n",
"best_loss = float('inf')\n",
"last_ckpt_time = time.time()\n",
"\n",
"# Set initial LR (warmup starts from near-zero)\n",
"for pg in optimizer.param_groups:\n",
" pg['lr'] = get_lr(0)\n",
"\n",
"print(f\"\\nStarting training...\")\n",
"print(f\" Epochs: {EPOCHS}, Batch: {BATCH_SIZE}, Accum: {GRAD_ACCUM}\")\n",
"print(f\" Effective batch: {BATCH_SIZE * GRAD_ACCUM}\")\n",
"print(f\" Peak LR: {LEARNING_RATE}, Warmup: {WARMUP_STEPS} steps\")\n",
"print(f\" Est total steps: {TOTAL_STEPS:,}\")\n",
"print()\n",
"\n",
"for epoch in range(EPOCHS):\n",
" total_loss = 0.0\n",
" batch_count = 0\n",
" micro_count = 0\n",
" epoch_start = time.time()\n",
"\n",
" optimizer.zero_grad(set_to_none=True)\n",
"\n",
" for i, batch in enumerate(train_loader):\n",
" # Auto-checkpoint every 2 hours\n",
" if time.time() - last_ckpt_time >= 7200:\n",
" avg = total_loss / max(batch_count, 1)\n",
" save_checkpoint(model, optimizer, epoch, global_step, avg, best_loss,\n",
" CHECKPOINT_DIR, f\"auto_epoch{epoch+1}_step{global_step}.pt\")\n",
" last_ckpt_time = time.time()\n",
"\n",
" input_ids = batch[\"input_ids\"].to(device, non_blocking=True)\n",
" labels = batch[\"labels\"].to(device, non_blocking=True)\n",
"\n",
" # Forward\n",
" with autocast(device_type=\"cuda\", dtype=amp_dtype):\n",
" x = model.tgt_embed(input_ids)\n",
" for layer in model.decoder.layers:\n",
" x, _ = layer(x, tgt_mask=None, use_cache=False)\n",
" x = model.decoder.norm(x)\n",
" logits = model.project(x)\n",
"\n",
" shift_logits = logits[..., :-1, :].contiguous()\n",
" shift_labels = labels[..., 1:].contiguous()\n",
" loss = nn.CrossEntropyLoss(ignore_index=-100)(\n",
" shift_logits.view(-1, shift_logits.size(-1)),\n",
" shift_labels.view(-1)\n",
" )\n",
" loss = loss / GRAD_ACCUM\n",
"\n",
" # Skip bad batches\n",
" if not torch.isfinite(loss):\n",
" print(f\"[WARN] Non-finite loss at batch {i}, skipping\")\n",
" optimizer.zero_grad(set_to_none=True)\n",
" micro_count = 0\n",
" continue\n",
"\n",
" # Backward\n",
" if scaler.is_enabled():\n",
" scaler.scale(loss).backward()\n",
" else:\n",
" loss.backward()\n",
"\n",
" real_loss = loss.item() * GRAD_ACCUM\n",
" total_loss += real_loss\n",
" batch_count += 1\n",
" micro_count += 1\n",
"\n",
" # Optimizer step every GRAD_ACCUM micro-batches\n",
" if micro_count >= GRAD_ACCUM:\n",
" if scaler.is_enabled():\n",
" scaler.unscale_(optimizer)\n",
"\n",
" grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
"\n",
" if scaler.is_enabled():\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" else:\n",
" optimizer.step()\n",
"\n",
" optimizer.zero_grad(set_to_none=True)\n",
" global_step += 1\n",
" micro_count = 0\n",
"\n",
" # Update LR\n",
" lr = get_lr(global_step)\n",
" for pg in optimizer.param_groups:\n",
" pg['lr'] = lr\n",
"\n",
" # Log to WandB\n",
" if USE_WANDB:\n",
" log_dict = {\n",
" \"train/loss\": real_loss,\n",
" \"train/lr\": lr,\n",
" \"train/grad_norm\": float(grad_norm),\n",
" \"train/step\": global_step,\n",
" \"train/epoch\": epoch + 1,\n",
" }\n",
" if torch.cuda.is_available():\n",
" log_dict[\"system/gpu_gb\"] = torch.cuda.memory_allocated() / 1e9\n",
" wandb.log(log_dict, step=global_step)\n",
"\n",
" # Print progress\n",
" if i % 20 == 0:\n",
" elapsed = time.time() - epoch_start\n",
" lr_now = optimizer.param_groups[0]['lr']\n",
" gpu_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0\n",
" print(f\"E{epoch+1} B{i:>5} | loss={real_loss:.4f} | lr={lr_now:.6f} | \"\n",
" f\"step={global_step} | {elapsed:.0f}s | {gpu_gb:.1f}GB\")\n",
"\n",
" if i % 50 == 0 and torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
"\n",
" # Flush leftover micro-batches\n",
" if micro_count > 0:\n",
" if scaler.is_enabled():\n",
" scaler.unscale_(optimizer)\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
" if scaler.is_enabled():\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" else:\n",
" optimizer.step()\n",
" optimizer.zero_grad(set_to_none=True)\n",
" global_step += 1\n",
"\n",
" # End of epoch\n",
" avg_loss = total_loss / max(batch_count, 1)\n",
" duration = time.time() - epoch_start\n",
" print(f\"\\nEpoch {epoch+1}/{EPOCHS} done | avg_loss={avg_loss:.4f} | {duration:.0f}s\")\n",
"\n",
" if USE_WANDB:\n",
" wandb.log({\"epoch/avg_loss\": avg_loss, \"epoch/duration\": duration}, step=global_step)\n",
"\n",
" if avg_loss < best_loss:\n",
" best_loss = avg_loss\n",
" save_checkpoint(model, optimizer, epoch, global_step, avg_loss, best_loss,\n",
" CHECKPOINT_DIR, \"best_model.pt\")\n",
" if USE_WANDB:\n",
" wandb.log({\"train/best_loss\": best_loss}, step=global_step)\n",
"\n",
" save_checkpoint(model, optimizer, epoch, global_step, avg_loss, best_loss,\n",
" CHECKPOINT_DIR, f\"epoch_{epoch+1}.pt\")\n",
"\n",
"# Final\n",
"print(\"\\nTraining complete!\")\n",
"save_checkpoint(model, optimizer, EPOCHS-1, global_step, avg_loss, best_loss,\n",
" CHECKPOINT_DIR, \"final_model.pt\")\n",
"if USE_WANDB:\n",
" wandb.finish()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# ======================== RESUME FROM CHECKPOINT ========================\n",
"# Uncomment to resume:\n",
"\"\"\"\n",
"ckpt = torch.load(os.path.join(CHECKPOINT_DIR, \"best_model.pt\"), map_location=device, weights_only=False)\n",
"model.load_state_dict(ckpt['model_state_dict'])\n",
"optimizer.load_state_dict(ckpt['optimizer_state_dict'])\n",
"global_step = ckpt['global_step']\n",
"best_loss = ckpt['best_loss']\n",
"print(f\"Resumed from step {global_step}\")\n",
"\"\"\"\n",
"print(\"Resume cell ready (uncomment to use).\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv (3.12.3)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Xet Storage Details

Size:
46.9 kB
·
Xet hash:
081b178d4e10267905d2423d3aed52a1dd702bc85bd4c939b70630472ce86a41

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.