AliHamza852's picture
Upload 4 files
9f93006 verified
import gradio as gr
import torch
import torch.nn as nn
import pickle
from torchvision import models, transforms
from PIL import Image
class Config:
embed_size = 300
hidden_size = 512
num_layers = 1
feature_dim = 2048
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(Encoder, self).__init__()
self.linear = nn.Linear(input_dim, hidden_dim)
self.bn = nn.BatchNorm1d(hidden_dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, images):
x = self.linear(images)
x = self.bn(x)
return self.dropout(self.relu(x))
class Decoder(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
super(Decoder, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
def forward(self, features, captions):
return None
class Seq2Seq(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, feature_dim):
super(Seq2Seq, self).__init__()
self.encoder = Encoder(feature_dim, hidden_size)
self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers)
device = torch.device("cpu")
with open('vocab_safe.pkl', 'rb') as f:
vocab_data = pickle.load(f)
itos = vocab_data['itos']
stoi = vocab_data['stoi']
vocab_size = len(itos)
model = Seq2Seq(Config.embed_size, Config.hidden_size, vocab_size, Config.num_layers, Config.feature_dim)
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = nn.Sequential(*list(resnet.children())[:-1]).to(device)
resnet.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
def generate_caption(image):
try:
if image is None:
return "Please upload an image first."
image = image.convert('RGB')
img_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
features = resnet(img_tensor).view(1, -1)
with torch.no_grad():
enc_out = model.encoder(features).unsqueeze(0)
h, c = enc_out, enc_out
word_idx = stoi['<start>']
word = torch.tensor(word_idx).view(1).to(device)
caption = []
for i in range(20):
embed = model.decoder.embed(word).view(1, 1, -1)
output, (h, c) = model.decoder.lstm(embed, (h, c))
prediction = model.decoder.linear(output)
idx = prediction.argmax(2).item()
if idx == stoi['<end>']:
break
word_str = itos.get(idx, "<unk>")
caption.append(word_str)
word = torch.tensor(idx).view(1).to(device)
final_caption = " ".join(caption).strip().capitalize()
if final_caption:
final_caption += "."
return final_caption
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🖼️ Image Captioning Generator
Upload an image to generate a descriptive caption.
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
generate_btn = gr.Button("✨ Generate Caption", variant="primary")
with gr.Column():
caption_output = gr.Textbox(label="Generated Caption", lines=4, interactive=False)
generate_btn.click(
fn=generate_caption,
inputs=image_input,
outputs=caption_output
)
if __name__ == "__main__":
demo.launch()