Buckets:
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Pre-Training on SlimPajama-6B (Azure A100)\n", | |
| "\n", | |
| "**Model**: 154M Decoder-only Transformer (GQA + REPO-Attention + Flash-Attention) \n", | |
| "**Dataset**: SlimPajama-6B via Oxen \n", | |
| "**Tracking**: Weights & Biases" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Uncomment on Azure if needed\n", | |
| "# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n", | |
| "# !pip install oxen wandb transformers tokenizers pyarrow" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/home/spedrox/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
| " from .autonotebook import tqdm as notebook_tqdm\n", | |
| "Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Pad token: [PAD], ID: 50257\n", | |
| "Added 3 special tokens\n", | |
| "PyTorch: 2.5.1+cu121\n", | |
| "CUDA: True\n", | |
| "GPU: NVIDIA A100 80GB PCIe\n", | |
| "VRAM: 85.0 GB\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import os, sys, time, math\n", | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "from torch.amp import autocast, GradScaler\n", | |
| "from torch.utils.data import DataLoader\n", | |
| "from datetime import datetime\n", | |
| "import wandb\n", | |
| "\n", | |
| "PROJECT_ROOT = os.path.dirname(os.path.abspath(\"__file__\"))\n", | |
| "TRAIN_DIR = os.path.join(PROJECT_ROOT, \"train\")\n", | |
| "sys.path.insert(0, PROJECT_ROOT)\n", | |
| "sys.path.insert(0, TRAIN_DIR)\n", | |
| "\n", | |
| "from transformer.build_transformer import build_transformer\n", | |
| "from dataset_define import SlimPajamaDataset\n", | |
| "from save_checkpoint import save_checkpoint\n", | |
| "from tokenizer import tokenizer\n", | |
| "\n", | |
| "print(f\"PyTorch: {torch.__version__}\")\n", | |
| "print(f\"CUDA: {torch.cuda.is_available()}\")\n", | |
| "if torch.cuda.is_available():\n", | |
| " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", | |
| " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Vocab: 50260, Pad ID: 50257\n", | |
| "Tokens/step: 360,448\n", | |
| "Est steps/epoch: 16,645\n", | |
| "Total steps: 16,645\n", | |
| "Warmup: 200 steps\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== CONFIG ========================\n", | |
| "DATASET_DIR = os.path.join(PROJECT_ROOT, \"SlimPajama-6B\")\n", | |
| "CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, \"checkpoints\")\n", | |
| "\n", | |
| "# Model\n", | |
| "D_MODEL = 768\n", | |
| "NUM_LAYERS = 12\n", | |
| "NUM_HEADS = 12\n", | |
| "KV_HEADS = 4\n", | |
| "D_FF = 3072\n", | |
| "DROPOUT = 0.1\n", | |
| "MAX_SEQ_LEN = 2048\n", | |
| "USE_REPO = True\n", | |
| "USE_FLASH = True\n", | |
| "\n", | |
| "# Training\n", | |
| "EPOCHS = 1\n", | |
| "BATCH_SIZE = 22\n", | |
| "GRAD_ACCUM = 8\n", | |
| "LEARNING_RATE = 2e-4\n", | |
| "MIN_LR = 2e-5 # 10% of peak\n", | |
| "WEIGHT_DECAY = 0.01\n", | |
| "MAX_GRAD_NORM = 1.0\n", | |
| "WARMUP_STEPS = 200 # Short warmup so model starts learning fast\n", | |
| "\n", | |
| "# Estimated steps (for cosine schedule)\n", | |
| "TOKENS_IN_DATASET = 6_000_000_000\n", | |
| "TOKENS_PER_STEP = BATCH_SIZE * GRAD_ACCUM * MAX_SEQ_LEN\n", | |
| "EST_STEPS_PER_EPOCH = TOKENS_IN_DATASET // TOKENS_PER_STEP\n", | |
| "TOTAL_STEPS = EST_STEPS_PER_EPOCH * EPOCHS\n", | |
| "\n", | |
| "# WandB\n", | |
| "WANDB_PROJECT = \"Spedrox_llm\"\n", | |
| "USE_WANDB = True\n", | |
| "\n", | |
| "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
| "VOCAB_SIZE = len(tokenizer)\n", | |
| "PAD_ID = tokenizer.pad_token_id\n", | |
| "\n", | |
| "print(f\"Vocab: {VOCAB_SIZE}, Pad ID: {PAD_ID}\")\n", | |
| "print(f\"Tokens/step: {TOKENS_PER_STEP:,}\")\n", | |
| "print(f\"Est steps/epoch: {EST_STEPS_PER_EPOCH:,}\")\n", | |
| "print(f\"Total steps: {TOTAL_STEPS:,}\")\n", | |
| "print(f\"Warmup: {WARMUP_STEPS} steps\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Dataset exists at /home/spedrox/Transformers/SlimPajama-6B\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== CLONE DATASET ========================\n", | |
| "import oxen\n", | |
| "\n", | |
| "if not os.path.exists(DATASET_DIR):\n", | |
| " print(\"Cloning SlimPajama-6B...\")\n", | |
| " oxen.clone(\"https://hub.oxen.ai/datasets/SlimPajama-6B\", DATASET_DIR)\n", | |
| " print(\"Done!\")\n", | |
| "else:\n", | |
| " print(f\"Dataset exists at {DATASET_DIR}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Parameters: 154,582,996 (154.6M)\n", | |
| "REPO: ON, Flash: ON\n", | |
| "[OK] RMSNorm gamma parameters are non-zero\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== BUILD MODEL ========================\n", | |
| "model = build_transformer(\n", | |
| " src_vocab_size=VOCAB_SIZE, tgt_vocab_size=VOCAB_SIZE,\n", | |
| " src_seq_len=MAX_SEQ_LEN, tgt_seq_len=MAX_SEQ_LEN,\n", | |
| " d_model=D_MODEL, N=NUM_LAYERS, h=NUM_HEADS, kv_h=KV_HEADS,\n", | |
| " dropout=DROPOUT, d_ff=D_FF, use_repo=USE_REPO, use_flash=USE_FLASH,\n", | |
| ")\n", | |
| "model = model.to(device)\n", | |
| "\n", | |
| "total_params = sum(p.numel() for p in model.parameters())\n", | |
| "print(f\"Parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n", | |
| "print(f\"REPO: {'ON' if USE_REPO else 'OFF'}, Flash: {'ON' if USE_FLASH else 'OFF'}\")\n", | |
| "\n", | |
| "# Verify RMSNorm gamma is NOT zero\n", | |
| "for name, p in model.named_parameters():\n", | |
| " if 'gamma' in name:\n", | |
| " assert p.abs().sum() > 0, f\"FATAL: {name} is all zeros!\"\n", | |
| "print(\"[OK] RMSNorm gamma parameters are non-zero\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Running sanity check...\n", | |
| "Output shape: torch.Size([2, 128, 50260])\n", | |
| "Output range: [-0.8724, 0.9890]\n", | |
| "Output std: 0.1734\n", | |
| "[OK] Sanity check passed\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== SANITY CHECK ========================\n", | |
| "print(\"Running sanity check...\")\n", | |
| "model.train()\n", | |
| "dummy = torch.randint(0, VOCAB_SIZE, (2, 128), device=device)\n", | |
| "\n", | |
| "with torch.no_grad():\n", | |
| " x = model.tgt_embed(dummy)\n", | |
| " for layer in model.decoder.layers:\n", | |
| " x, _ = layer(x, tgt_mask=None, use_cache=False)\n", | |
| " x = model.decoder.norm(x)\n", | |
| " logits = model.project(x)\n", | |
| "\n", | |
| "print(f\"Output shape: {logits.shape}\")\n", | |
| "print(f\"Output range: [{logits.min().item():.4f}, {logits.max().item():.4f}]\")\n", | |
| "print(f\"Output std: {logits.std().item():.4f}\")\n", | |
| "assert logits.std().item() > 0.01, \"FATAL: Model output has near-zero variance!\"\n", | |
| "print(\"[OK] Sanity check passed\")\n", | |
| "\n", | |
| "del dummy, x, logits\n", | |
| "torch.cuda.empty_cache()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Found 50 parquet files in /home/spedrox/Transformers/SlimPajama-6B\n", | |
| "DataLoader ready (batch=22, workers=16)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== DATASET & DATALOADER ========================\n", | |
| "train_dataset = SlimPajamaDataset(\n", | |
| " data_dir=DATASET_DIR, tokenizer=tokenizer, max_length=MAX_SEQ_LEN,\n", | |
| ")\n", | |
| "train_loader = DataLoader(\n", | |
| " train_dataset, batch_size=BATCH_SIZE,\n", | |
| " num_workers=16, pin_memory=True, prefetch_factor=2,\n", | |
| ")\n", | |
| "print(f\"DataLoader ready (batch={BATCH_SIZE}, workers=16)\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n", | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n", | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: [wandb.login()] Using explicit session credentials for https://api.wandb.ai.\n", | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /home/spedrox/.netrc\n", | |
| "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mdinmaybrahmaofficial\u001b[0m (\u001b[33mdinmaybrahmaofficial-indian-institute-of-technology\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/html": [], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "Tracking run with wandb version 0.25.1" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "Run data is saved locally in <code>/home/spedrox/Transformers/wandb/run-20260402_174756-yomr6qws</code>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "Syncing run <strong><a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws' target=\"_blank\">slimpajama_0402_1747</a></strong> to <a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| " View project at <a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm' target=\"_blank\">https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm</a>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| " View run at <a href='https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws' target=\"_blank\">https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws</a>" | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "WandB: https://wandb.ai/dinmaybrahmaofficial-indian-institute-of-technology/Spedrox_llm/runs/yomr6qws\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== WANDB ========================\n", | |
| "wandb.login(key=\"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\")\n", | |
| "\n", | |
| "if USE_WANDB:\n", | |
| " wandb.init(\n", | |
| " project=WANDB_PROJECT,\n", | |
| " name=f\"slimpajama_{datetime.now().strftime('%m%d_%H%M')}\",\n", | |
| " config={\n", | |
| " \"model\": \"decoder_only_transformer\",\n", | |
| " \"params\": total_params,\n", | |
| " \"d_model\": D_MODEL, \"layers\": NUM_LAYERS,\n", | |
| " \"heads\": NUM_HEADS, \"kv_heads\": KV_HEADS,\n", | |
| " \"d_ff\": D_FF, \"seq_len\": MAX_SEQ_LEN,\n", | |
| " \"batch_size\": BATCH_SIZE, \"grad_accum\": GRAD_ACCUM,\n", | |
| " \"effective_batch\": BATCH_SIZE * GRAD_ACCUM,\n", | |
| " \"lr\": LEARNING_RATE, \"min_lr\": MIN_LR,\n", | |
| " \"warmup\": WARMUP_STEPS, \"total_steps\": TOTAL_STEPS,\n", | |
| " \"dataset\": \"SlimPajama-6B\",\n", | |
| " \"features\": [\"GQA\", \"REPO-Attention\", \"Flash-Attention\", \"RMSNorm\"],\n", | |
| " },\n", | |
| " tags=[\"pre-training\", \"slimpajama\", \"a100\"]\n", | |
| " )\n", | |
| " wandb.watch(model, log=\"all\", log_freq=200)\n", | |
| " print(f\"WandB: {wandb.run.url}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| " Step 0: LR = 0.00000100\n", | |
| " Step 10: LR = 0.00001100\n", | |
| " Step 50: LR = 0.00005100\n", | |
| " Step 100: LR = 0.00010100\n", | |
| " Step 200: LR = 0.00020000\n", | |
| " Step 500: LR = 0.00019985\n", | |
| " Step 1000: LR = 0.00019895\n", | |
| " Step 5000: LR = 0.00016474\n", | |
| " Step 10000: LR = 0.00008329\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== LR SCHEDULE ========================\n", | |
| "def get_lr(step):\n", | |
| " \"\"\"Cosine schedule with linear warmup.\"\"\"\n", | |
| " if step < WARMUP_STEPS:\n", | |
| " # Linear warmup\n", | |
| " return LEARNING_RATE * (step + 1) / WARMUP_STEPS\n", | |
| " # Cosine decay\n", | |
| " progress = (step - WARMUP_STEPS) / max(1, TOTAL_STEPS - WARMUP_STEPS)\n", | |
| " progress = min(progress, 1.0)\n", | |
| " cosine = 0.5 * (1.0 + math.cos(math.pi * progress))\n", | |
| " return MIN_LR + (LEARNING_RATE - MIN_LR) * cosine\n", | |
| "\n", | |
| "# Quick preview\n", | |
| "for s in [0, 10, 50, 100, 200, 500, 1000, 5000, 10000]:\n", | |
| " print(f\" Step {s:>6}: LR = {get_lr(s):.8f}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Params with weight decay : 115,867,776\n", | |
| "Params without weight decay: 38,715,220\n", | |
| "Mixed precision: BF16\n", | |
| "GradScaler: OFF (BF16)\n", | |
| "\n", | |
| "Starting training...\n", | |
| " Epochs: 1, Batch: 22, Accum: 8\n", | |
| " Effective batch: 176\n", | |
| " Peak LR: 0.0002, Warmup: 200 steps\n", | |
| " Est total steps: 16,645\n", | |
| "\n", | |
| "E1 B17000 | loss=5.0662 | lr=0.000194 | step=2125 | 9266s | 12.0GB\n", | |
| "E1 B17020 | loss=5.0566 | lr=0.000194 | step=2127 | 9277s | 12.0GB\n", | |
| "E1 B17040 | loss=4.9669 | lr=0.000194 | step=2130 | 9288s | 12.0GB\n", | |
| "E1 B17060 | loss=4.8286 | lr=0.000194 | step=2132 | 9299s | 12.0GB\n", | |
| "E1 B17080 | loss=4.9056 | lr=0.000194 | step=2135 | 9310s | 12.0GB\n", | |
| "E1 B17100 | loss=4.5468 | lr=0.000194 | step=2137 | 9320s | 12.0GB\n", | |
| "E1 B17120 | loss=5.1199 | lr=0.000194 | step=2140 | 9332s | 12.0GB\n", | |
| "E1 B17140 | loss=3.7074 | lr=0.000194 | step=2142 | 9342s | 12.0GB\n", | |
| "E1 B17160 | loss=4.3470 | lr=0.000194 | step=2145 | 9353s | 12.0GB\n", | |
| "E1 B17180 | loss=4.9766 | lr=0.000194 | step=2147 | 9364s | 12.0GB\n", | |
| "E1 B17200 | loss=5.0189 | lr=0.000194 | step=2150 | 9375s | 12.0GB\n", | |
| "E1 B17220 | loss=5.0189 | lr=0.000194 | step=2152 | 9386s | 12.0GB\n", | |
| "E1 B17240 | loss=5.0028 | lr=0.000194 | step=2155 | 9397s | 12.0GB\n", | |
| "E1 B17260 | loss=4.9502 | lr=0.000194 | step=2157 | 9408s | 12.0GB\n", | |
| "E1 B17280 | loss=5.0946 | lr=0.000194 | step=2160 | 9419s | 12.0GB\n", | |
| "E1 B17300 | loss=5.2172 | lr=0.000194 | step=2162 | 9430s | 12.0GB\n", | |
| "E1 B17320 | loss=4.7954 | lr=0.000194 | step=2165 | 9441s | 12.0GB\n", | |
| "E1 B17340 | loss=4.9832 | lr=0.000194 | step=2167 | 9451s | 12.0GB\n", | |
| "E1 B17360 | loss=4.9154 | lr=0.000194 | step=2170 | 9463s | 12.0GB\n", | |
| "E1 B17380 | loss=4.7914 | lr=0.000194 | step=2172 | 9473s | 12.0GB\n", | |
| "E1 B17400 | loss=4.9853 | lr=0.000194 | step=2175 | 9484s | 12.0GB\n", | |
| "E1 B17420 | loss=4.9046 | lr=0.000194 | step=2177 | 9495s | 12.0GB\n", | |
| "E1 B17440 | loss=4.7686 | lr=0.000194 | step=2180 | 9506s | 12.0GB\n", | |
| "E1 B17460 | loss=4.9282 | lr=0.000194 | step=2182 | 9517s | 12.0GB\n", | |
| "E1 B17480 | loss=5.2533 | lr=0.000194 | step=2185 | 9528s | 12.0GB\n", | |
| "E1 B17500 | loss=4.7660 | lr=0.000194 | step=2187 | 9538s | 12.0GB\n", | |
| "E1 B17520 | loss=4.8343 | lr=0.000194 | step=2190 | 9550s | 12.0GB\n", | |
| "E1 B17540 | loss=3.9227 | lr=0.000194 | step=2192 | 9560s | 12.0GB\n", | |
| "E1 B17560 | loss=5.1872 | lr=0.000194 | step=2195 | 9571s | 12.0GB\n", | |
| "E1 B17580 | loss=5.2201 | lr=0.000194 | step=2197 | 9582s | 12.0GB\n", | |
| "E1 B17600 | loss=4.9061 | lr=0.000194 | step=2200 | 9593s | 12.0GB\n", | |
| "E1 B17620 | loss=4.5560 | lr=0.000193 | step=2202 | 9604s | 12.0GB\n", | |
| "E1 B17640 | loss=5.0290 | lr=0.000193 | step=2205 | 9615s | 12.0GB\n", | |
| "E1 B83360 | loss=4.0679 | lr=0.000076 | step=10420 | 45447s | 12.0GB\n", | |
| "E1 B83380 | loss=3.8272 | lr=0.000076 | step=10422 | 45458s | 12.0GB\n", | |
| "E1 B83400 | loss=3.8961 | lr=0.000076 | step=10425 | 45469s | 12.0GB\n", | |
| "E1 B83420 | loss=3.6188 | lr=0.000076 | step=10427 | 45480s | 12.0GB\n", | |
| "E1 B83440 | loss=3.6587 | lr=0.000076 | step=10430 | 45490s | 12.0GB\n", | |
| "E1 B83460 | loss=3.7851 | lr=0.000076 | step=10432 | 45501s | 12.0GB\n", | |
| "E1 B83480 | loss=3.4612 | lr=0.000076 | step=10435 | 45512s | 12.0GB\n", | |
| "E1 B83500 | loss=3.9424 | lr=0.000076 | step=10437 | 45523s | 12.0GB\n", | |
| "E1 B83520 | loss=4.2557 | lr=0.000076 | step=10440 | 45534s | 12.0GB\n", | |
| "E1 B83540 | loss=4.2505 | lr=0.000076 | step=10442 | 45545s | 12.0GB\n", | |
| "E1 B83560 | loss=4.0652 | lr=0.000076 | step=10445 | 45556s | 12.0GB\n", | |
| "E1 B83580 | loss=3.4681 | lr=0.000076 | step=10447 | 45567s | 12.0GB\n", | |
| "E1 B83600 | loss=3.1749 | lr=0.000076 | step=10450 | 45578s | 12.0GB\n", | |
| "E1 B83620 | loss=2.6673 | lr=0.000076 | step=10452 | 45589s | 12.0GB\n", | |
| "E1 B83640 | loss=3.8833 | lr=0.000076 | step=10455 | 45599s | 12.0GB\n", | |
| "E1 B83660 | loss=3.7466 | lr=0.000076 | step=10457 | 45610s | 12.0GB\n", | |
| "E1 B83680 | loss=3.7969 | lr=0.000076 | step=10460 | 45621s | 12.0GB\n", | |
| "E1 B83700 | loss=4.0344 | lr=0.000076 | step=10462 | 45632s | 12.0GB\n", | |
| "E1 B83720 | loss=3.9571 | lr=0.000076 | step=10465 | 45643s | 12.0GB\n", | |
| "E1 B83740 | loss=4.1347 | lr=0.000076 | step=10467 | 45654s | 12.0GB\n", | |
| "E1 B83760 | loss=4.0433 | lr=0.000076 | step=10470 | 45665s | 12.0GB\n", | |
| "E1 B83780 | loss=3.6965 | lr=0.000076 | step=10472 | 45675s | 12.0GB\n", | |
| "E1 B83800 | loss=4.0063 | lr=0.000076 | step=10475 | 45686s | 12.0GB\n", | |
| "E1 B83820 | loss=3.9130 | lr=0.000076 | step=10477 | 45697s | 12.0GB\n", | |
| "E1 B83840 | loss=3.6705 | lr=0.000076 | step=10480 | 45708s | 12.0GB\n", | |
| "E1 B83860 | loss=4.1480 | lr=0.000075 | step=10482 | 45719s | 12.0GB\n", | |
| "E1 B83880 | loss=3.7317 | lr=0.000075 | step=10485 | 45730s | 12.0GB\n", | |
| "E1 B83900 | loss=4.1277 | lr=0.000075 | step=10487 | 45741s | 12.0GB\n", | |
| "E1 B83920 | loss=4.2789 | lr=0.000075 | step=10490 | 45752s | 12.0GB\n", | |
| "E1 B83940 | loss=4.2223 | lr=0.000075 | step=10492 | 45763s | 12.0GB\n", | |
| "E1 B83960 | loss=3.6111 | lr=0.000075 | step=10495 | 45774s | 12.0GB\n", | |
| "E1 B83980 | loss=3.0651 | lr=0.000075 | step=10497 | 45784s | 12.0GB\n", | |
| "E1 B84000 | loss=3.9246 | lr=0.000075 | step=10500 | 45795s | 12.0GB\n", | |
| "E1 B84020 | loss=2.6241 | lr=0.000075 | step=10502 | 45806s | 12.0GB\n", | |
| "E1 B84040 | loss=3.7022 | lr=0.000075 | step=10505 | 45817s | 12.0GB\n", | |
| "E1 B84060 | loss=3.1323 | lr=0.000075 | step=10507 | 45828s | 12.0GB\n", | |
| "E1 B84080 | loss=3.9758 | lr=0.000075 | step=10510 | 45839s | 12.0GB\n", | |
| "E1 B84100 | loss=2.6221 | lr=0.000075 | step=10512 | 45850s | 12.0GB\n", | |
| "E1 B84120 | loss=3.9541 | lr=0.000075 | step=10515 | 45861s | 12.0GB\n", | |
| "E1 B84140 | loss=4.1421 | lr=0.000075 | step=10517 | 45872s | 12.0GB\n", | |
| "E1 B84160 | loss=3.3570 | lr=0.000075 | step=10520 | 45883s | 12.0GB\n", | |
| "E1 B84180 | loss=3.7867 | lr=0.000075 | step=10522 | 45894s | 12.0GB\n", | |
| "E1 B84200 | loss=1.4464 | lr=0.000075 | step=10525 | 45904s | 12.0GB\n", | |
| "E1 B84220 | loss=3.6208 | lr=0.000075 | step=10527 | 45916s | 12.0GB\n", | |
| "E1 B84240 | loss=3.1878 | lr=0.000075 | step=10530 | 45926s | 12.0GB\n", | |
| "E1 B84260 | loss=3.6956 | lr=0.000075 | step=10532 | 45937s | 12.0GB\n", | |
| "E1 B84280 | loss=1.1194 | lr=0.000075 | step=10535 | 45948s | 12.0GB\n", | |
| "E1 B84300 | loss=2.7751 | lr=0.000075 | step=10537 | 45959s | 12.0GB\n", | |
| "E1 B84320 | loss=3.9186 | lr=0.000075 | step=10540 | 45970s | 12.0GB\n", | |
| "E1 B84340 | loss=3.9909 | lr=0.000075 | step=10542 | 45981s | 12.0GB\n", | |
| "E1 B84360 | loss=4.1888 | lr=0.000074 | step=10545 | 45992s | 12.0GB\n", | |
| "E1 B84380 | loss=3.7813 | lr=0.000074 | step=10547 | 46003s | 12.0GB\n", | |
| "E1 B84400 | loss=3.2111 | lr=0.000074 | step=10550 | 46014s | 12.0GB\n", | |
| "E1 B84420 | loss=3.9105 | lr=0.000074 | step=10552 | 46025s | 12.0GB\n", | |
| "E1 B84440 | loss=4.1016 | lr=0.000074 | step=10555 | 46036s | 12.0GB\n", | |
| "E1 B84460 | loss=3.9241 | lr=0.000074 | step=10557 | 46047s | 12.0GB\n", | |
| "E1 B84480 | loss=3.5857 | lr=0.000074 | step=10560 | 46057s | 12.0GB\n", | |
| "E1 B84500 | loss=3.9282 | lr=0.000074 | step=10562 | 46068s | 12.0GB\n", | |
| "E1 B84520 | loss=3.9756 | lr=0.000074 | step=10565 | 46079s | 12.0GB\n", | |
| "E1 B84540 | loss=3.8981 | lr=0.000074 | step=10567 | 46090s | 12.0GB\n", | |
| "E1 B84560 | loss=3.6915 | lr=0.000074 | step=10570 | 46101s | 12.0GB\n", | |
| "E1 B84580 | loss=3.3759 | lr=0.000074 | step=10572 | 46112s | 12.0GB\n", | |
| "E1 B84600 | loss=4.0181 | lr=0.000074 | step=10575 | 46123s | 12.0GB\n", | |
| "E1 B84620 | loss=3.9899 | lr=0.000074 | step=10577 | 46134s | 12.0GB\n", | |
| "E1 B84640 | loss=2.7491 | lr=0.000074 | step=10580 | 46144s | 12.0GB\n", | |
| "E1 B84660 | loss=3.8926 | lr=0.000074 | step=10582 | 46156s | 12.0GB\n", | |
| "E1 B84680 | loss=3.9660 | lr=0.000074 | step=10585 | 46166s | 12.0GB\n", | |
| "E1 B84700 | loss=3.7381 | lr=0.000074 | step=10587 | 46177s | 12.0GB\n", | |
| "E1 B84720 | loss=3.6865 | lr=0.000074 | step=10590 | 46188s | 12.0GB\n", | |
| "E1 B84740 | loss=4.0988 | lr=0.000074 | step=10592 | 46199s | 12.0GB\n", | |
| "E1 B84760 | loss=3.9796 | lr=0.000074 | step=10595 | 46210s | 12.0GB\n", | |
| "E1 B84780 | loss=3.4667 | lr=0.000074 | step=10597 | 46221s | 12.0GB\n", | |
| "E1 B84800 | loss=3.3943 | lr=0.000074 | step=10600 | 46232s | 12.0GB\n", | |
| "E1 B84820 | loss=4.2667 | lr=0.000074 | step=10602 | 46243s | 12.0GB\n", | |
| "E1 B84840 | loss=3.4909 | lr=0.000074 | step=10605 | 46254s | 12.0GB\n", | |
| "E1 B84860 | loss=3.9001 | lr=0.000074 | step=10607 | 46265s | 12.0GB\n", | |
| "E1 B84880 | loss=3.2266 | lr=0.000073 | step=10610 | 46276s | 12.0GB\n", | |
| "E1 B84900 | loss=3.8690 | lr=0.000073 | step=10612 | 46286s | 12.0GB\n", | |
| "E1 B84920 | loss=3.6407 | lr=0.000073 | step=10615 | 46297s | 12.0GB\n", | |
| "E1 B84940 | loss=3.8192 | lr=0.000073 | step=10617 | 46308s | 12.0GB\n", | |
| "E1 B84960 | loss=3.7542 | lr=0.000073 | step=10620 | 46319s | 12.0GB\n", | |
| "E1 B84980 | loss=4.0122 | lr=0.000073 | step=10622 | 46330s | 12.0GB\n", | |
| "E1 B85000 | loss=4.2165 | lr=0.000073 | step=10625 | 46341s | 12.0GB\n", | |
| "E1 B85020 | loss=3.0944 | lr=0.000073 | step=10627 | 46352s | 12.0GB\n", | |
| "E1 B85040 | loss=3.9340 | lr=0.000073 | step=10630 | 46363s | 12.0GB\n", | |
| "E1 B85060 | loss=3.3452 | lr=0.000073 | step=10632 | 46374s | 12.0GB\n", | |
| "E1 B85080 | loss=3.7265 | lr=0.000073 | step=10635 | 46385s | 12.0GB\n", | |
| "E1 B85100 | loss=3.2035 | lr=0.000073 | step=10637 | 46395s | 12.0GB\n", | |
| "E1 B85120 | loss=3.4126 | lr=0.000073 | step=10640 | 46407s | 12.0GB\n", | |
| "E1 B85140 | loss=3.9198 | lr=0.000073 | step=10642 | 46417s | 12.0GB\n", | |
| "E1 B85160 | loss=3.5839 | lr=0.000073 | step=10645 | 46428s | 12.0GB\n", | |
| "E1 B85180 | loss=3.8842 | lr=0.000073 | step=10647 | 46439s | 12.0GB\n", | |
| "E1 B85200 | loss=3.9351 | lr=0.000073 | step=10650 | 46450s | 12.0GB\n", | |
| "E1 B85220 | loss=3.3622 | lr=0.000073 | step=10652 | 46461s | 12.0GB\n", | |
| "E1 B85240 | loss=3.7592 | lr=0.000073 | step=10655 | 46472s | 12.0GB\n", | |
| "E1 B85260 | loss=2.7395 | lr=0.000073 | step=10657 | 46483s | 12.0GB\n", | |
| "E1 B85280 | loss=3.8370 | lr=0.000073 | step=10660 | 46494s | 12.0GB\n", | |
| "E1 B85300 | loss=4.0697 | lr=0.000073 | step=10662 | 46505s | 12.0GB\n", | |
| "E1 B85320 | loss=3.1111 | lr=0.000073 | step=10665 | 46516s | 12.0GB\n", | |
| "E1 B85340 | loss=3.9165 | lr=0.000073 | step=10667 | 46527s | 12.0GB\n", | |
| "E1 B85360 | loss=4.0334 | lr=0.000073 | step=10670 | 46538s | 12.0GB\n", | |
| "E1 B85380 | loss=3.8077 | lr=0.000073 | step=10672 | 46548s | 12.0GB\n", | |
| "E1 B85400 | loss=4.0962 | lr=0.000072 | step=10675 | 46559s | 12.0GB\n", | |
| "E1 B85420 | loss=3.7642 | lr=0.000072 | step=10677 | 46570s | 12.0GB\n", | |
| "E1 B85640 | loss=3.8509 | lr=0.000072 | step=10705 | 46690s | 12.0GB\n", | |
| "E1 B85660 | loss=3.9086 | lr=0.000072 | step=10707 | 46701s | 12.0GB\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (2156 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1047 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (2619 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1225 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (3242 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (2371 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1314 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (3373 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (3054 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1690 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1364 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1551 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (5217 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1419 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (1271 > 1024). Running this sequence through the model will result in indexing errors\n", | |
| "Token indices sequence length is longer than the specified maximum sequence length for this model (2637 > 1024). Running this sequence through the model will result in indexing errors\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "E1 B 0 | loss=10.8247 | lr=0.000001 | step=0 | 3s | 10.7GB\n", | |
| "E1 B 20 | loss=10.8169 | lr=0.000003 | step=2 | 14s | 12.0GB\n", | |
| "E1 B 40 | loss=10.7877 | lr=0.000006 | step=5 | 25s | 12.0GB\n", | |
| "E1 B 60 | loss=10.7531 | lr=0.000008 | step=7 | 36s | 12.0GB\n", | |
| "E1 B 80 | loss=10.6827 | lr=0.000011 | step=10 | 47s | 12.0GB\n", | |
| "E1 B 100 | loss=10.6539 | lr=0.000013 | step=12 | 57s | 12.0GB\n", | |
| "E1 B 120 | loss=10.5372 | lr=0.000016 | step=15 | 68s | 12.0GB\n", | |
| "E1 B 140 | loss=10.5912 | lr=0.000018 | step=17 | 79s | 12.0GB\n", | |
| "E1 B 160 | loss=10.4607 | lr=0.000021 | step=20 | 90s | 12.0GB\n", | |
| "E1 B 180 | loss=10.3397 | lr=0.000023 | step=22 | 101s | 12.0GB\n", | |
| "E1 B 200 | loss=9.9626 | lr=0.000026 | step=25 | 112s | 12.0GB\n", | |
| "E1 B 220 | loss=10.2695 | lr=0.000028 | step=27 | 123s | 12.0GB\n", | |
| "E1 B 240 | loss=10.2457 | lr=0.000031 | step=30 | 133s | 12.0GB\n", | |
| "E1 B 260 | loss=10.1893 | lr=0.000033 | step=32 | 144s | 12.0GB\n", | |
| "E1 B 280 | loss=10.1125 | lr=0.000036 | step=35 | 155s | 12.0GB\n", | |
| "E1 B 300 | loss=10.0175 | lr=0.000038 | step=37 | 166s | 12.0GB\n", | |
| "E1 B 320 | loss=9.9452 | lr=0.000041 | step=40 | 177s | 12.0GB\n", | |
| "E1 B 340 | loss=9.8831 | lr=0.000043 | step=42 | 188s | 12.0GB\n", | |
| "E1 B 360 | loss=9.4486 | lr=0.000046 | step=45 | 199s | 12.0GB\n", | |
| "E1 B 380 | loss=9.6624 | lr=0.000048 | step=47 | 210s | 12.0GB\n", | |
| "E1 B 400 | loss=9.5566 | lr=0.000051 | step=50 | 221s | 12.0GB\n", | |
| "E1 B 420 | loss=9.5670 | lr=0.000053 | step=52 | 232s | 12.0GB\n", | |
| "E1 B 440 | loss=9.6506 | lr=0.000056 | step=55 | 242s | 12.0GB\n", | |
| "E1 B 460 | loss=9.2208 | lr=0.000058 | step=57 | 253s | 12.0GB\n", | |
| "E1 B 480 | loss=9.1615 | lr=0.000061 | step=60 | 264s | 12.0GB\n", | |
| "E1 B 500 | loss=9.1667 | lr=0.000063 | step=62 | 275s | 12.0GB\n", | |
| "E1 B 520 | loss=8.0803 | lr=0.000066 | step=65 | 286s | 12.0GB\n", | |
| "E1 B 540 | loss=8.9027 | lr=0.000068 | step=67 | 297s | 12.0GB\n", | |
| "E1 B 560 | loss=8.8101 | lr=0.000071 | step=70 | 308s | 12.0GB\n", | |
| "E1 B 580 | loss=8.6734 | lr=0.000073 | step=72 | 319s | 12.0GB\n", | |
| "E1 B 600 | loss=8.6437 | lr=0.000076 | step=75 | 329s | 12.0GB\n", | |
| "E1 B 620 | loss=8.5093 | lr=0.000078 | step=77 | 341s | 12.0GB\n", | |
| "E1 B 640 | loss=8.3063 | lr=0.000081 | step=80 | 351s | 12.0GB\n" | |
| ] | |
| }, | |
| { | |
| "ename": "KeyboardInterrupt", | |
| "evalue": "", | |
| "output_type": "error", | |
| "traceback": [ | |
| "\u001b[31m---------------------------------------------------------------------------\u001b[39m", | |
| "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", | |
| "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[10]\u001b[39m\u001b[32m, line 94\u001b[39m\n\u001b[32m 90\u001b[39m \u001b[38;5;66;03m# Backward\u001b[39;00m\n\u001b[32m 91\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m scaler.is_enabled():\n\u001b[32m 92\u001b[39m scaler.scale(loss).backward()\n\u001b[32m 93\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m94\u001b[39m loss.backward()\n\u001b[32m 95\u001b[39m \n\u001b[32m 96\u001b[39m real_loss = loss.item() * GRAD_ACCUM\n\u001b[32m 97\u001b[39m total_loss += real_loss\n", | |
| "\u001b[36mFile \u001b[39m\u001b[32m~/.venv/lib/python3.12/site-packages/torch/_tensor.py:581\u001b[39m, in \u001b[36mTensor.backward\u001b[39m\u001b[34m(self, gradient, retain_graph, create_graph, inputs)\u001b[39m\n\u001b[32m 571\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 572\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[32m 573\u001b[39m Tensor.backward,\n\u001b[32m 574\u001b[39m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[32m (...)\u001b[39m\u001b[32m 579\u001b[39m inputs=inputs,\n\u001b[32m 580\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m581\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mautograd\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 582\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m=\u001b[49m\u001b[43minputs\u001b[49m\n\u001b[32m 583\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", | |
| "\u001b[36mFile \u001b[39m\u001b[32m~/.venv/lib/python3.12/site-packages/torch/autograd/__init__.py:347\u001b[39m, in \u001b[36mbackward\u001b[39m\u001b[34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[39m\n\u001b[32m 342\u001b[39m retain_graph = create_graph\n\u001b[32m 344\u001b[39m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[32m 345\u001b[39m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[32m 346\u001b[39m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m347\u001b[39m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 348\u001b[39m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 349\u001b[39m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 350\u001b[39m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 351\u001b[39m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 352\u001b[39m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 353\u001b[39m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 354\u001b[39m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 355\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", | |
| "\u001b[36mFile \u001b[39m\u001b[32m~/.venv/lib/python3.12/site-packages/torch/autograd/graph.py:825\u001b[39m, in \u001b[36m_engine_run_backward\u001b[39m\u001b[34m(t_outputs, *args, **kwargs)\u001b[39m\n\u001b[32m 823\u001b[39m unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[32m 824\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m825\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_execution_engine\u001b[49m\u001b[43m.\u001b[49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[32m 826\u001b[39m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\n\u001b[32m 827\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[32m 828\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 829\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n", | |
| "\u001b[31mKeyboardInterrupt\u001b[39m: " | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# ======================== TRAINING LOOP ========================\n", | |
| "\n", | |
| "os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n", | |
| "\n", | |
| "# Separate param groups: NO weight decay on norms, biases, embeddings\n", | |
| "decay_params = []\n", | |
| "no_decay_params = []\n", | |
| "for name, p in model.named_parameters():\n", | |
| " if not p.requires_grad:\n", | |
| " continue\n", | |
| " if p.dim() == 1 or \"norm\" in name or \"gamma\" in name or \"bias\" in name or \"embed\" in name:\n", | |
| " no_decay_params.append(p)\n", | |
| " else:\n", | |
| " decay_params.append(p)\n", | |
| "\n", | |
| "print(f\"Params with weight decay : {sum(p.numel() for p in decay_params):,}\")\n", | |
| "print(f\"Params without weight decay: {sum(p.numel() for p in no_decay_params):,}\")\n", | |
| "\n", | |
| "optimizer = torch.optim.AdamW([\n", | |
| " {\"params\": decay_params, \"weight_decay\": WEIGHT_DECAY},\n", | |
| " {\"params\": no_decay_params, \"weight_decay\": 0.0},\n", | |
| "], lr=LEARNING_RATE, betas=(0.9, 0.95))\n", | |
| "\n", | |
| "# BF16 on A100 (no GradScaler needed for BF16)\n", | |
| "use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()\n", | |
| "amp_dtype = torch.bfloat16 if use_bf16 else torch.float16\n", | |
| "scaler = GradScaler(enabled=(not use_bf16)) # Only needed for FP16\n", | |
| "\n", | |
| "print(f\"Mixed precision: {'BF16' if use_bf16 else 'FP16'}\")\n", | |
| "print(f\"GradScaler: {'OFF (BF16)' if use_bf16 else 'ON (FP16)'}\")\n", | |
| "\n", | |
| "model.train()\n", | |
| "global_step = 0\n", | |
| "best_loss = float('inf')\n", | |
| "last_ckpt_time = time.time()\n", | |
| "\n", | |
| "# Set initial LR (warmup starts from near-zero)\n", | |
| "for pg in optimizer.param_groups:\n", | |
| " pg['lr'] = get_lr(0)\n", | |
| "\n", | |
| "print(f\"\\nStarting training...\")\n", | |
| "print(f\" Epochs: {EPOCHS}, Batch: {BATCH_SIZE}, Accum: {GRAD_ACCUM}\")\n", | |
| "print(f\" Effective batch: {BATCH_SIZE * GRAD_ACCUM}\")\n", | |
| "print(f\" Peak LR: {LEARNING_RATE}, Warmup: {WARMUP_STEPS} steps\")\n", | |
| "print(f\" Est total steps: {TOTAL_STEPS:,}\")\n", | |
| "print()\n", | |
| "\n", | |
| "for epoch in range(EPOCHS):\n", | |
| " total_loss = 0.0\n", | |
| " batch_count = 0\n", | |
| " micro_count = 0\n", | |
| " epoch_start = time.time()\n", | |
| "\n", | |
| " optimizer.zero_grad(set_to_none=True)\n", | |
| "\n", | |
| " for i, batch in enumerate(train_loader):\n", | |
| " # Auto-checkpoint every 2 hours\n", | |
| " if time.time() - last_ckpt_time >= 7200:\n", | |
| " avg = total_loss / max(batch_count, 1)\n", | |
| " save_checkpoint(model, optimizer, epoch, global_step, avg, best_loss,\n", | |
| " CHECKPOINT_DIR, f\"auto_epoch{epoch+1}_step{global_step}.pt\")\n", | |
| " last_ckpt_time = time.time()\n", | |
| "\n", | |
| " input_ids = batch[\"input_ids\"].to(device, non_blocking=True)\n", | |
| " labels = batch[\"labels\"].to(device, non_blocking=True)\n", | |
| "\n", | |
| " # Forward\n", | |
| " with autocast(device_type=\"cuda\", dtype=amp_dtype):\n", | |
| " x = model.tgt_embed(input_ids)\n", | |
| " for layer in model.decoder.layers:\n", | |
| " x, _ = layer(x, tgt_mask=None, use_cache=False)\n", | |
| " x = model.decoder.norm(x)\n", | |
| " logits = model.project(x)\n", | |
| "\n", | |
| " shift_logits = logits[..., :-1, :].contiguous()\n", | |
| " shift_labels = labels[..., 1:].contiguous()\n", | |
| " loss = nn.CrossEntropyLoss(ignore_index=-100)(\n", | |
| " shift_logits.view(-1, shift_logits.size(-1)),\n", | |
| " shift_labels.view(-1)\n", | |
| " )\n", | |
| " loss = loss / GRAD_ACCUM\n", | |
| "\n", | |
| " # Skip bad batches\n", | |
| " if not torch.isfinite(loss):\n", | |
| " print(f\"[WARN] Non-finite loss at batch {i}, skipping\")\n", | |
| " optimizer.zero_grad(set_to_none=True)\n", | |
| " micro_count = 0\n", | |
| " continue\n", | |
| "\n", | |
| " # Backward\n", | |
| " if scaler.is_enabled():\n", | |
| " scaler.scale(loss).backward()\n", | |
| " else:\n", | |
| " loss.backward()\n", | |
| "\n", | |
| " real_loss = loss.item() * GRAD_ACCUM\n", | |
| " total_loss += real_loss\n", | |
| " batch_count += 1\n", | |
| " micro_count += 1\n", | |
| "\n", | |
| " # Optimizer step every GRAD_ACCUM micro-batches\n", | |
| " if micro_count >= GRAD_ACCUM:\n", | |
| " if scaler.is_enabled():\n", | |
| " scaler.unscale_(optimizer)\n", | |
| "\n", | |
| " grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n", | |
| "\n", | |
| " if scaler.is_enabled():\n", | |
| " scaler.step(optimizer)\n", | |
| " scaler.update()\n", | |
| " else:\n", | |
| " optimizer.step()\n", | |
| "\n", | |
| " optimizer.zero_grad(set_to_none=True)\n", | |
| " global_step += 1\n", | |
| " micro_count = 0\n", | |
| "\n", | |
| " # Update LR\n", | |
| " lr = get_lr(global_step)\n", | |
| " for pg in optimizer.param_groups:\n", | |
| " pg['lr'] = lr\n", | |
| "\n", | |
| " # Log to WandB\n", | |
| " if USE_WANDB:\n", | |
| " log_dict = {\n", | |
| " \"train/loss\": real_loss,\n", | |
| " \"train/lr\": lr,\n", | |
| " \"train/grad_norm\": float(grad_norm),\n", | |
| " \"train/step\": global_step,\n", | |
| " \"train/epoch\": epoch + 1,\n", | |
| " }\n", | |
| " if torch.cuda.is_available():\n", | |
| " log_dict[\"system/gpu_gb\"] = torch.cuda.memory_allocated() / 1e9\n", | |
| " wandb.log(log_dict, step=global_step)\n", | |
| "\n", | |
| " # Print progress\n", | |
| " if i % 20 == 0:\n", | |
| " elapsed = time.time() - epoch_start\n", | |
| " lr_now = optimizer.param_groups[0]['lr']\n", | |
| " gpu_gb = torch.cuda.memory_allocated() / 1e9 if torch.cuda.is_available() else 0\n", | |
| " print(f\"E{epoch+1} B{i:>5} | loss={real_loss:.4f} | lr={lr_now:.6f} | \"\n", | |
| " f\"step={global_step} | {elapsed:.0f}s | {gpu_gb:.1f}GB\")\n", | |
| "\n", | |
| " if i % 50 == 0 and torch.cuda.is_available():\n", | |
| " torch.cuda.empty_cache()\n", | |
| "\n", | |
| " # Flush leftover micro-batches\n", | |
| " if micro_count > 0:\n", | |
| " if scaler.is_enabled():\n", | |
| " scaler.unscale_(optimizer)\n", | |
| " torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n", | |
| " if scaler.is_enabled():\n", | |
| " scaler.step(optimizer)\n", | |
| " scaler.update()\n", | |
| " else:\n", | |
| " optimizer.step()\n", | |
| " optimizer.zero_grad(set_to_none=True)\n", | |
| " global_step += 1\n", | |
| "\n", | |
| " # End of epoch\n", | |
| " avg_loss = total_loss / max(batch_count, 1)\n", | |
| " duration = time.time() - epoch_start\n", | |
| " print(f\"\\nEpoch {epoch+1}/{EPOCHS} done | avg_loss={avg_loss:.4f} | {duration:.0f}s\")\n", | |
| "\n", | |
| " if USE_WANDB:\n", | |
| " wandb.log({\"epoch/avg_loss\": avg_loss, \"epoch/duration\": duration}, step=global_step)\n", | |
| "\n", | |
| " if avg_loss < best_loss:\n", | |
| " best_loss = avg_loss\n", | |
| " save_checkpoint(model, optimizer, epoch, global_step, avg_loss, best_loss,\n", | |
| " CHECKPOINT_DIR, \"best_model.pt\")\n", | |
| " if USE_WANDB:\n", | |
| " wandb.log({\"train/best_loss\": best_loss}, step=global_step)\n", | |
| "\n", | |
| " save_checkpoint(model, optimizer, epoch, global_step, avg_loss, best_loss,\n", | |
| " CHECKPOINT_DIR, f\"epoch_{epoch+1}.pt\")\n", | |
| "\n", | |
| "# Final\n", | |
| "print(\"\\nTraining complete!\")\n", | |
| "save_checkpoint(model, optimizer, EPOCHS-1, global_step, avg_loss, best_loss,\n", | |
| " CHECKPOINT_DIR, \"final_model.pt\")\n", | |
| "if USE_WANDB:\n", | |
| " wandb.finish()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# ======================== RESUME FROM CHECKPOINT ========================\n", | |
| "# Uncomment to resume:\n", | |
| "\"\"\"\n", | |
| "ckpt = torch.load(os.path.join(CHECKPOINT_DIR, \"best_model.pt\"), map_location=device, weights_only=False)\n", | |
| "model.load_state_dict(ckpt['model_state_dict'])\n", | |
| "optimizer.load_state_dict(ckpt['optimizer_state_dict'])\n", | |
| "global_step = ckpt['global_step']\n", | |
| "best_loss = ckpt['best_loss']\n", | |
| "print(f\"Resumed from step {global_step}\")\n", | |
| "\"\"\"\n", | |
| "print(\"Resume cell ready (uncomment to use).\")" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv (3.12.3)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python", | |
| "version": "3.12.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } | |
Xet Storage Details
- Size:
- 46.9 kB
- Xet hash:
- 081b178d4e10267905d2423d3aed52a1dd702bc85bd4c939b70630472ce86a41
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.