File size: 17,648 Bytes
91c1980 | 1 | {"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"name":"python","version":"3.12.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"accelerate_config":{"num_processes":2}},"nbformat_minor":5,"nbformat":4,"cells":[{"id":"932ffe2c-915d-4254-b4dd-1d32bfeb87db","cell_type":"markdown","source":"# Java Code Optimization β CodeT5-small Fine-tuning\n**Kaggle T4Γ2 Β· Dual-GPU DataParallel Β· ~6 K train / 680 val pairs**\n\nPipeline:\n1. Load dataset from `dataset_train.jsonl` / `dataset_val.jsonl`\n2. Tokenize with `Salesforce/codet5-small` tokenizer \n3. Fine-tune with `Seq2SeqTrainer` + both GPUs via `DataParallel`\n4. Evaluate with BLEU + CodeBLEU proxies \n5. Push to HuggingFace Hub (optional) and save artefacts","metadata":{}},{"id":"10c9c8fb-0f60-4876-9110-7df1fc24b0aa","cell_type":"code","source":"import subprocess, sys\n\npkgs = [\n \"transformers==4.41.2\",\n \"datasets==2.20.0\",\n \"evaluate==0.4.2\",\n \"sacrebleu==2.4.3\",\n \"accelerate==0.33.0\",\n \"peft==0.11.1\",\n \"sentencepiece==0.2.0\",\n \"rouge_score==0.1.2\",\n \"tokenizers==0.19.1\",\n]\n\nsubprocess.run(\n [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--force-reinstall\"] + pkgs\n)\n\nprint(\"β
Install done. RESTART KERNEL NOW.\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"e8ea6ed7-c9ef-4a0f-8cbb-457c1b599466","cell_type":"code","source":"import os\n\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # β
FORCE SINGLE GPU\nos.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\" # debug + stability","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T12:59:28.032986Z","iopub.execute_input":"2026-04-03T12:59:28.033907Z","iopub.status.idle":"2026-04-03T12:59:28.038196Z","shell.execute_reply.started":"2026-04-03T12:59:28.033868Z","shell.execute_reply":"2026-04-03T12:59:28.037332Z"}},"outputs":[],"execution_count":14},{"id":"5712ecaa-b2af-4446-a009-e210d4897a12","cell_type":"code","source":"import os, json, random, time\nfrom pathlib import Path\nfrom dataclasses import dataclass\n\nimport numpy as np\nimport torch\n\nfrom transformers import (\n T5ForConditionalGeneration,\n RobertaTokenizer,\n Seq2SeqTrainer,\n Seq2SeqTrainingArguments,\n DataCollatorForSeq2Seq,\n)\n\nfrom datasets import Dataset as HFDataset, DatasetDict\n\nprint(\"CUDA:\", torch.cuda.is_available())\nprint(\"GPUs:\", torch.cuda.device_count())","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T12:59:28.247024Z","iopub.execute_input":"2026-04-03T12:59:28.247805Z","iopub.status.idle":"2026-04-03T12:59:28.253552Z","shell.execute_reply.started":"2026-04-03T12:59:28.247769Z","shell.execute_reply":"2026-04-03T12:59:28.252760Z"}},"outputs":[{"name":"stdout","text":"CUDA: True\nGPUs: 1\n","output_type":"stream"}],"execution_count":15},{"id":"ad8ca487-c537-49ba-a1ff-0b124837ced8","cell_type":"code","source":"@dataclass\nclass CFG:\n model_name: str = \"Salesforce/codet5-small\"\n\n train_file: str = \"/kaggle/input/datasets/suhaskoheda/java-optimisation/dataset_train.jsonl\"\n val_file: str = \"/kaggle/input/datasets/suhaskoheda/java-optimisation/dataset_val.jsonl\"\n\n max_source_length: int = 128\n max_target_length: int = 128\n\n num_train_epochs: int = 5 # π₯ fast\n\n per_device_train_batch_size: int = 16\n per_device_eval_batch_size: int = 32\n\n learning_rate: float = 5e-4\n weight_decay: float = 0.01\n warmup_ratio: float = 0.05\n\n output_dir: str = \"/kaggle/working/codet5-fast\"\n\ncfg = CFG()\nPath(cfg.output_dir).mkdir(parents=True, exist_ok=True)\n\nprint(\"β
Config ready\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T12:59:28.411634Z","iopub.execute_input":"2026-04-03T12:59:28.412174Z","iopub.status.idle":"2026-04-03T12:59:28.419593Z","shell.execute_reply.started":"2026-04-03T12:59:28.412144Z","shell.execute_reply":"2026-04-03T12:59:28.418710Z"}},"outputs":[{"name":"stdout","text":"β
Config ready\n","output_type":"stream"}],"execution_count":16},{"id":"2a3aae51-9c02-4b13-b4f7-7e95e3c5084e","cell_type":"code","source":"def load_jsonl(path):\n data = []\n with open(path, encoding=\"utf-8\") as f:\n for line in f:\n if line.strip():\n data.append(json.loads(line))\n return data\n\ntrain_raw = load_jsonl(cfg.train_file)\nval_raw = load_jsonl(cfg.val_file)\n\nprint(\"Train:\", len(train_raw))\nprint(\"Val:\", len(val_raw))\n\nprint(\"\\nSample:\")\nprint(train_raw[0][\"input\"][:200])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T12:59:28.560595Z","iopub.execute_input":"2026-04-03T12:59:28.561419Z","iopub.status.idle":"2026-04-03T12:59:28.617144Z","shell.execute_reply.started":"2026-04-03T12:59:28.561339Z","shell.execute_reply":"2026-04-03T12:59:28.616172Z"}},"outputs":[{"name":"stdout","text":"Train: 6115\nVal: 680\n\nSample:\nString alertLevel;\nswitch (sensor.getValue()) {\n case 0:\n case 1:\n alertLevel = \"GREEN\";\n break;\n case 2:\n alertLevel = \"YELLOW\";\n break;\n case 3:\n alert\n","output_type":"stream"}],"execution_count":17},{"id":"4ead70f3-795f-42e9-863d-ac14024611d7","cell_type":"code","source":"from transformers import AutoTokenizer\n\ntokenizer = AutoTokenizer.from_pretrained(\n cfg.model_name,\n use_fast=False\n)\n\nTASK_PREFIX = \"Optimize Java: \"\n\ndef tokenize_batch(examples):\n inputs = [TASK_PREFIX + x for x in examples[\"input\"]]\n targets = examples[\"output\"]\n\n model_inputs = tokenizer(\n inputs,\n max_length=cfg.max_source_length,\n truncation=True,\n padding=\"max_length\",\n )\n\n labels = tokenizer(\n targets,\n max_length=cfg.max_target_length,\n truncation=True,\n padding=\"max_length\",\n )\n\n # π₯ CRITICAL FIX\n cleaned_labels = []\n for label in labels[\"input_ids\"]:\n cleaned = []\n for token in label:\n if token == tokenizer.pad_token_id:\n cleaned.append(-100)\n else:\n cleaned.append(int(token)) # π₯ force python int\n cleaned_labels.append(cleaned)\n\n model_inputs[\"labels\"] = cleaned_labels\n\n return model_inputs\n \nprint(\"β
Tokenizer ready\")","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"id":"94508e52-a595-452b-89f2-cf85003d2109","cell_type":"code","source":"def to_hf(rows):\n return HFDataset.from_dict({\n \"input\": [r[\"input\"] for r in rows],\n \"output\": [r[\"output\"] for r in rows],\n })\n\ndataset = DatasetDict({\n \"train\": to_hf(train_raw),\n \"val\": to_hf(val_raw),\n})\n\ntokenized_ds = dataset.map(\n tokenize_batch,\n batched=True,\n remove_columns=[\"input\", \"output\"],\n)\n\ntokenized_ds.set_format(\"torch\")\ntokenized_ds = tokenized_ds.with_format(\"torch\")\nprint(\"β
Dataset ready\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:02:56.439164Z","iopub.execute_input":"2026-04-03T13:02:56.440280Z","iopub.status.idle":"2026-04-03T13:03:01.939341Z","shell.execute_reply.started":"2026-04-03T13:02:56.440196Z","shell.execute_reply":"2026-04-03T13:03:01.938429Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/6115 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"d5a572e0b3ff4df6a689234b7c46a72d"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/680 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f74452d8df16462ba281f71eb59989bd"}},"metadata":{}},{"name":"stdout","text":"β
Dataset ready\n","output_type":"stream"}],"execution_count":25},{"id":"508d26e1-5754-41a5-88ae-1b70cb9fc14a","cell_type":"code","source":"model = T5ForConditionalGeneration.from_pretrained(cfg.model_name)\n\nprint(\"Params:\", sum(p.numel() for p in model.parameters()) / 1e6, \"M\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:03:01.940689Z","iopub.execute_input":"2026-04-03T13:03:01.941138Z","iopub.status.idle":"2026-04-03T13:03:02.573573Z","shell.execute_reply.started":"2026-04-03T13:03:01.941109Z","shell.execute_reply":"2026-04-03T13:03:02.572813Z"}},"outputs":[{"name":"stdout","text":"Params: 60.492288 M\n","output_type":"stream"}],"execution_count":26},{"id":"3a5dfcc8-9f91-414a-932f-62fb9d1f3998","cell_type":"code","source":"from transformers import DataCollatorForSeq2Seq\n\nclass FastCollator(DataCollatorForSeq2Seq):\n def __call__(self, features):\n for f in features:\n if \"labels\" in f:\n f[\"labels\"] = list(f[\"labels\"]) # π₯ ensure list\n return super().__call__(features)\n\ndata_collator = FastCollator(\n tokenizer=tokenizer,\n model=model,\n label_pad_token_id=-100,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:03:02.574574Z","iopub.execute_input":"2026-04-03T13:03:02.575012Z","iopub.status.idle":"2026-04-03T13:03:02.580743Z","shell.execute_reply.started":"2026-04-03T13:03:02.574963Z","shell.execute_reply":"2026-04-03T13:03:02.580006Z"}},"outputs":[],"execution_count":27},{"id":"ad9b3c55-9cb8-4b67-98a4-a85724e3b6b0","cell_type":"code","source":"training_args = Seq2SeqTrainingArguments(\n output_dir=cfg.output_dir,\n\n num_train_epochs=cfg.num_train_epochs,\n\n per_device_train_batch_size=cfg.per_device_train_batch_size,\n per_device_eval_batch_size=cfg.per_device_eval_batch_size,\n\n learning_rate=cfg.learning_rate,\n weight_decay=cfg.weight_decay,\n warmup_ratio=cfg.warmup_ratio,\n\n # π₯ CRITICAL FIXES\n fp16=False, # β disable fp16\n optim=\"adamw_torch\", # β no fused optimizer\n\n predict_with_generate=False,\n\n logging_steps=200,\n report_to=\"none\",\n\n save_strategy=\"epoch\",\n\n dataloader_num_workers=2,\n dataloader_pin_memory=True,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:10:17.704999Z","iopub.execute_input":"2026-04-03T13:10:17.705790Z","iopub.status.idle":"2026-04-03T13:10:17.747022Z","shell.execute_reply.started":"2026-04-03T13:10:17.705752Z","shell.execute_reply":"2026-04-03T13:10:17.746340Z"}},"outputs":[],"execution_count":31},{"id":"c429c4c1-9626-4889-a1f4-124115c31bdc","cell_type":"code","source":"from transformers import TrainerCallback\nfrom tqdm.auto import tqdm\n\nclass ProgressBarCallback(TrainerCallback):\n def __init__(self):\n self.pbar = None\n\n def on_train_begin(self, args, state, control, **kwargs):\n self.pbar = tqdm(\n total=state.max_steps,\n desc=\"π Training\",\n unit=\"step\"\n )\n\n def on_step_end(self, args, state, control, **kwargs):\n if self.pbar:\n self.pbar.update(1)\n\n def on_train_end(self, args, state, control, **kwargs):\n if self.pbar:\n self.pbar.close()\n\ntrainer = Seq2SeqTrainer(\n model=model,\n args=training_args,\n train_dataset=tokenized_ds[\"train\"],\n eval_dataset=tokenized_ds[\"val\"],\n tokenizer=tokenizer,\n data_collator=data_collator,\n callbacks=[ProgressBarCallback()], # π₯ ADD THIS\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:10:17.884885Z","iopub.execute_input":"2026-04-03T13:10:17.885120Z","iopub.status.idle":"2026-04-03T13:10:17.902139Z","shell.execute_reply.started":"2026-04-03T13:10:17.885097Z","shell.execute_reply":"2026-04-03T13:10:17.901497Z"}},"outputs":[],"execution_count":32},{"id":"97318834-7e71-4716-bdb9-3e00aa27570c","cell_type":"code","source":"\nprint(\"π Training...\")\n\nt0 = time.time()\n\ntrainer.train()\n\nprint(f\"β
Done in {(time.time()-t0)/60:.1f} min\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:10:19.398505Z","iopub.execute_input":"2026-04-03T13:10:19.399280Z","iopub.status.idle":"2026-04-03T13:22:11.117488Z","shell.execute_reply.started":"2026-04-03T13:10:19.399231Z","shell.execute_reply":"2026-04-03T13:22:11.116791Z"}},"outputs":[{"name":"stdout","text":"π Training...\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"π Training: 0%| | 0/1915 [00:00<?, ?step/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"f0fef2302c9149f39803f045ffb08556"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n <div>\n \n <progress value='1915' max='1915' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [1915/1915 11:50, Epoch 5/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Step</th>\n <th>Training Loss</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>200</td>\n <td>0.145300</td>\n </tr>\n <tr>\n <td>400</td>\n <td>0.149900</td>\n </tr>\n <tr>\n <td>600</td>\n <td>0.112700</td>\n </tr>\n <tr>\n <td>800</td>\n <td>0.092400</td>\n </tr>\n <tr>\n <td>1000</td>\n <td>0.074900</td>\n </tr>\n <tr>\n <td>1200</td>\n <td>0.073200</td>\n </tr>\n <tr>\n <td>1400</td>\n <td>0.070000</td>\n </tr>\n <tr>\n <td>1600</td>\n <td>0.061400</td>\n </tr>\n <tr>\n <td>1800</td>\n <td>0.044600</td>\n </tr>\n </tbody>\n</table><p>"},"metadata":{}},{"name":"stdout","text":"β
Done in 11.9 min\n","output_type":"stream"}],"execution_count":33},{"id":"77c6ffc2-efae-437f-9f3d-fb9db0a07c76","cell_type":"code","source":"trainer.save_model(cfg.output_dir)\ntokenizer.save_pretrained(cfg.output_dir)\n\nprint(\"β
Model saved\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T13:22:11.118944Z","iopub.execute_input":"2026-04-03T13:22:11.119273Z","iopub.status.idle":"2026-04-03T13:22:11.697695Z","shell.execute_reply.started":"2026-04-03T13:22:11.119244Z","shell.execute_reply":"2026-04-03T13:22:11.696982Z"}},"outputs":[{"name":"stdout","text":"β
Model saved\n","output_type":"stream"}],"execution_count":34},{"id":"6537f70f-3c8d-444d-9dd8-f563abc2af0c","cell_type":"code","source":"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nmodel = model.to(device)\nmodel.eval()\n\ndef optimize_java(code):\n inputs = tokenizer(\n TASK_PREFIX + code,\n return_tensors=\"pt\",\n truncation=True,\n max_length=cfg.max_source_length,\n ).to(device)\n\n with torch.no_grad():\n outputs = model.generate(\n **inputs,\n max_new_tokens=128,\n )\n\n return tokenizer.decode(outputs[0], skip_special_tokens=True)\n\n\n# TEST\nprint(optimize_java(\"\"\"\nString alertLevel;\nswitch (sensor.getValue()) {\n case 0:\n case 1:\n alertLevel = \"GREEN\";\n break;\n case 2:\n alertLevel = \"YELLOW\";\n break;\n case 3:\n alertLevel = \"RED\";\n break;\n default:\n alertLevel = \"CRITICAL\";\n}\n\"\"\"))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T14:05:35.081973Z","iopub.execute_input":"2026-04-03T14:05:35.082240Z","iopub.status.idle":"2026-04-03T14:05:35.749507Z","shell.execute_reply.started":"2026-04-03T14:05:35.082214Z","shell.execute_reply":"2026-04-03T14:05:35.748745Z"}},"outputs":[{"name":"stdout","text":"String alertLevel = switch (sensor.getValue()) {\n case 0, 1 -> \"1\";\n case 2 -> \"2\";\n case 3 -> \"3\";\n default -> \"4\";\n};\n","output_type":"stream"}],"execution_count":37},{"id":"cb063d9a-5425-40cd-89de-c23edd1a60f3","cell_type":"code","source":"print(optimize_java(\"\"\"\nList<Payment> payments = new ArrayList<>(pendingPayments);\nIterator<Payment> iterator = payments.iterator();\nwhile (iterator.hasNext()) {\n Payment p = iterator.next();\n if (p.isExpired()) {\n iterator.remove();\n }\n}\"\"\"))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T14:06:17.153683Z","iopub.execute_input":"2026-04-03T14:06:17.154494Z","iopub.status.idle":"2026-04-03T14:06:17.546252Z","shell.execute_reply.started":"2026-04-03T14:06:17.154454Z","shell.execute_reply":"2026-04-03T14:06:17.545289Z"}},"outputs":[{"name":"stdout","text":"List<Payment> payments = new ArrayList<>(pendingPayments);\npayments.removeIf(Payment::isExpired);\n","output_type":"stream"}],"execution_count":38},{"id":"ae99af62-1c8a-4adc-bdef-7c74c2da7696","cell_type":"code","source":"print(optimize_java(\"\"\"\nString result = \"\";\nfor (int i = 0; i < 100; i++) {\n result += \"Value: \" + i + \"\\n\";\n}\"\"\"))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2026-04-03T14:08:05.318144Z","iopub.execute_input":"2026-04-03T14:08:05.318944Z","iopub.status.idle":"2026-04-03T14:08:06.065047Z","shell.execute_reply.started":"2026-04-03T14:08:05.318907Z","shell.execute_reply":"2026-04-03T14:08:06.064234Z"}},"outputs":[{"name":"stdout","text":"StringBuilder sb = new StringBuilder(100 * 8);\nfor (int i = 0; i < 100; i++) {\n sb.append(\"Value: \").append(i).append(\"\n\");\n}\nString result = sb.toString();\n","output_type":"stream"}],"execution_count":39}]} |