{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm", "gpuType": "G4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "\"\"\"\n", "=============================================================\n", " IMAGE CAPTIONING — Vanilla RNN (One-to-Many)\n", " Dataset : Flickr8k (kagglehub version)\n", " Encoder : NONE — raw flattened pixels → linear projection\n", " Decoder : Vanilla RNN (manually implemented)\n", " No CNN, no ResNet, no pretrained weights.\n", "=============================================================\n", "\"\"\"\n", "\n", "import re, math, time, random, os\n", "import numpy as np\n", "import pandas as pd\n", "from PIL import Image\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import transforms\n", "from collections import Counter\n", "import kagglehub\n", "\n", "# ─────────────────────────────────────────────\n", "# Setup\n", "# ─────────────────────────────────────────────\n", "SEED = 42\n", "random.seed(SEED)\n", "np.random.seed(SEED)\n", "torch.manual_seed(SEED)\n", "\n", "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Using device: {DEVICE}\")\n", "\n", "# ─────────────────────────────────────────────\n", "# Hyperparameters\n", "# ─────────────────────────────────────────────\n", "IMG_SIZE = 32\n", "EMBED_DIM = 256\n", "HIDDEN_DIM = 512\n", "BATCH_SIZE = 64\n", "EPOCHS = 20\n", "LR = 3e-4\n", "MAX_SEQ_LEN = 30\n", "MIN_WORD_FREQ = 2\n", "GRAD_CLIP = 5.0\n", "SAVE_PATH = \"vanilla_rnn_captioning.pth\"\n", "\n", "PIXEL_DIM = IMG_SIZE * IMG_SIZE * 3\n", "\n", "# ─────────────────────────────────────────────\n", "# Vocabulary\n", "# ─────────────────────────────────────────────\n", "PAD, SOS, EOS, UNK = \"\", \"\", \"\", \"\"\n", "\n", "class Vocabulary:\n", " def __init__(self, min_freq=MIN_WORD_FREQ):\n", " self.min_freq = min_freq\n", " self.word2idx = {}\n", " self.idx2word = {}\n", "\n", " for i, tok in enumerate([PAD, SOS, EOS, UNK]):\n", " self.word2idx[tok] = i\n", " self.idx2word[i] = tok\n", "\n", " def tokenize(self, text):\n", " return re.sub(r\"[^a-z0-9' ]\", \"\", str(text).lower()).split()\n", "\n", " def build(self, captions):\n", " counter = Counter(w for cap in captions for w in self.tokenize(cap))\n", " for word, freq in counter.items():\n", " if freq >= self.min_freq:\n", " idx = len(self.word2idx)\n", " self.word2idx[word] = idx\n", " self.idx2word[idx] = word\n", "\n", " def encode(self, text):\n", " return (\n", " [self.word2idx[SOS]] +\n", " [self.word2idx.get(w, self.word2idx[UNK]) for w in self.tokenize(text)] +\n", " [self.word2idx[EOS]]\n", " )\n", "\n", " def decode(self, indices):\n", " words = []\n", " for i in indices:\n", " w = self.idx2word.get(i, UNK)\n", " if w == EOS:\n", " break\n", " if w not in (PAD, SOS):\n", " words.append(w)\n", " return \" \".join(words)\n", "\n", " def __len__(self):\n", " return len(self.word2idx)\n", "\n", "# ─────────────────────────────────────────────\n", "# Dataset\n", "# ─────────────────────────────────────────────\n", "class Flickr8kDataset(Dataset):\n", " def __init__(self, df, img_dir, vocab, transform):\n", " self.vocab = vocab\n", " self.transform = transform\n", " self.img_dir = img_dir\n", " self.samples = list(zip(df['image'], df['caption']))\n", "\n", " def __len__(self):\n", " return len(self.samples)\n", "\n", " def __getitem__(self, idx):\n", " img_name, cap = self.samples[idx]\n", " img_path = os.path.join(self.img_dir, img_name)\n", "\n", " img = Image.open(img_path).convert(\"RGB\")\n", " img = self.transform(img)\n", " img = img.view(-1)\n", "\n", " ids = self.vocab.encode(cap)\n", " ids = ids[:MAX_SEQ_LEN] + [self.vocab.word2idx[PAD]] * max(0, MAX_SEQ_LEN - len(ids))\n", "\n", " return img, torch.tensor(ids, dtype=torch.long)\n", "\n", "# ─────────────────────────────────────────────\n", "# Vanilla RNN Cell\n", "# ─────────────────────────────────────────────\n", "class VanillaRNNCell(nn.Module):\n", " def __init__(self, input_dim, hidden_dim):\n", " super().__init__()\n", " self.W_ih = nn.Linear(input_dim, hidden_dim)\n", " self.W_hh = nn.Linear(hidden_dim, hidden_dim, bias=False)\n", "\n", " def forward(self, x, h):\n", " return torch.tanh(self.W_ih(x) + self.W_hh(h))\n", "\n", "# ─────────────────────────────────────────────\n", "# Model\n", "# ─────────────────────────────────────────────\n", "class VanillaRNNCaptioner(nn.Module):\n", " def __init__(self, vocab_size, pixel_dim, embed_dim, hidden_dim):\n", " super().__init__()\n", "\n", " self.img_proj = nn.Linear(pixel_dim, hidden_dim)\n", " self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=0)\n", " self.cell = VanillaRNNCell(embed_dim, hidden_dim)\n", " self.fc_out = nn.Linear(hidden_dim, vocab_size)\n", " self.dropout = nn.Dropout(0.3)\n", "\n", " def forward(self, pixels, captions):\n", " h = torch.tanh(self.img_proj(pixels))\n", " outputs = []\n", "\n", " for t in range(captions.size(1) - 1):\n", " x = self.dropout(self.embed(captions[:, t]))\n", " h = self.cell(x, h)\n", " outputs.append(self.fc_out(h))\n", "\n", " return torch.stack(outputs, dim=1)\n", "\n", " @torch.no_grad()\n", " def generate(self, pixels, vocab, max_len=MAX_SEQ_LEN):\n", " B = pixels.size(0)\n", " h = torch.tanh(self.img_proj(pixels))\n", "\n", " inp = torch.full((B,), vocab.word2idx[SOS], dtype=torch.long, device=DEVICE)\n", " result = []\n", "\n", " for _ in range(max_len):\n", " x = self.embed(inp)\n", " h = self.cell(x, h)\n", " pred = self.fc_out(h).argmax(dim=-1)\n", "\n", " result.append(pred)\n", " inp = pred\n", "\n", " return [vocab.decode(torch.stack(result, dim=1)[i].tolist()) for i in range(B)]\n", "\n", "# ─────────────────────────────────────────────\n", "# Train / Eval\n", "# ─────────────────────────────────────────────\n", "def train_epoch(model, loader, optimizer, criterion, vocab):\n", " model.train()\n", " total_loss, total_tok = 0.0, 0\n", " pad = vocab.word2idx[PAD]\n", "\n", " for pixels, caps in loader:\n", " pixels, caps = pixels.to(DEVICE), caps.to(DEVICE)\n", "\n", " optimizer.zero_grad()\n", " logits = model(pixels, caps)\n", "\n", " B, T, V = logits.shape\n", " loss = criterion(logits.reshape(B*T, V), caps[:, 1:].reshape(B*T))\n", "\n", " loss.backward()\n", " nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n", " optimizer.step()\n", "\n", " mask = caps[:, 1:] != pad\n", " total_loss += loss.item() * mask.sum().item()\n", " total_tok += mask.sum().item()\n", "\n", " return total_loss / total_tok\n", "\n", "\n", "@torch.no_grad()\n", "def eval_epoch(model, loader, criterion, vocab):\n", " model.eval()\n", " total_loss, total_tok = 0.0, 0\n", " pad = vocab.word2idx[PAD]\n", "\n", " for pixels, caps in loader:\n", " pixels, caps = pixels.to(DEVICE), caps.to(DEVICE)\n", "\n", " logits = model(pixels, caps)\n", " B, T, V = logits.shape\n", "\n", " loss = criterion(logits.reshape(B*T, V), caps[:, 1:].reshape(B*T))\n", "\n", " mask = caps[:, 1:] != pad\n", " total_loss += loss.item() * mask.sum().item()\n", " total_tok += mask.sum().item()\n", "\n", " return total_loss / total_tok\n", "\n", "# ─────────────────────────────────────────────\n", "# Main\n", "# ─────────────────────────────────────────────\n", "def main():\n", " print(\"Downloading dataset...\")\n", " data_dir = kagglehub.dataset_download(\"adityajn105/flickr8k\")\n", "\n", " print(\"Dataset path:\", data_dir)\n", "\n", " img_dir = os.path.join(data_dir, \"Images\")\n", " csv_path = os.path.join(data_dir, \"captions.txt\")\n", "\n", " # fallback search\n", " if not os.path.exists(csv_path):\n", " for root, dirs, files in os.walk(data_dir):\n", " if \"captions.txt\" in files:\n", " csv_path = os.path.join(root, \"captions.txt\")\n", " if \"Images\" in dirs:\n", " img_dir = os.path.join(root, \"Images\")\n", "\n", " print(\"Images dir:\", img_dir)\n", " print(\"Captions file:\", csv_path)\n", "\n", " if not os.path.exists(csv_path):\n", " raise FileNotFoundError(\"captions.txt not found!\")\n", "\n", " df = pd.read_csv(csv_path)\n", "\n", " if df.shape[1] == 1:\n", " df = pd.read_csv(csv_path, sep=\",\", names=[\"image\", \"caption\"], skiprows=1)\n", "\n", " df[\"image\"] = df[\"image\"].apply(lambda x: x.split(\"#\")[0])\n", "\n", " # split\n", " n = len(df)\n", " n_train = int(0.9 * n)\n", "\n", " train_df = df.iloc[:n_train]\n", " val_df = df.iloc[n_train:]\n", "\n", " # vocab\n", " vocab = Vocabulary()\n", " vocab.build(train_df[\"caption\"].tolist())\n", "\n", " print(f\"Vocab size: {len(vocab)}\")\n", "\n", " # transforms\n", " tfm = transforms.Compose([\n", " transforms.Resize((IMG_SIZE, IMG_SIZE)),\n", " transforms.ToTensor()\n", " ])\n", "\n", " # datasets\n", " train_set = Flickr8kDataset(train_df, img_dir, vocab, tfm)\n", " val_set = Flickr8kDataset(val_df, img_dir, vocab, tfm)\n", "\n", " train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True, drop_last=True)\n", " val_loader = DataLoader(val_set, BATCH_SIZE)\n", "\n", " # model\n", " model = VanillaRNNCaptioner(len(vocab), PIXEL_DIM, EMBED_DIM, HIDDEN_DIM).to(DEVICE)\n", "\n", " optimizer = optim.Adam(model.parameters(), lr=LR)\n", " scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)\n", " criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[PAD])\n", "\n", " best_val = math.inf\n", "\n", " # training loop\n", " for epoch in range(1, EPOCHS + 1):\n", " t0 = time.time()\n", "\n", " train_loss = train_epoch(model, train_loader, optimizer, criterion, vocab)\n", " val_loss = eval_epoch(model, val_loader, criterion, vocab)\n", "\n", " scheduler.step()\n", "\n", " print(f\"Epoch {epoch:02d} | Train {train_loss:.4f} | Val {val_loss:.4f}\")\n", "\n", " if val_loss < best_val:\n", " best_val = val_loss\n", " torch.save({\"model\": model.state_dict(), \"vocab\": vocab}, SAVE_PATH)\n", " print(\"Saved model\")\n", "\n", "if __name__ == \"__main__\":\n", " main()" ], "metadata": { "id": "vnrnDRn3I-Cu", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "1df4e60b-6ea2-420c-c016-66159273ad3c" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Using device: cuda\n", "Downloading dataset...\n", "Using Colab cache for faster access to the 'flickr8k' dataset.\n", "Dataset path: /kaggle/input/flickr8k\n", "Images dir: /kaggle/input/flickr8k/Images\n", "Captions file: /kaggle/input/flickr8k/captions.txt\n", "Vocab size: 5001\n", "Epoch 01 | Train 4.3251 | Val 3.6760\n", "Saved model\n", "Epoch 02 | Train 3.6269 | Val 3.4343\n", "Saved model\n", "Epoch 03 | Train 3.4227 | Val 3.3320\n", "Saved model\n", "Epoch 04 | Train 3.2967 | Val 3.2652\n", "Saved model\n", "Epoch 05 | Train 3.2038 | Val 3.2251\n", "Saved model\n", "Epoch 06 | Train 3.1303 | Val 3.1986\n", "Saved model\n", "Epoch 07 | Train 3.0692 | Val 3.1696\n", "Saved model\n", "Epoch 08 | Train 3.0177 | Val 3.1493\n", "Saved model\n", "Epoch 09 | Train 2.9719 | Val 3.1336\n", "Saved model\n", "Epoch 10 | Train 2.9328 | Val 3.1252\n", "Saved model\n", "Epoch 11 | Train 2.8984 | Val 3.1244\n", "Saved model\n", "Epoch 12 | Train 2.8679 | Val 3.1102\n", "Saved model\n", "Epoch 13 | Train 2.8444 | Val 3.1026\n", "Saved model\n", "Epoch 14 | Train 2.8215 | Val 3.1084\n", "Epoch 15 | Train 2.8054 | Val 3.1011\n", "Saved model\n", "Epoch 16 | Train 2.7908 | Val 3.1013\n", "Epoch 17 | Train 2.7802 | Val 3.1018\n", "Epoch 18 | Train 2.7714 | Val 3.0995\n", "Saved model\n", "Epoch 19 | Train 2.7661 | Val 3.0990\n", "Saved model\n", "Epoch 20 | Train 2.7650 | Val 3.0995\n" ] } ] }, { "cell_type": "code", "source": [ "# ─────────────────────────────────────────────\n", "# UI + Inference Cell\n", "# ─────────────────────────────────────────────\n", "import gradio as gr\n", "\n", "# Load trained model\n", "def load_model():\n", " checkpoint = torch.load(SAVE_PATH, map_location=DEVICE, weights_only=False)\n", " vocab = checkpoint[\"vocab\"]\n", "\n", " model = VanillaRNNCaptioner(\n", " len(vocab),\n", " PIXEL_DIM,\n", " EMBED_DIM,\n", " HIDDEN_DIM\n", " ).to(DEVICE)\n", "\n", " model.load_state_dict(checkpoint[\"model\"])\n", " model.eval()\n", "\n", " return model, vocab\n", "\n", "model, vocab = load_model()\n", "\n", "# Transform (same as training)\n", "tfm = transforms.Compose([\n", " transforms.Resize((IMG_SIZE, IMG_SIZE)),\n", " transforms.ToTensor()\n", "])\n", "\n", "# Prediction function\n", "def predict(image):\n", " image = image.convert(\"RGB\")\n", " image = tfm(image)\n", " image = image.view(1, -1).to(DEVICE)\n", "\n", " caption = model.generate(image, vocab)[0]\n", " return caption\n", "\n", "# Gradio UI\n", "demo = gr.Interface(\n", " fn=predict,\n", " inputs=gr.Image(type=\"pil\"),\n", " outputs=\"text\",\n", " title=\"Image Captioning (Vanilla RNN)\",\n", " description=\"Upload an image → get caption\"\n", ")\n", "\n", "demo.launch()\n" ], "metadata": { "id": "QEycwtwaSTy4", "colab": { "base_uri": "https://localhost:8080/", "height": 648 }, "outputId": "e9ab980b-e575-4926-ca47-b396171bae50" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n", "\n", "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n", "* Running on public URL: https://dd5611461a17776d59.gradio.live\n", "\n", "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [] }, "metadata": {}, "execution_count": 3 } ] } ] }