sofa-match-ai / app.py
Leelu1002's picture
Update app.py
8a3a070 verified
import os
import pickle
import numpy as np
import gradio as gr
import torch
from datasets import load_dataset
from PIL import Image
from transformers import AutoImageProcessor, AutoModel
from transformers import CLIPProcessor, CLIPModel
# Faster generation model (Turbo)
from diffusers import AutoPipelineForText2Image
# ==========================================================
# CONFIG
# ==========================================================
DATASET_ID = "MAY199/synthetic-sofa-images"
PICKLE_PATH = "sofa_embeddings_for_app.pkl"
# Fast model for quick generation
SD_MODEL_ID = "stabilityai/sd-turbo"
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
# Make sure it generates a SOFA (avoid stripes/patterns)
NEGATIVE_PROMPT_BASE = (
"low quality, blurry, distorted, watermark, text, logo, deformed, bad proportions, "
"cartoon, painting, sketch, bad anatomy, extra legs, extra cushions, "
"abstract, pattern, texture, stripes, lines, wallpaper, fabric sample, swatch, "
"gradient, color blocks, flat image"
)
STYLE_BANK = [
"modern minimalist",
"scandinavian",
"mid century modern",
"contemporary",
"japandi",
"industrial modern",
"luxury modern",
"coastal minimal",
]
MATERIAL_BANK = [
"linen fabric",
"velvet fabric",
"boucle fabric",
"leather",
"microfiber fabric",
]
SHOT_BANK = [
"studio product photo, isolated object, seamless white background, softbox lighting, sharp focus, ultra realistic",
"clean catalog photo, product cutout, white seamless background, softbox lighting, sharp focus, ultra realistic",
]
# ==========================================================
# LOAD DATASET (HF)
# ==========================================================
ds = load_dataset(DATASET_ID)
train_ds = ds["train"]
# ==========================================================
# LOAD EMBEDDINGS (ViT) for FAST retrieval
# ==========================================================
with open(PICKLE_PATH, "rb") as f:
data = pickle.load(f)
vit_model_id = data["model_id"]
emb_matrix = data["embeddings"].astype(np.float32)
image_indices = data["image_indices"]
if "vit" not in vit_model_id.lower():
raise ValueError(f"This app expects a ViT model_id for retrieval, got: {vit_model_id}")
# ==========================================================
# DEVICE
# ==========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
# ==========================================================
# LOAD ViT (retrieval encoder)
# ==========================================================
vit_processor = AutoImageProcessor.from_pretrained(vit_model_id, use_fast=True)
vit_model = AutoModel.from_pretrained(vit_model_id).to(device)
vit_model.eval()
def l2_normalize(x: np.ndarray) -> np.ndarray:
return x / (np.linalg.norm(x) + 1e-12)
@torch.no_grad()
def embed_image_vit(img: Image.Image) -> np.ndarray:
inputs = vit_processor(images=img, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = vit_model(**inputs)
feats = outputs.last_hidden_state.mean(dim=1)
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats.squeeze(0).float().cpu().numpy()
# ==========================================================
# LOAD CLIP (style selection only)
# ==========================================================
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID)
clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(device)
clip_model.eval()
@torch.no_grad()
def choose_style_with_clip(input_img: Image.Image, style_prompts: list[str]) -> str:
inputs = clip_processor(
text=style_prompts,
images=input_img,
return_tensors="pt",
padding=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = clip_model(**inputs)
scores = outputs.logits_per_image.squeeze(0)
best_idx = int(torch.argmax(scores).item())
return style_prompts[best_idx]
# ==========================================================
# COLOR EXTRACTION (fast and simple): use mean RGB -> nearest names
# ==========================================================
NAMED_COLORS_RGB = {
"white": (245, 245, 245),
"black": (20, 20, 20),
"gray": (128, 128, 128),
"beige": (215, 198, 170),
"brown": (120, 75, 45),
"tan": (180, 140, 95),
"cream": (240, 230, 205),
"red": (210, 40, 40),
"pink": (235, 145, 175),
"magenta": (200, 60, 160),
"purple": (120, 70, 170),
"orange": (230, 140, 50),
"yellow": (235, 210, 60),
"green": (60, 160, 90),
"olive": (120, 140, 60),
"teal": (50, 150, 150),
"cyan": (70, 190, 210),
"blue": (60, 110, 205),
"navy": (35, 55, 110),
}
def nearest_named_colors(img: Image.Image, top_n: int = 2):
im = img.convert("RGB").resize((64, 64))
arr = np.array(im).reshape(-1, 3).astype(np.float32)
mean_rgb = arr.mean(axis=0)
# pick nearest by euclidean RGB distance (fast)
dists = []
for name, rgb in NAMED_COLORS_RGB.items():
d = float(((mean_rgb - np.array(rgb)) ** 2).sum())
dists.append((d, name))
dists.sort()
names = [name for _, name in dists[:top_n]]
return names
# ==========================================================
# SD Turbo txt2img (lazy-load)
# ==========================================================
sd_pipe = None
def get_sd_pipe():
global sd_pipe
if sd_pipe is not None:
return sd_pipe
dtype = torch.float16 if device == "cuda" else torch.float32
p = AutoPipelineForText2Image.from_pretrained(
SD_MODEL_ID,
torch_dtype=dtype,
safety_checker=None,
).to(device)
# Turbo-friendly defaults
if device == "cuda":
try:
p.enable_attention_slicing()
except Exception:
pass
sd_pipe = p
return sd_pipe
# ==========================================================
# FAST RETRIEVAL (dataset) using ViT embeddings
# ==========================================================
def recommend_fast(img: Image.Image):
if img is None:
raise gr.Error("Please upload an image first.")
q = l2_normalize(embed_image_vit(img)).astype(np.float32)
sims = emb_matrix @ q
top_idx = np.argsort(-sims)
results = []
for j in top_idx:
if sims[j] > 0.999:
continue
ds_idx = image_indices[j]
results.append(train_ds[int(ds_idx)]["image"])
if len(results) == 3:
break
while len(results) < 3:
ds_idx = image_indices[int(top_idx[len(results)])]
results.append(train_ds[int(ds_idx)]["image"])
return results[0], results[1], results[2]
# ==========================================================
# AI GENERATION (FAST + "SOFA GUARANTEE")
# - strong sofa geometry words
# - strong negative for patterns/stripes
# - Turbo few steps for speed
# ==========================================================
def build_style_prompts():
prompts = []
for style in STYLE_BANK:
for material in MATERIAL_BANK:
for shot in SHOT_BANK:
prompts.append(f"{style}, {material}, {shot}")
return prompts
STYLE_PROMPTS = build_style_prompts()
def generate_new_sofa_from_input_only(img: Image.Image, seed: int = 0):
if img is None:
raise gr.Error("Please upload an image first.")
p = get_sd_pipe()
# colors from input only (fast)
color_names = nearest_named_colors(img, top_n=2)
color_phrase = " and ".join(color_names)
# style from CLIP (input-only)
chosen_style = choose_style_with_clip(img, STYLE_PROMPTS)
# Strong geometry anchors to avoid "stripes"
prompt = (
"a realistic three-dimensional sofa, seating furniture, "
"with visible cushions, armrests, and legs, "
f"{color_phrase} color, {chosen_style}, "
"single sofa only, isolated object, product cutout, "
"no room, no living room, no interior, no background scene"
)
negative_prompt = NEGATIVE_PROMPT_BASE
generator = None
if seed and int(seed) > 0:
generator = torch.Generator(device=device).manual_seed(int(seed))
# SD Turbo fast settings
out = p(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=4,
guidance_scale=1.5,
height=512,
width=512,
generator=generator,
)
img_out = out.images[0]
return img_out, prompt, str(color_names)
# ==========================================================
# UI
# ==========================================================
with gr.Blocks() as app:
gr.Markdown("# Sofa Recommendation System + AI Generation")
gr.Markdown(
"Submit (Fast) returns 3 similar sofas from the dataset using ViT embeddings. "
"Generate AI creates a NEW sofa image influenced ONLY by the input image via auto prompt."
)
inp = gr.Image(type="pil", label="Upload a living room or sofa image")
gr.Markdown("## Quick Starters (1-click examples)")
gr.Examples(
examples=[
"examples/starter1.jpeg",
"examples/starter2.jpeg",
"examples/starter3.jpeg",
],
inputs=inp,
label="Quick Starters",
)
with gr.Row():
btn_submit = gr.Button("Submit (Fast) - Dataset Retrieval")
btn_ai = gr.Button("Generate AI (New Sofa) - Input Only")
with gr.Row():
seed_inp = gr.Slider(minimum=0, maximum=999999, step=1, value=0, label="Seed (0 = random)")
palette_box = gr.Textbox(label="Detected colors (names from INPUT only)", lines=1)
chosen_prompt_box = gr.Textbox(label="Auto prompt used for generation (built from INPUT only)", lines=4)
with gr.Row():
out1 = gr.Image(label="Recommendation 1 (dataset)")
out2 = gr.Image(label="Recommendation 2 (dataset)")
out3 = gr.Image(label="Recommendation 3 (dataset)")
out_ai = gr.Image(label="AI Generated Sofa (new image, no dataset image used)")
btn_submit.click(fn=recommend_fast, inputs=inp, outputs=[out1, out2, out3])
btn_ai.click(fn=generate_new_sofa_from_input_only, inputs=[inp, seed_inp], outputs=[out_ai, chosen_prompt_box, palette_box])
app.launch()