MAY199 commited on
Commit
989e7e8
·
verified ·
1 Parent(s): a4b7ce1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from datasets import load_dataset
6
+ from transformers import CLIPProcessor, CLIPModel
7
+
8
+ # Load dataset
9
+ ds = load_dataset("amaye15/landscapes")
10
+ train_ds = ds["train"]
11
+
12
+ # Load embeddings
13
+ df = pd.read_parquet("image_embeddings_clip.parquet")
14
+ image_indices = df["image_index"].values
15
+ emb_matrix = df.drop(columns=["image_index"]).values.astype(np.float32)
16
+
17
+ # Load CLIP
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model_name = "openai/clip-vit-base-patch32"
20
+ processor = CLIPProcessor.from_pretrained(model_name)
21
+ model = CLIPModel.from_pretrained(model_name).to(device)
22
+ model.eval()
23
+
24
+ def l2_normalize(x):
25
+ return x / np.linalg.norm(x)
26
+
27
+ @torch.no_grad()
28
+ def embed_image(img):
29
+ inputs = processor(images=img, return_tensors="pt")
30
+ inputs = {k: v.to(device) for k, v in inputs.items()}
31
+ feats = model.get_image_features(**inputs)
32
+ feats = feats / feats.norm(dim=-1, keepdim=True)
33
+ return feats.squeeze(0).cpu().numpy()
34
+
35
+ def recommend(img):
36
+ q_emb = embed_image(img)
37
+ sims = emb_matrix @ l2_normalize(q_emb)
38
+ top = np.argsort(-sims)[1:4]
39
+ results = []
40
+ for i in top:
41
+ results.append(train_ds[int(image_indices[i])]["pixel_values"])
42
+ return results
43
+
44
+ demo = gr.Interface(
45
+ fn=recommend,
46
+ inputs=gr.Image(type="pil", label="Upload a landscape image"),
47
+ outputs=[
48
+ gr.Image(label="Recommendation 1"),
49
+ gr.Image(label="Recommendation 2"),
50
+ gr.Image(label="Recommendation 3"),
51
+ ],
52
+ title="Landscape Image Recommendation System",
53
+ description="Upload a landscape image and receive visually similar recommendations."
54
+ )
55
+
56
+ demo.launch()