Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| class Config: | |
| embed_size = 300 | |
| hidden_size = 512 | |
| num_layers = 1 | |
| feature_dim = 2048 | |
| class Encoder(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super(Encoder, self).__init__() | |
| self.linear = nn.Linear(input_dim, hidden_dim) | |
| self.bn = nn.BatchNorm1d(hidden_dim) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(0.5) | |
| def forward(self, images): | |
| x = self.linear(images) | |
| x = self.bn(x) | |
| return self.dropout(self.relu(x)) | |
| class Decoder(nn.Module): | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers): | |
| super(Decoder, self).__init__() | |
| self.embed = nn.Embedding(vocab_size, embed_size) | |
| self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) | |
| self.linear = nn.Linear(hidden_size, vocab_size) | |
| def forward(self, features, captions): | |
| return None | |
| class Seq2Seq(nn.Module): | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers, feature_dim): | |
| super(Seq2Seq, self).__init__() | |
| self.encoder = Encoder(feature_dim, hidden_size) | |
| self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers) | |
| device = torch.device("cpu") | |
| with open('vocab_safe.pkl', 'rb') as f: | |
| vocab_data = pickle.load(f) | |
| itos = vocab_data['itos'] | |
| stoi = vocab_data['stoi'] | |
| vocab_size = len(itos) | |
| model = Seq2Seq(Config.embed_size, Config.hidden_size, vocab_size, Config.num_layers, Config.feature_dim) | |
| model.load_state_dict(torch.load('best_model.pth', map_location=device)) | |
| model.eval() | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
| resnet = nn.Sequential(*list(resnet.children())[:-1]).to(device) | |
| resnet.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
| ]) | |
| def generate_caption(image): | |
| try: | |
| if image is None: | |
| return "Please upload an image first." | |
| image = image.convert('RGB') | |
| img_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| features = resnet(img_tensor).view(1, -1) | |
| with torch.no_grad(): | |
| enc_out = model.encoder(features).unsqueeze(0) | |
| h, c = enc_out, enc_out | |
| word_idx = stoi['<start>'] | |
| word = torch.tensor(word_idx).view(1).to(device) | |
| caption = [] | |
| for i in range(20): | |
| embed = model.decoder.embed(word).view(1, 1, -1) | |
| output, (h, c) = model.decoder.lstm(embed, (h, c)) | |
| prediction = model.decoder.linear(output) | |
| idx = prediction.argmax(2).item() | |
| if idx == stoi['<end>']: | |
| break | |
| word_str = itos.get(idx, "<unk>") | |
| caption.append(word_str) | |
| word = torch.tensor(idx).view(1).to(device) | |
| final_caption = " ".join(caption).strip().capitalize() | |
| if final_caption: | |
| final_caption += "." | |
| return final_caption | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🖼️ Image Captioning Generator | |
| Upload an image to generate a descriptive caption. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| generate_btn = gr.Button("✨ Generate Caption", variant="primary") | |
| with gr.Column(): | |
| caption_output = gr.Textbox(label="Generated Caption", lines=4, interactive=False) | |
| generate_btn.click( | |
| fn=generate_caption, | |
| inputs=image_input, | |
| outputs=caption_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |