AliHamza852 commited on
Commit
9f93006
·
verified ·
1 Parent(s): a6ac61d

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +126 -0
  2. best_model.pth +3 -0
  3. requirements.txt +4 -0
  4. vocab_safe.pkl +3 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import pickle
5
+ from torchvision import models, transforms
6
+ from PIL import Image
7
+
8
+ class Config:
9
+ embed_size = 300
10
+ hidden_size = 512
11
+ num_layers = 1
12
+ feature_dim = 2048
13
+
14
+ class Encoder(nn.Module):
15
+ def __init__(self, input_dim, hidden_dim):
16
+ super(Encoder, self).__init__()
17
+ self.linear = nn.Linear(input_dim, hidden_dim)
18
+ self.bn = nn.BatchNorm1d(hidden_dim)
19
+ self.relu = nn.ReLU()
20
+ self.dropout = nn.Dropout(0.5)
21
+
22
+ def forward(self, images):
23
+ x = self.linear(images)
24
+ x = self.bn(x)
25
+ return self.dropout(self.relu(x))
26
+
27
+ class Decoder(nn.Module):
28
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
29
+ super(Decoder, self).__init__()
30
+ self.embed = nn.Embedding(vocab_size, embed_size)
31
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
32
+ self.linear = nn.Linear(hidden_size, vocab_size)
33
+
34
+ def forward(self, features, captions):
35
+ return None
36
+
37
+ class Seq2Seq(nn.Module):
38
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers, feature_dim):
39
+ super(Seq2Seq, self).__init__()
40
+ self.encoder = Encoder(feature_dim, hidden_size)
41
+ self.decoder = Decoder(embed_size, hidden_size, vocab_size, num_layers)
42
+
43
+ device = torch.device("cpu")
44
+
45
+ with open('vocab_safe.pkl', 'rb') as f:
46
+ vocab_data = pickle.load(f)
47
+ itos = vocab_data['itos']
48
+ stoi = vocab_data['stoi']
49
+ vocab_size = len(itos)
50
+
51
+ model = Seq2Seq(Config.embed_size, Config.hidden_size, vocab_size, Config.num_layers, Config.feature_dim)
52
+ model.load_state_dict(torch.load('best_model.pth', map_location=device))
53
+ model.eval()
54
+
55
+ resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
56
+ resnet = nn.Sequential(*list(resnet.children())[:-1]).to(device)
57
+ resnet.eval()
58
+
59
+ transform = transforms.Compose([
60
+ transforms.Resize((224, 224)),
61
+ transforms.ToTensor(),
62
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
63
+ ])
64
+
65
+ def generate_caption(image):
66
+ try:
67
+ if image is None:
68
+ return "Please upload an image first."
69
+
70
+ image = image.convert('RGB')
71
+ img_tensor = transform(image).unsqueeze(0).to(device)
72
+
73
+ with torch.no_grad():
74
+ features = resnet(img_tensor).view(1, -1)
75
+
76
+ with torch.no_grad():
77
+ enc_out = model.encoder(features).unsqueeze(0)
78
+ h, c = enc_out, enc_out
79
+
80
+ word_idx = stoi['<start>']
81
+ word = torch.tensor(word_idx).view(1).to(device)
82
+ caption = []
83
+
84
+ for i in range(20):
85
+ embed = model.decoder.embed(word).view(1, 1, -1)
86
+ output, (h, c) = model.decoder.lstm(embed, (h, c))
87
+ prediction = model.decoder.linear(output)
88
+ idx = prediction.argmax(2).item()
89
+
90
+ if idx == stoi['<end>']:
91
+ break
92
+
93
+ word_str = itos.get(idx, "<unk>")
94
+ caption.append(word_str)
95
+ word = torch.tensor(idx).view(1).to(device)
96
+
97
+ final_caption = " ".join(caption).strip().capitalize()
98
+ if final_caption:
99
+ final_caption += "."
100
+ return final_caption
101
+
102
+ except Exception as e:
103
+ return f"Error: {str(e)}"
104
+
105
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
106
+ gr.Markdown(
107
+ """
108
+ # 🖼️ Image Captioning Generator
109
+ Upload an image to generate a descriptive caption.
110
+ """
111
+ )
112
+ with gr.Row():
113
+ with gr.Column():
114
+ image_input = gr.Image(type="pil", label="Upload Image")
115
+ generate_btn = gr.Button("✨ Generate Caption", variant="primary")
116
+ with gr.Column():
117
+ caption_output = gr.Textbox(label="Generated Caption", lines=4, interactive=False)
118
+
119
+ generate_btn.click(
120
+ fn=generate_caption,
121
+ inputs=image_input,
122
+ outputs=caption_output
123
+ )
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88cf3415474790fc2aadfc6906bddd4a91a85ddccf536d70fb09cc6b8c40e01c
3
+ size 51560373
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ gradio
vocab_safe.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3787a6494232df5dcfa088f6b8d5efbd9e4f23507c0079acb26e85989e67c967
3
+ size 260287