Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, AutoModel | |
| from transformers import CLIPProcessor, CLIPModel | |
| # Faster generation model (Turbo) | |
| from diffusers import AutoPipelineForText2Image | |
| # ========================================================== | |
| # CONFIG | |
| # ========================================================== | |
| DATASET_ID = "MAY199/synthetic-sofa-images" | |
| PICKLE_PATH = "sofa_embeddings_for_app.pkl" | |
| # Fast model for quick generation | |
| SD_MODEL_ID = "stabilityai/sd-turbo" | |
| CLIP_MODEL_ID = "openai/clip-vit-base-patch32" | |
| # Make sure it generates a SOFA (avoid stripes/patterns) | |
| NEGATIVE_PROMPT_BASE = ( | |
| "low quality, blurry, distorted, watermark, text, logo, deformed, bad proportions, " | |
| "cartoon, painting, sketch, bad anatomy, extra legs, extra cushions, " | |
| "abstract, pattern, texture, stripes, lines, wallpaper, fabric sample, swatch, " | |
| "gradient, color blocks, flat image" | |
| ) | |
| STYLE_BANK = [ | |
| "modern minimalist", | |
| "scandinavian", | |
| "mid century modern", | |
| "contemporary", | |
| "japandi", | |
| "industrial modern", | |
| "luxury modern", | |
| "coastal minimal", | |
| ] | |
| MATERIAL_BANK = [ | |
| "linen fabric", | |
| "velvet fabric", | |
| "boucle fabric", | |
| "leather", | |
| "microfiber fabric", | |
| ] | |
| SHOT_BANK = [ | |
| "studio product photo, isolated object, seamless white background, softbox lighting, sharp focus, ultra realistic", | |
| "clean catalog photo, product cutout, white seamless background, softbox lighting, sharp focus, ultra realistic", | |
| ] | |
| # ========================================================== | |
| # LOAD DATASET (HF) | |
| # ========================================================== | |
| ds = load_dataset(DATASET_ID) | |
| train_ds = ds["train"] | |
| # ========================================================== | |
| # LOAD EMBEDDINGS (ViT) for FAST retrieval | |
| # ========================================================== | |
| with open(PICKLE_PATH, "rb") as f: | |
| data = pickle.load(f) | |
| vit_model_id = data["model_id"] | |
| emb_matrix = data["embeddings"].astype(np.float32) | |
| image_indices = data["image_indices"] | |
| if "vit" not in vit_model_id.lower(): | |
| raise ValueError(f"This app expects a ViT model_id for retrieval, got: {vit_model_id}") | |
| # ========================================================== | |
| # DEVICE | |
| # ========================================================== | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Device:", device) | |
| # ========================================================== | |
| # LOAD ViT (retrieval encoder) | |
| # ========================================================== | |
| vit_processor = AutoImageProcessor.from_pretrained(vit_model_id, use_fast=True) | |
| vit_model = AutoModel.from_pretrained(vit_model_id).to(device) | |
| vit_model.eval() | |
| def l2_normalize(x: np.ndarray) -> np.ndarray: | |
| return x / (np.linalg.norm(x) + 1e-12) | |
| def embed_image_vit(img: Image.Image) -> np.ndarray: | |
| inputs = vit_processor(images=img, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| outputs = vit_model(**inputs) | |
| feats = outputs.last_hidden_state.mean(dim=1) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats.squeeze(0).float().cpu().numpy() | |
| # ========================================================== | |
| # LOAD CLIP (style selection only) | |
| # ========================================================== | |
| clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL_ID) | |
| clip_model = CLIPModel.from_pretrained(CLIP_MODEL_ID).to(device) | |
| clip_model.eval() | |
| def choose_style_with_clip(input_img: Image.Image, style_prompts: list[str]) -> str: | |
| inputs = clip_processor( | |
| text=style_prompts, | |
| images=input_img, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| outputs = clip_model(**inputs) | |
| scores = outputs.logits_per_image.squeeze(0) | |
| best_idx = int(torch.argmax(scores).item()) | |
| return style_prompts[best_idx] | |
| # ========================================================== | |
| # COLOR EXTRACTION (fast and simple): use mean RGB -> nearest names | |
| # ========================================================== | |
| NAMED_COLORS_RGB = { | |
| "white": (245, 245, 245), | |
| "black": (20, 20, 20), | |
| "gray": (128, 128, 128), | |
| "beige": (215, 198, 170), | |
| "brown": (120, 75, 45), | |
| "tan": (180, 140, 95), | |
| "cream": (240, 230, 205), | |
| "red": (210, 40, 40), | |
| "pink": (235, 145, 175), | |
| "magenta": (200, 60, 160), | |
| "purple": (120, 70, 170), | |
| "orange": (230, 140, 50), | |
| "yellow": (235, 210, 60), | |
| "green": (60, 160, 90), | |
| "olive": (120, 140, 60), | |
| "teal": (50, 150, 150), | |
| "cyan": (70, 190, 210), | |
| "blue": (60, 110, 205), | |
| "navy": (35, 55, 110), | |
| } | |
| def nearest_named_colors(img: Image.Image, top_n: int = 2): | |
| im = img.convert("RGB").resize((64, 64)) | |
| arr = np.array(im).reshape(-1, 3).astype(np.float32) | |
| mean_rgb = arr.mean(axis=0) | |
| # pick nearest by euclidean RGB distance (fast) | |
| dists = [] | |
| for name, rgb in NAMED_COLORS_RGB.items(): | |
| d = float(((mean_rgb - np.array(rgb)) ** 2).sum()) | |
| dists.append((d, name)) | |
| dists.sort() | |
| names = [name for _, name in dists[:top_n]] | |
| return names | |
| # ========================================================== | |
| # SD Turbo txt2img (lazy-load) | |
| # ========================================================== | |
| sd_pipe = None | |
| def get_sd_pipe(): | |
| global sd_pipe | |
| if sd_pipe is not None: | |
| return sd_pipe | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| p = AutoPipelineForText2Image.from_pretrained( | |
| SD_MODEL_ID, | |
| torch_dtype=dtype, | |
| safety_checker=None, | |
| ).to(device) | |
| # Turbo-friendly defaults | |
| if device == "cuda": | |
| try: | |
| p.enable_attention_slicing() | |
| except Exception: | |
| pass | |
| sd_pipe = p | |
| return sd_pipe | |
| # ========================================================== | |
| # FAST RETRIEVAL (dataset) using ViT embeddings | |
| # ========================================================== | |
| def recommend_fast(img: Image.Image): | |
| if img is None: | |
| raise gr.Error("Please upload an image first.") | |
| q = l2_normalize(embed_image_vit(img)).astype(np.float32) | |
| sims = emb_matrix @ q | |
| top_idx = np.argsort(-sims) | |
| results = [] | |
| for j in top_idx: | |
| if sims[j] > 0.999: | |
| continue | |
| ds_idx = image_indices[j] | |
| results.append(train_ds[int(ds_idx)]["image"]) | |
| if len(results) == 3: | |
| break | |
| while len(results) < 3: | |
| ds_idx = image_indices[int(top_idx[len(results)])] | |
| results.append(train_ds[int(ds_idx)]["image"]) | |
| return results[0], results[1], results[2] | |
| # ========================================================== | |
| # AI GENERATION (FAST + "SOFA GUARANTEE") | |
| # - strong sofa geometry words | |
| # - strong negative for patterns/stripes | |
| # - Turbo few steps for speed | |
| # ========================================================== | |
| def build_style_prompts(): | |
| prompts = [] | |
| for style in STYLE_BANK: | |
| for material in MATERIAL_BANK: | |
| for shot in SHOT_BANK: | |
| prompts.append(f"{style}, {material}, {shot}") | |
| return prompts | |
| STYLE_PROMPTS = build_style_prompts() | |
| def generate_new_sofa_from_input_only(img: Image.Image, seed: int = 0): | |
| if img is None: | |
| raise gr.Error("Please upload an image first.") | |
| p = get_sd_pipe() | |
| # colors from input only (fast) | |
| color_names = nearest_named_colors(img, top_n=2) | |
| color_phrase = " and ".join(color_names) | |
| # style from CLIP (input-only) | |
| chosen_style = choose_style_with_clip(img, STYLE_PROMPTS) | |
| # Strong geometry anchors to avoid "stripes" | |
| prompt = ( | |
| "a realistic three-dimensional sofa, seating furniture, " | |
| "with visible cushions, armrests, and legs, " | |
| f"{color_phrase} color, {chosen_style}, " | |
| "single sofa only, isolated object, product cutout, " | |
| "no room, no living room, no interior, no background scene" | |
| ) | |
| negative_prompt = NEGATIVE_PROMPT_BASE | |
| generator = None | |
| if seed and int(seed) > 0: | |
| generator = torch.Generator(device=device).manual_seed(int(seed)) | |
| # SD Turbo fast settings | |
| out = p( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=4, | |
| guidance_scale=1.5, | |
| height=512, | |
| width=512, | |
| generator=generator, | |
| ) | |
| img_out = out.images[0] | |
| return img_out, prompt, str(color_names) | |
| # ========================================================== | |
| # UI | |
| # ========================================================== | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Sofa Recommendation System + AI Generation") | |
| gr.Markdown( | |
| "Submit (Fast) returns 3 similar sofas from the dataset using ViT embeddings. " | |
| "Generate AI creates a NEW sofa image influenced ONLY by the input image via auto prompt." | |
| ) | |
| inp = gr.Image(type="pil", label="Upload a living room or sofa image") | |
| gr.Markdown("## Quick Starters (1-click examples)") | |
| gr.Examples( | |
| examples=[ | |
| "examples/starter1.jpeg", | |
| "examples/starter2.jpeg", | |
| "examples/starter3.jpeg", | |
| ], | |
| inputs=inp, | |
| label="Quick Starters", | |
| ) | |
| with gr.Row(): | |
| btn_submit = gr.Button("Submit (Fast) - Dataset Retrieval") | |
| btn_ai = gr.Button("Generate AI (New Sofa) - Input Only") | |
| with gr.Row(): | |
| seed_inp = gr.Slider(minimum=0, maximum=999999, step=1, value=0, label="Seed (0 = random)") | |
| palette_box = gr.Textbox(label="Detected colors (names from INPUT only)", lines=1) | |
| chosen_prompt_box = gr.Textbox(label="Auto prompt used for generation (built from INPUT only)", lines=4) | |
| with gr.Row(): | |
| out1 = gr.Image(label="Recommendation 1 (dataset)") | |
| out2 = gr.Image(label="Recommendation 2 (dataset)") | |
| out3 = gr.Image(label="Recommendation 3 (dataset)") | |
| out_ai = gr.Image(label="AI Generated Sofa (new image, no dataset image used)") | |
| btn_submit.click(fn=recommend_fast, inputs=inp, outputs=[out1, out2, out3]) | |
| btn_ai.click(fn=generate_new_sofa_from_input_only, inputs=[inp, seed_inp], outputs=[out_ai, chosen_prompt_box, palette_box]) | |
| app.launch() | |