Upload train.ipynb
Browse files- train.ipynb +1 -0
train.ipynb
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"!git clone https://github.com/karpathy/nanoGPT.git\n%cd nanoGPT\n\n!pip install transformers datasets tiktoken wandb tqdm","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2026-03-23T16:13:04.163593Z","iopub.execute_input":"2026-03-23T16:13:04.163877Z","iopub.status.idle":"2026-03-23T16:13:10.085088Z","shell.execute_reply.started":"2026-03-23T16:13:04.163842Z","shell.execute_reply":"2026-03-23T16:13:10.084392Z"}},"outputs":[{"name":"stdout","text":"Cloning into 'nanoGPT'...\nremote: Enumerating objects: 689, done.\u001b[K\nremote: Total 689 (delta 0), reused 0 (delta 0), pack-reused 689 (from 1)\u001b[K\nReceiving objects: 100% (689/689), 975.25 KiB | 5.39 MiB/s, done.\nResolving deltas: 100% (382/382), done.\n/kaggle/working/nanoGPT\nRequirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (5.0.0)\nRequirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.8.3)\nRequirement already satisfied: tiktoken in /usr/local/lib/python3.12/dist-packages (0.12.0)\nRequirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (0.25.0)\nRequirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (4.67.3)\nRequirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers) (3.24.3)\nRequirement already satisfied: huggingface-hub<2.0,>=1.3.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (1.4.1)\nRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2.0.2)\nRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (26.0)\nRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers) (6.0.3)\nRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers) (2025.11.3)\nRequirement already satisfied: tokenizers<=0.23.0,>=0.22.0 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.22.2)\nRequirement already satisfied: typer-slim in /usr/local/lib/python3.12/dist-packages (from transformers) (0.24.0)\nRequirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.12/dist-packages (from transformers) (0.7.0)\nRequirement already satisfied: pyarrow>=21.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (23.0.1)\nRequirement already satisfied: dill<0.4.2,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.4.1)\nRequirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.3.3)\nRequirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4)\nRequirement already satisfied: httpx<1.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.28.1)\nRequirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.6.0)\nRequirement already satisfied: multiprocess<0.70.20 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16)\nRequirement already satisfied: fsspec<=2026.2.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (2026.2.0)\nRequirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.12/dist-packages (from wandb) (8.3.1)\nRequirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (3.1.46)\nRequirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb) (4.9.2)\nRequirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (5.29.5)\nRequirement already satisfied: pydantic<3 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.12.3)\nRequirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.53.0)\nRequirement already satisfied: typing-extensions<5,>=4.8 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.15.0)\nRequirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (3.13.3)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12)\nRequirement already satisfied: anyio in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (4.12.1)\nRequirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (2026.1.4)\nRequirement already satisfied: httpcore==1.* in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (1.0.9)\nRequirement already satisfied: idna in /usr/local/lib/python3.12/dist-packages (from httpx<1.0.0->datasets) (3.11)\nRequirement already satisfied: h11>=0.16 in /usr/local/lib/python3.12/dist-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)\nRequirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (1.3.0)\nRequirement already satisfied: shellingham in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<2.0,>=1.3.0->transformers) (1.5.4)\nRequirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.7.0)\nRequirement already satisfied: pydantic-core==2.41.4 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (2.41.4)\nRequirement already satisfied: typing-inspection>=0.4.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.4.2)\nRequirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0)\nRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0)\nRequirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2)\nRequirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.3)\nRequirement already satisfied: typer>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from typer-slim->transformers) (0.24.1)\nRequirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (2.6.1)\nRequirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.4.0)\nRequirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (25.4.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.8.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (6.7.1)\nRequirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (0.4.1)\nRequirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.22.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2)\nRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)\nRequirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->transformers) (13.9.4)\nRequirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.12/dist-packages (from typer>=0.24.0->typer-slim->transformers) (0.0.4)\nRequirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer>=0.24.0->typer-slim->transformers) (4.0.0)\nRequirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.12/dist-packages (from rich>=12.3.0->typer>=0.24.0->typer-slim->transformers) (2.19.2)\nRequirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.12/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer>=0.24.0->typer-slim->transformers) (0.1.2)\n","output_type":"stream"}],"execution_count":1},{"cell_type":"code","source":"!mkdir -p data/emails","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-03-23T16:13:24.273530Z","iopub.execute_input":"2026-03-23T16:13:24.274239Z","iopub.status.idle":"2026-03-23T16:13:24.391465Z","shell.execute_reply.started":"2026-03-23T16:13:24.274191Z","shell.execute_reply":"2026-03-23T16:13:24.390584Z"}},"outputs":[],"execution_count":2},{"cell_type":"code","source":"%%writefile data/emails/prepare.py\nimport os\nimport tqdm\nimport numpy as np\nimport tiktoken\nfrom datasets import load_dataset\n\nprint(\"Lade Kamisori-daijin/email-datasets-20k...\")\ndataset = load_dataset(\"Kamisori-daijin/email-datasets-20k\", split='train')\n\ndef get_text(example):\n instr = str(example['instruction']).strip()\n resp = str(example['output']).strip()\n return f\"### Instruction:\\n{instr}\\n\\n### Response:\\n{resp}\\n<|endoftext|>\\n\"\n\nprint(\"Formatting data...\")\nall_text = \"\"\nfor ex in tqdm.tqdm(dataset):\n all_text += get_text(ex)\n\nenc = tiktoken.get_encoding(\"gpt2\")\ntrain_data = all_text[:int(len(all_text)*0.9)]\nval_data = all_text[int(len(all_text)*0.9):]\n\ntrain_ids = enc.encode_ordinary(train_data)\nval_ids = enc.encode_ordinary(val_data)\n\nprint(f\"Train Tokens: {len(train_ids):,}\")\n\ntrain_ids = np.array(train_ids, dtype=np.uint16)\nval_ids = np.array(val_ids, dtype=np.uint16)\ntrain_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))\nval_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))\nprint(\"Done!\")\n\nsample_ids = train_ids[:500]\n\ndecoded_text = enc.decode(sample_ids)\n\nprint(\"-\" * 30)\nprint(\"Preview of the first 500 tokens:\")\nprint(\"-\" * 30)\nprint(decoded_text)\nprint(\"-\" * 30)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-03-23T16:17:55.501931Z","iopub.execute_input":"2026-03-23T16:17:55.502554Z","iopub.status.idle":"2026-03-23T16:17:55.507488Z","shell.execute_reply.started":"2026-03-23T16:17:55.502515Z","shell.execute_reply":"2026-03-23T16:17:55.506921Z"}},"outputs":[{"name":"stdout","text":"Overwriting data/emails/prepare.py\n","output_type":"stream"}],"execution_count":7},{"cell_type":"code","source":"!python3 data/emails/prepare.py","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-03-23T16:18:09.092816Z","iopub.execute_input":"2026-03-23T16:18:09.093115Z","iopub.status.idle":"2026-03-23T16:20:36.549533Z","shell.execute_reply.started":"2026-03-23T16:18:09.093089Z","shell.execute_reply":"2026-03-23T16:20:36.548858Z"}},"outputs":[{"name":"stdout","text":"Lade Kamisori-daijin/email-datasets-20k...\nWarning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.\nFormatting data...\n100%|ββββββββββββββββββββββββββββββββββββ| 19849/19849 [02:20<00:00, 140.80it/s]\nTrain Tokens: 4,416,222\nDone!\n------------------------------\nPreview of the first 500 tokens:\n------------------------------\n### Instruction:\nWrite a firm and authoritative business email from a Account Manager to a Technical Director regarding inviting to a beta testing program, specifically following a negative viral tweet.\n\n### Response:\n{\"subject\":\"Urgent: Invitation to Beta Testing - Addressing Recent Concerns\",\"body\":\"Dear [Technical Director Name],\\n\\nI\\u2019m writing to you following the recent viral tweet regarding [Product Name] and to proactively address the concerns raised. We acknowledge the issues highlighted and are taking immediate steps to rectify them.\\n\\nTo demonstrate our commitment to transparency and rapid improvement, we\\u2019d like to formally invite you to participate in a dedicated beta testing program for the upcoming update.\\n\\nYour technical expertise and understanding of [Product Name]'s core functionalities would be invaluable. We believe your feedback will be crucial in ensuring a successful and robust release. This program will provide you with exclusive early access to the updated version, alongside a direct line of communication with our development team for immediate feedback and issue reporting.\\n\\nWe understand the gravity of the situation and want to swiftly resolve these issues. Participation in this beta program is a key component of our strategy. To learn more about the program and sign up, please visit [Link to Beta Program].\\n\\nPlease confirm your participation by [Date].\\n\\nSincerely,\\n[Name]\\nAccount Manager\"}\n<|endoftext|>\n### Instruction:\nWrite a firm and authoritative business email from a Senior Engineer to a Internal Team regarding inviting to a beta testing program, specifically while the system is partially down.\n\n### Response:\n{\"subject\":\"Urgent: Beta Testing Invitation - System Partial Outage\",\"body\":\"Subject: Urgent: Beta Testing Invitation - System Partial Outage\\n\\nTeam,\\n\\nAs you are likely aware, we are currently experiencing a partial outage affecting [System Name]. While we work diligently to restore full functionality, we require immediate assistance with a critical beta testing program.\\n\\nDue to the ongoing instability, **participation in this beta testing program is strictly limited to experienced engineers with a proven understanding of [Relevant Technology] and a strong ability to diagnose and report issues effectively.** We need your expertise to quickly identify and document the scope of the problem.\\n\\nWe will be providing a dedicated workspace with\n------------------------------\n","output_type":"stream"}],"execution_count":8},{"cell_type":"code","source":"%%writefile config/train_emails_instr.py\nout_dir = 'out-emails-instr'\neval_interval = 250\neval_iters = 100\nlog_interval = 10\nwandb_log = False\n\ndataset = 'emails'\nbatch_size = 32 \ngradient_accumulation_steps = 4\nblock_size = 512 \n\nn_layer = 6\nn_head = 8\nn_embd = 512\ndropout = 0.1\nbias = False\n\nalways_save_checkpoint = False\n\nlearning_rate = 5e-4\nmax_iters = 2000\nlr_decay_iters = 2000\nmin_lr = 6e-5\nwarmup_iters = 200\nbeta1 = 0.9\nbeta2 = 0.95\nweight_decay = 1e-1\n\ndevice = 'cuda'\ndtype = 'float16'\ncompile = True","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-03-23T16:23:08.916901Z","iopub.execute_input":"2026-03-23T16:23:08.917571Z","iopub.status.idle":"2026-03-23T16:23:08.922905Z","shell.execute_reply.started":"2026-03-23T16:23:08.917536Z","shell.execute_reply":"2026-03-23T16:23:08.922156Z"}},"outputs":[{"name":"stdout","text":"Writing config/train_emails_instr.py\n","output_type":"stream"}],"execution_count":9},{"cell_type":"code","source":"!python train.py config/train_emails_instr.py","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-03-23T16:23:10.905431Z","iopub.execute_input":"2026-03-23T16:23:10.905761Z","iopub.status.idle":"2026-03-23T16:54:40.581344Z","shell.execute_reply.started":"2026-03-23T16:23:10.905732Z","shell.execute_reply":"2026-03-23T16:54:40.580686Z"}},"outputs":[{"name":"stdout","text":"Overriding config with config/train_emails_instr.py:\nout_dir = 'out-emails-instr'\neval_interval = 250\neval_iters = 100\nlog_interval = 10\nwandb_log = False\n\ndataset = 'emails'\nbatch_size = 32 \ngradient_accumulation_steps = 4\nblock_size = 512 \n\nn_layer = 6\nn_head = 8\nn_embd = 512\ndropout = 0.1\nbias = False\n\nalways_save_checkpoint = False\n\nlearning_rate = 5e-4\nmax_iters = 2000\nlr_decay_iters = 2000\nmin_lr = 6e-5\nwarmup_iters = 200\nbeta1 = 0.9\nbeta2 = 0.95\nweight_decay = 1e-1\n\ndevice = 'cuda'\ndtype = 'float16'\ncompile = True\n\ntokens per iteration will be: 65,536\nInitializing a new model from scratch\ndefaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)\nnumber of parameters: 44.64M\n/kaggle/working/nanoGPT/train.py:196: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))\nnum decayed parameter tensors: 26, with 44,892,160 parameters\nnum non-decayed parameter tensors: 13, with 6,656 parameters\nusing fused AdamW: True\ncompiling the model... (takes a ~minute)\nW0323 16:23:26.969000 238 torch/_inductor/utils.py:1679] [0/0] Not enough SMs to use max_autotune_gemm mode\nstep 0: train loss 10.9128, val loss 10.9125\niter 0: loss 10.9159, time 47776.24ms, mfu -100.00%\niter 10: loss 9.8689, time 1215.87ms, mfu 4.95%\niter 20: loss 9.0482, time 1258.20ms, mfu 4.94%\niter 30: loss 8.0331, time 1291.68ms, mfu 4.91%\niter 40: loss 7.0748, time 1351.83ms, mfu 4.86%\niter 50: loss 6.2711, time 1398.24ms, mfu 4.81%\niter 60: loss 5.4529, time 1431.68ms, mfu 4.75%\niter 70: loss 4.7636, time 1396.69ms, mfu 4.70%\niter 80: loss 4.1760, time 1371.29ms, mfu 4.67%\niter 90: loss 3.7997, time 1357.53ms, mfu 4.65%\niter 100: loss 3.4748, time 1362.87ms, mfu 4.63%\niter 110: loss 3.2527, time 1377.98ms, mfu 4.60%\niter 120: loss 3.1839, time 1377.34ms, mfu 4.58%\niter 130: loss 2.9877, time 1377.38ms, mfu 4.56%\niter 140: loss 2.7692, time 1375.16ms, mfu 4.54%\niter 150: loss 2.7559, time 1368.17ms, mfu 4.53%\niter 160: loss 2.7659, time 1379.79ms, mfu 4.51%\niter 170: loss 2.6179, time 1365.05ms, mfu 4.50%\niter 180: loss 2.5360, time 1364.52ms, mfu 4.49%\niter 190: loss 2.4574, time 1362.38ms, mfu 4.48%\niter 200: loss 2.4284, time 1363.14ms, mfu 4.48%\niter 210: loss 2.3530, time 1366.60ms, mfu 4.47%\niter 220: loss 2.3515, time 1357.93ms, mfu 4.47%\niter 230: loss 2.2095, time 1350.51ms, mfu 4.47%\niter 240: loss 2.0556, time 1352.70ms, mfu 4.46%\nstep 250: train loss 2.0681, val loss 2.0520\nsaving checkpoint to out-emails-instr\niter 250: loss 2.0986, time 25024.78ms, mfu 4.04%\niter 260: loss 1.9747, time 1355.30ms, mfu 4.08%\niter 270: loss 1.9650, time 1345.11ms, mfu 4.12%\niter 280: loss 1.9110, time 1343.25ms, mfu 4.16%\niter 290: loss 1.9005, time 1348.14ms, mfu 4.19%\niter 300: loss 1.9401, time 1348.59ms, mfu 4.22%\niter 310: loss 1.8480, time 1344.09ms, mfu 4.24%\niter 320: loss 1.7547, time 1339.44ms, mfu 4.27%\niter 330: loss 1.7241, time 1342.73ms, mfu 4.29%\niter 340: loss 1.8031, time 1340.91ms, mfu 4.31%\niter 350: loss 1.9044, time 1343.58ms, mfu 4.33%\niter 360: loss 1.7794, time 1349.26ms, mfu 4.34%\niter 370: loss 1.5913, time 1348.82ms, mfu 4.35%\niter 380: loss 1.6579, time 1336.52ms, mfu 4.37%\niter 390: loss 1.6046, time 1347.98ms, mfu 4.38%\niter 400: loss 1.6202, time 1338.72ms, mfu 4.39%\niter 410: loss 1.5568, time 1338.26ms, mfu 4.40%\niter 420: loss 1.5788, time 1349.17ms, mfu 4.41%\niter 430: loss 1.5368, time 1336.73ms, mfu 4.42%\niter 440: loss 1.5459, time 1347.51ms, mfu 4.42%\niter 450: loss 1.5069, time 1334.54ms, mfu 4.43%\niter 460: loss 1.5526, time 1337.62ms, mfu 4.44%\niter 470: loss 1.5196, time 1334.64ms, mfu 4.45%\niter 480: loss 1.4389, time 1334.79ms, mfu 4.45%\niter 490: loss 1.4959, time 1336.81ms, mfu 4.46%\nstep 500: train loss 1.4014, val loss 1.4947\nsaving checkpoint to out-emails-instr\niter 500: loss 1.5449, time 25294.00ms, mfu 4.04%\niter 510: loss 1.3807, time 1337.07ms, mfu 4.08%\niter 520: loss 1.5042, time 1334.67ms, mfu 4.13%\niter 530: loss 1.4826, time 1337.26ms, mfu 4.16%\niter 540: loss 1.4377, time 1338.09ms, mfu 4.20%\niter 550: loss 1.4087, time 1340.98ms, mfu 4.23%\niter 560: loss 1.3443, time 1338.40ms, mfu 4.25%\niter 570: loss 1.3519, time 1348.60ms, mfu 4.27%\niter 580: loss 1.3948, time 1341.85ms, mfu 4.30%\niter 590: loss 1.4708, time 1349.24ms, mfu 4.31%\niter 600: loss 1.3638, time 1346.94ms, mfu 4.33%\niter 610: loss 1.3622, time 1338.57ms, mfu 4.35%\niter 620: loss 1.2946, time 1342.65ms, mfu 4.36%\niter 630: loss 1.3855, time 1348.48ms, mfu 4.37%\niter 640: loss 1.3450, time 1349.40ms, mfu 4.38%\niter 650: loss 1.3653, time 1338.21ms, mfu 4.39%\niter 660: loss 1.2414, time 1334.87ms, mfu 4.40%\niter 670: loss 1.2506, time 1338.53ms, mfu 4.41%\niter 680: loss 1.2816, time 1331.45ms, mfu 4.42%\niter 690: loss 1.3233, time 1329.18ms, mfu 4.43%\niter 700: loss 1.3626, time 1338.51ms, mfu 4.44%\niter 710: loss 1.2485, time 1334.65ms, mfu 4.45%\niter 720: loss 1.2556, time 1335.73ms, mfu 4.45%\niter 730: loss 1.2642, time 1336.49ms, mfu 4.46%\niter 740: loss 1.2138, time 1333.10ms, mfu 4.47%\nstep 750: train loss 1.1685, val loss 1.3601\nsaving checkpoint to out-emails-instr\niter 750: loss 1.2102, time 25340.56ms, mfu 4.04%\niter 760: loss 1.2587, time 1326.33ms, mfu 4.09%\niter 770: loss 1.3652, time 1334.58ms, mfu 4.13%\niter 780: loss 1.2327, time 1336.02ms, mfu 4.17%\niter 790: loss 1.2060, time 1349.08ms, mfu 4.20%\niter 800: loss 1.2003, time 1333.83ms, mfu 4.23%\niter 810: loss 1.2041, time 1334.33ms, mfu 4.26%\niter 820: loss 1.1763, time 1336.58ms, mfu 4.28%\niter 830: loss 1.1559, time 1326.77ms, mfu 4.31%\niter 840: loss 1.2359, time 1331.97ms, mfu 4.33%\niter 850: loss 1.1655, time 1337.19ms, mfu 4.35%\niter 860: loss 1.1583, time 1331.47ms, mfu 4.37%\niter 870: loss 1.1543, time 1337.82ms, mfu 4.38%\niter 880: loss 1.1518, time 1339.48ms, mfu 4.39%\niter 890: loss 1.1260, time 1322.68ms, mfu 4.41%\niter 900: loss 1.1634, time 1332.59ms, mfu 4.42%\niter 910: loss 1.1818, time 1330.90ms, mfu 4.43%\niter 920: loss 1.1021, time 1318.95ms, mfu 4.44%\niter 930: loss 1.1188, time 1318.21ms, mfu 4.46%\niter 940: loss 1.1395, time 1329.56ms, mfu 4.46%\niter 950: loss 1.1560, time 1329.45ms, mfu 4.47%\niter 960: loss 1.2023, time 1324.69ms, mfu 4.48%\niter 970: loss 1.1176, time 1323.81ms, mfu 4.48%\niter 980: loss 1.1348, time 1325.15ms, mfu 4.49%\niter 990: loss 1.0540, time 1323.65ms, mfu 4.50%\nstep 1000: train loss 1.0256, val loss 1.3018\nsaving checkpoint to out-emails-instr\niter 1000: loss 1.1113, time 25323.54ms, mfu 4.07%\niter 1010: loss 1.0751, time 1320.36ms, mfu 4.12%\niter 1020: loss 1.0243, time 1334.81ms, mfu 4.16%\niter 1030: loss 1.1181, time 1334.37ms, mfu 4.19%\niter 1040: loss 1.1267, time 1331.80ms, mfu 4.23%\niter 1050: loss 1.0844, time 1335.55ms, mfu 4.26%\niter 1060: loss 1.0525, time 1337.28ms, mfu 4.28%\niter 1070: loss 1.0738, time 1348.78ms, mfu 4.30%\niter 1080: loss 1.0581, time 1334.13ms, mfu 4.32%\niter 1090: loss 1.0636, time 1335.17ms, mfu 4.34%\niter 1100: loss 1.1201, time 1335.63ms, mfu 4.36%\niter 1110: loss 1.0981, time 1337.16ms, mfu 4.37%\niter 1120: loss 1.0996, time 1335.75ms, mfu 4.38%\niter 1130: loss 0.9831, time 1332.39ms, mfu 4.40%\niter 1140: loss 1.0614, time 1337.83ms, mfu 4.41%\niter 1150: loss 0.9644, time 1342.08ms, mfu 4.42%\niter 1160: loss 1.0201, time 1328.43ms, mfu 4.43%\niter 1170: loss 0.9824, time 1323.65ms, mfu 4.44%\niter 1180: loss 1.0494, time 1318.79ms, mfu 4.45%\niter 1190: loss 1.0182, time 1334.76ms, mfu 4.46%\niter 1200: loss 1.0209, time 1323.44ms, mfu 4.47%\niter 1210: loss 0.9818, time 1319.36ms, mfu 4.48%\niter 1220: loss 1.0966, time 1329.64ms, mfu 4.48%\niter 1230: loss 1.0098, time 1326.58ms, mfu 4.49%\niter 1240: loss 1.0307, time 1324.23ms, mfu 4.49%\nstep 1250: train loss 0.8988, val loss 1.2901\nsaving checkpoint to out-emails-instr\niter 1250: loss 1.0472, time 25317.10ms, mfu 4.07%\niter 1260: loss 1.0251, time 1323.84ms, mfu 4.12%\niter 1270: loss 0.9320, time 1317.91ms, mfu 4.16%\n^C\nTraceback (most recent call last):\n File \"/kaggle/working/nanoGPT/train.py\", line 303, in <module>\n X, Y = get_batch('train')\n ^^^^^^^^^^^^^^^^^^\n File \"/kaggle/working/nanoGPT/train.py\", line 124, in get_batch\n x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])\n ~~~~^^^^^^^^^^^^^^^^\n File \"/usr/local/lib/python3.12/dist-packages/numpy/_core/memmap.py\", line 347, in __getitem__\n def __getitem__(self, index):\n\nKeyboardInterrupt\n","output_type":"stream"}],"execution_count":10},{"cell_type":"code","source":"import os\nimport torch\nfrom model import GPTConfig, GPT\nimport tiktoken\n\nout_dir = 'out-emails-instr'\ndevice = 'cuda'\n\nprint(\"Loading checkpoint...\")\nckpt_path = os.path.join(out_dir, 'ckpt.pt')\ncheckpoint = torch.load(ckpt_path, map_location=device)\ngptconf = GPTConfig(**checkpoint['model_args'])\nmodel = GPT(gptconf)\nstate_dict = checkpoint['model']\n\nunwanted_prefix = '_orig_mod.'\nfor k,v in list(state_dict.items()):\n if k.startswith(unwanted_prefix):\n state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)\nmodel.load_state_dict(state_dict)\n\nmodel.to(device)\nmodel.eval()\n\nenc = tiktoken.get_encoding(\"gpt2\")\n\ninstruction = \"Write a polite refusal email\"\nprompt = f\"### Instruction:\\n{instruction}\\n\\n### Response:\\n\"\n\nstart_ids = enc.encode(prompt, allowed_special={\"\"})\nx = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]\n\nimport json\n\nprint(\"\\n--- Generating email ---\")\nwith torch.no_grad():\n y = model.generate(x, max_new_tokens=300, temperature=0.7, top_k=15)\n full_output = enc.decode(y[0].tolist())\n \n response_marker = \"### Response:\\n\"\n if response_marker in full_output:\n json_part = full_output.split(response_marker)[1].split(\"<|endoftext|>\")[0].strip()\n \n try:\n email_data = json.loads(json_part)\n \n print(f\"SUBJECT: {email_data.get('subject', 'No subject')}\")\n print(\"-\" * 30)\n print(f\"BODY:\\n{email_data.get('body', 'No body')}\")\n \n except json.JSONDecodeError:\n print(\"Failed parsing JSON. Raw output:\")\n print(json_part)\n else:\n print(\"No response marker found!\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-03-23T17:07:10.674328Z","iopub.execute_input":"2026-03-23T17:07:10.674815Z","iopub.status.idle":"2026-03-23T17:07:13.294489Z","shell.execute_reply.started":"2026-03-23T17:07:10.674786Z","shell.execute_reply":"2026-03-23T17:07:13.293776Z"}},"outputs":[{"name":"stdout","text":"Loading checkpoint...\nnumber of parameters: 44.64M\n\n--- Generating email ---\nFailed parsing JSON. Raw output:\nWrite a firm and authoritative business email(under 200 words) from a HR Manager to a Potential Partner regarding proposing a joint webinar, specifically right before a long holiday.\n","output_type":"stream"}],"execution_count":25}]}
|