Files changed (1) hide show
  1. app.py +0 -175
app.py DELETED
@@ -1,175 +0,0 @@
1
- import os
2
- import pickle
3
- import numpy as np
4
- import gradio as gr
5
- import torch
6
- from datasets import load_dataset
7
- from PIL import Image
8
- from transformers import AutoImageProcessor, AutoModel
9
- from diffusers import StableDiffusionPipeline
10
-
11
- # ==========================================================
12
- # CONFIG
13
- # ==========================================================
14
- DATASET_ID = "MAY199/synthetic-sofa-images"
15
- PICKLE_PATH = "sofa_embeddings_for_app.pkl"
16
-
17
- AI_PROMPT = (
18
- "a modern minimalist sofa, neutral colors, scandinavian interior style, "
19
- "high quality product photo, soft natural light, sharp focus"
20
- )
21
- NEGATIVE_PROMPT = (
22
- "low quality, blurry, distorted, extra legs, extra cushions, watermark, text, logo, "
23
- "deformed, bad proportions, cartoon, painting, sketch"
24
- )
25
- SD_MODEL_ID = "runwayml/stable-diffusion-v1-5"
26
-
27
- # ==========================================================
28
- # LOAD DATASET (HF)
29
- # ==========================================================
30
- ds = load_dataset(DATASET_ID)
31
- train_ds = ds["train"]
32
-
33
- # ==========================================================
34
- # LOAD EMBEDDINGS (ViT)
35
- # ==========================================================
36
- with open(PICKLE_PATH, "rb") as f:
37
- data = pickle.load(f)
38
-
39
- model_id = data["model_id"]
40
- emb_matrix = data["embeddings"].astype(np.float32)
41
- image_indices = data["image_indices"]
42
-
43
- if "vit" not in model_id.lower():
44
- raise ValueError(f"This app expects a ViT model_id, got: {model_id}")
45
-
46
- # ==========================================================
47
- # LOAD VIT
48
- # ==========================================================
49
- device = "cuda" if torch.cuda.is_available() else "cpu"
50
- processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
51
- model = AutoModel.from_pretrained(model_id).to(device)
52
- model.eval()
53
-
54
- def l2_normalize(x):
55
- return x / (np.linalg.norm(x) + 1e-12)
56
-
57
- @torch.no_grad()
58
- def embed_image(img: Image.Image) -> np.ndarray:
59
- inputs = processor(images=img, return_tensors="pt")
60
- inputs = {k: v.to(device) for k, v in inputs.items()}
61
- outputs = model(**inputs)
62
- feats = outputs.last_hidden_state.mean(dim=1)
63
- feats = feats / feats.norm(dim=-1, keepdim=True)
64
- return feats.squeeze(0).float().cpu().numpy()
65
-
66
- # ==========================================================
67
- # LOAD SD (lazy-load to avoid blocking startup)
68
- # ==========================================================
69
- pipe = None
70
-
71
- def get_pipe():
72
- global pipe
73
- if pipe is not None:
74
- return pipe
75
-
76
- # print device so you can see in Logs if it's CPU!
77
- print("Stable Diffusion device:", device)
78
-
79
- sd_dtype = torch.float16 if device == "cuda" else torch.float32
80
- p = StableDiffusionPipeline.from_pretrained(
81
- SD_MODEL_ID,
82
- torch_dtype=sd_dtype,
83
- safety_checker=None,
84
- ).to(device)
85
-
86
- if device == "cuda":
87
- try:
88
- p.enable_attention_slicing()
89
- except Exception:
90
- pass
91
-
92
- pipe = p
93
- return pipe
94
-
95
- def generate_ai_recommendation():
96
- p = get_pipe()
97
- with torch.no_grad():
98
- img = p(
99
- prompt=AI_PROMPT,
100
- negative_prompt=NEGATIVE_PROMPT,
101
- num_inference_steps=12, # ↓ faster
102
- guidance_scale=6.0
103
- ).images[0]
104
- return img
105
-
106
- # ==========================================================
107
- # 1) FAST RETRIEVAL (Submit)
108
- # ==========================================================
109
- def recommend_fast(img: Image.Image):
110
- q = l2_normalize(embed_image(img)).astype(np.float32)
111
- sims = emb_matrix @ q
112
- top_idx = np.argsort(-sims)
113
-
114
- # debug
115
- print("Top 5 sims:", sims[top_idx[:5]])
116
-
117
- results = []
118
- for j in top_idx:
119
- if sims[j] > 0.999:
120
- continue
121
- ds_idx = image_indices[j]
122
- results.append(train_ds[int(ds_idx)]["image"])
123
- if len(results) == 3:
124
- break
125
-
126
- while len(results) < 3:
127
- ds_idx = image_indices[int(top_idx[len(results)])]
128
- results.append(train_ds[int(ds_idx)]["image"])
129
-
130
- return results[0], results[1], results[2]
131
-
132
- # ==========================================================
133
- # 2) SLOW GENERATION (button)
134
- # ==========================================================
135
- def recommend_ai():
136
- return generate_ai_recommendation()
137
-
138
- # ==========================================================
139
- # UI
140
- # ==========================================================
141
- with gr.Blocks() as app:
142
- gr.Markdown("# Sofa Recommendation System + AI Generation")
143
- gr.Markdown("Submit returns 3 similar sofas quickly. Generate AI creates a new sofa image (can be slow).")
144
-
145
- inp = gr.Image(type="pil", label="Upload a sofa image")
146
-
147
- # ======================================================
148
- # QUICK STARTERS (NEW)
149
- # ======================================================
150
- gr.Markdown("## Quick Starters (1-click examples)")
151
- gr.Markdown("Click an example image to auto-fill the input, then press **Submit (Fast)**.")
152
-
153
- gr.Examples(
154
- examples=[
155
- "examples/starter1.jpeg",
156
- "examples/starter2.jpeg",
157
- "examples/starter3.jpeg",
158
- ],
159
- inputs=inp,
160
- label="Quick Starters",
161
- )
162
-
163
- with gr.Row():
164
- btn_submit = gr.Button("Submit (Fast)")
165
- btn_ai = gr.Button("Generate AI Recommendation (Slow)")
166
-
167
- out1 = gr.Image(label="Recommendation 1")
168
- out2 = gr.Image(label="Recommendation 2")
169
- out3 = gr.Image(label="Recommendation 3")
170
- out_ai = gr.Image(label="AI Recommendation")
171
-
172
- btn_submit.click(fn=recommend_fast, inputs=inp, outputs=[out1, out2, out3])
173
- btn_ai.click(fn=recommend_ai, inputs=None, outputs=out_ai)
174
-
175
- app.launch()