import os import pickle import numpy as np import gradio as gr import torch from datasets import load_dataset from PIL import Image from transformers import AutoImageProcessor, AutoModel from transformers import CLIPProcessor, CLIPModel # Faster generation model (Turbo) from diffusers import AutoPipelineForText2Image # ========================================================== # CONFIG # ========================================================== DATASET_ID = "MAY199/synthetic-sofa-images" PICKLE_PATH = "sofa_embeddings_for_app.pkl" # Fast model for quick generation SD_MODEL_ID = "stabilityai/sd-turbo" CLIP_MODEL_ID = "openai/clip-vit-base-patch32" # Make sure it generates a SOFA (avoid stripes/patterns) NEGATIVE_PROMPT_BASE = ( "low quality, blurry, distorted, watermark, text, logo, deformed, bad proportions, " "cartoon, painting, sketch, bad anatomy, extra legs, extra cushions, " "abstract, pattern, texture, stripes, lines, wallpaper, fabric sample, swatch, " "gradient, color blocks, flat image" ) STYLE_BANK = [ "modern minimalist", "scandinavian", "mid century modern", "contemporary", "japandi", "industrial modern", "luxury modern", "coastal minimal", ] MATERIAL_BANK = [ "linen fabric", "velvet fabric", "boucle fabric", "leather", "microfiber fabric", ] SHOT_BANK = [ "studio product photo, isolated object, seamless white background, softbox lighting, sharp focus, ultra realistic", "clean catalog photo, product cutout, white seamless background, softbox lighting, sharp focus, ultra realistic", ] # ========================================================== # LOAD DATASET (HF) # ========================================================== ds = load_dataset(DATASET_ID) train_ds = ds["train"] # ========================================================== # LOAD EMBEDDINGS (ViT) for FAST retrieval # ========================================================== with open(PICKLE_PATH, "rb") as f: data = pickle.load(f) vit_model_id = data["model_id"] emb_matrix = data["embeddings"].astype(np.float32) image_indices = data["image_indices"] if "vit" not in vit_model_id.lower(): raise ValueError(f"This app expects a ViT model_id for retrieval, got: {vit_model_id}") # ========================================================== # DEVICE # ========================================================== device = "cuda" if torch.cuda.is_available() else "cpu" print("Device:", device) # ========================================================== # LOAD ViT (retrieval encoder) # ========================================================== vit_processor = AutoImageProcessor.from_pretrained(vit_model_id, use_fast=True) vit_model = AutoModel.from_pretrained(vit_model_id).to(device) vit_model.eval() def l2_normalize(x: np.ndarray) -> np.ndarray: return x / (np.linalg.norm(x) + 1e-12) @torch.no_grad() def embed_image_vit(img: Image.Image) -> np.ndarray: inputs = vit_processor(images=img, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} outputs = vit_model(**inputs) feats = outputs.last_hidden_state.mean(dim=1) feats = feats / feats.norm(dim=-1, keepdim=True) return feats.squeeze(0).float().cpu().numpy() # ========================================================== # LOAD CLIP (style selection only) # ========================================================== clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID) clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(device) clip_model.eval() @torch.no_grad() def choose_style_with_clip(input_img: Image.Image, style_prompts: list[str]) -> str: inputs = clip_processor( text=style_prompts, images=input_img, return_tensors="pt", padding=True ) inputs = {k: v.to(device) for k, v in inputs.items()} outputs = clip_model(**inputs) scores = outputs.logits_per_image.squeeze(0) best_idx = int(torch.argmax(scores).item()) return style_prompts[best_idx] # ========================================================== # COLOR EXTRACTION (fast and simple): use mean RGB -> nearest names # ========================================================== NAMED_COLORS_RGB = { "white": (245, 245, 245), "black": (20, 20, 20), "gray": (128, 128, 128), "beige": (215, 198, 170), "brown": (120, 75, 45), "tan": (180, 140, 95), "cream": (240, 230, 205), "red": (210, 40, 40), "pink": (235, 145, 175), "magenta": (200, 60, 160), "purple": (120, 70, 170), "orange": (230, 140, 50), "yellow": (235, 210, 60), "green": (60, 160, 90), "olive": (120, 140, 60), "teal": (50, 150, 150), "cyan": (70, 190, 210), "blue": (60, 110, 205), "navy": (35, 55, 110), } def nearest_named_colors(img: Image.Image, top_n: int = 2): im = img.convert("RGB").resize((64, 64)) arr = np.array(im).reshape(-1, 3).astype(np.float32) mean_rgb = arr.mean(axis=0) # pick nearest by euclidean RGB distance (fast) dists = [] for name, rgb in NAMED_COLORS_RGB.items(): d = float(((mean_rgb - np.array(rgb)) ** 2).sum()) dists.append((d, name)) dists.sort() names = [name for _, name in dists[:top_n]] return names # ========================================================== # SD Turbo txt2img (lazy-load) # ========================================================== sd_pipe = None def get_sd_pipe(): global sd_pipe if sd_pipe is not None: return sd_pipe dtype = torch.float16 if device == "cuda" else torch.float32 p = AutoPipelineForText2Image.from_pretrained( SD_MODEL_ID, torch_dtype=dtype, safety_checker=None, ).to(device) # Turbo-friendly defaults if device == "cuda": try: p.enable_attention_slicing() except Exception: pass sd_pipe = p return sd_pipe # ========================================================== # FAST RETRIEVAL (dataset) using ViT embeddings # ========================================================== def recommend_fast(img: Image.Image): if img is None: raise gr.Error("Please upload an image first.") q = l2_normalize(embed_image_vit(img)).astype(np.float32) sims = emb_matrix @ q top_idx = np.argsort(-sims) results = [] for j in top_idx: if sims[j] > 0.999: continue ds_idx = image_indices[j] results.append(train_ds[int(ds_idx)]["image"]) if len(results) == 3: break while len(results) < 3: ds_idx = image_indices[int(top_idx[len(results)])] results.append(train_ds[int(ds_idx)]["image"]) return results[0], results[1], results[2] # ========================================================== # AI GENERATION (FAST + "SOFA GUARANTEE") # - strong sofa geometry words # - strong negative for patterns/stripes # - Turbo few steps for speed # ========================================================== def build_style_prompts(): prompts = [] for style in STYLE_BANK: for material in MATERIAL_BANK: for shot in SHOT_BANK: prompts.append(f"{style}, {material}, {shot}") return prompts STYLE_PROMPTS = build_style_prompts() def generate_new_sofa_from_input_only(img: Image.Image, seed: int = 0): if img is None: raise gr.Error("Please upload an image first.") p = get_sd_pipe() # colors from input only (fast) color_names = nearest_named_colors(img, top_n=2) color_phrase = " and ".join(color_names) # style from CLIP (input-only) chosen_style = choose_style_with_clip(img, STYLE_PROMPTS) # Strong geometry anchors to avoid "stripes" prompt = ( "a realistic three-dimensional sofa, seating furniture, " "with visible cushions, armrests, and legs, " f"{color_phrase} color, {chosen_style}, " "single sofa only, isolated object, product cutout, " "no room, no living room, no interior, no background scene" ) negative_prompt = NEGATIVE_PROMPT_BASE generator = None if seed and int(seed) > 0: generator = torch.Generator(device=device).manual_seed(int(seed)) # SD Turbo fast settings out = p( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=4, guidance_scale=1.5, height=512, width=512, generator=generator, ) img_out = out.images[0] return img_out, prompt, str(color_names) # ========================================================== # UI # ========================================================== with gr.Blocks() as app: gr.Markdown("# Sofa Recommendation System + AI Generation") gr.Markdown( "Submit (Fast) returns 3 similar sofas from the dataset using ViT embeddings. " "Generate AI creates a NEW sofa image influenced ONLY by the input image via auto prompt." ) inp = gr.Image(type="pil", label="Upload a living room or sofa image") gr.Markdown("## Quick Starters (1-click examples)") gr.Examples( examples=[ "examples/starter1.jpeg", "examples/starter2.jpeg", "examples/starter3.jpeg", ], inputs=inp, label="Quick Starters", ) with gr.Row(): btn_submit = gr.Button("Submit (Fast) - Dataset Retrieval") btn_ai = gr.Button("Generate AI (New Sofa) - Input Only") with gr.Row(): seed_inp = gr.Slider(minimum=0, maximum=999999, step=1, value=0, label="Seed (0 = random)") palette_box = gr.Textbox(label="Detected colors (names from INPUT only)", lines=1) chosen_prompt_box = gr.Textbox(label="Auto prompt used for generation (built from INPUT only)", lines=4) with gr.Row(): out1 = gr.Image(label="Recommendation 1 (dataset)") out2 = gr.Image(label="Recommendation 2 (dataset)") out3 = gr.Image(label="Recommendation 3 (dataset)") out_ai = gr.Image(label="AI Generated Sofa (new image, no dataset image used)") btn_submit.click(fn=recommend_fast, inputs=inp, outputs=[out1, out2, out3]) btn_ai.click(fn=generate_new_sofa_from_input_only, inputs=[inp, seed_inp], outputs=[out_ai, chosen_prompt_box, palette_box]) app.launch()