| import os |
| import io |
| import base64 |
| import time |
| import logging |
| import threading |
| import uuid |
| from datetime import datetime |
| from pathlib import Path |
| from collections import deque |
| from typing import Dict, Optional, Tuple |
|
|
| import gradio as gr |
| from gradio_client import Client |
| from PIL import Image |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| MAX_QUEUE_SIZE = 50 |
| MAX_CONCURRENT_REQUESTS = 1 |
| AVERAGE_PROCESSING_TIME = 15 |
|
|
| |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| if not HF_TOKEN: |
| raise ValueError("HF_TOKEN environment variable is required") |
|
|
| |
| class QueueManager: |
| def __init__(self): |
| self.queue = deque() |
| self.processing = {} |
| self.completed = {} |
| self.failed = {} |
| self.lock = threading.Lock() |
| self.stats = { |
| 'total_processed': 0, |
| 'total_failed': 0, |
| 'avg_processing_time': AVERAGE_PROCESSING_TIME |
| } |
| |
| def add_request(self, request_id: str, user_data: dict) -> Tuple[int, float]: |
| """Add request to queue. Returns (position, estimated_wait)""" |
| with self.lock: |
| if len(self.queue) >= MAX_QUEUE_SIZE: |
| raise Exception("Queue is full. Please try again later.") |
| |
| self.queue.append((request_id, user_data, time.time())) |
| position = len(self.queue) |
| |
| |
| processing_count = len(self.processing) |
| queue_ahead = position - 1 |
| |
| if processing_count == 0: |
| estimated_wait = 0 |
| else: |
| estimated_wait = (queue_ahead + 1) * self.stats['avg_processing_time'] |
| |
| logger.info(f"Request {request_id} added to queue. Position: {position}, Est. wait: {estimated_wait:.0f}s") |
| return position, estimated_wait |
| |
| def get_next_requests(self): |
| """Get next request to process (only 1 at a time for GPU)""" |
| with self.lock: |
| if len(self.processing) >= MAX_CONCURRENT_REQUESTS or len(self.queue) == 0: |
| return [] |
| |
| request_id, user_data, timestamp = self.queue.popleft() |
| self.processing[request_id] = time.time() |
| return [(request_id, user_data)] |
| |
| def complete_request(self, request_id: str, result): |
| """Mark request as completed""" |
| with self.lock: |
| if request_id in self.processing: |
| processing_time = time.time() - self.processing[request_id] |
| del self.processing[request_id] |
| self.completed[request_id] = result |
| self.stats['total_processed'] += 1 |
| |
| |
| current_avg = self.stats['avg_processing_time'] |
| self.stats['avg_processing_time'] = (current_avg * 0.8) + (processing_time * 0.2) |
| |
| logger.info(f"Request {request_id} completed in {processing_time:.1f}s") |
| |
| def fail_request(self, request_id: str, error_msg: str): |
| """Mark request as failed""" |
| with self.lock: |
| if request_id in self.processing: |
| del self.processing[request_id] |
| self.failed[request_id] = error_msg |
| self.stats['total_failed'] += 1 |
| logger.error(f"Request {request_id} failed: {error_msg}") |
| |
| def get_request_status(self, request_id: str) -> dict: |
| """Get status of specific request""" |
| with self.lock: |
| if request_id in self.completed: |
| return {'status': 'completed', 'result': self.completed[request_id]} |
| elif request_id in self.failed: |
| return {'status': 'failed', 'error': self.failed[request_id]} |
| elif request_id in self.processing: |
| processing_time = time.time() - self.processing[request_id] |
| return {'status': 'processing', 'time': processing_time} |
| else: |
| for i, (rid, _, _) in enumerate(self.queue): |
| if rid == request_id: |
| return {'status': 'queued', 'position': i + 1} |
| return {'status': 'not_found'} |
|
|
| |
| queue_manager = QueueManager() |
|
|
| backend_status = { |
| "client": None, |
| "connected": False, |
| "last_check": None, |
| "error_message": "" |
| } |
|
|
| def check_backend_connection(): |
| """Ping the HF Space and cache the client object.""" |
| try: |
| test_client = Client("milliyin/backend", hf_token=HF_TOKEN) |
| backend_status.update({ |
| "client": test_client, |
| "connected": True, |
| "error_message": "", |
| "last_check": time.time(), |
| }) |
| logger.info("β
Backend connection established") |
| return True, "π’ Model is ready" |
| except Exception as e: |
| backend_status.update({ |
| "client": None, |
| "connected": False, |
| "last_check": time.time(), |
| "error_message": str(e), |
| }) |
| err = str(e).lower() |
| if "timeout" in err or "read operation timed out" in err: |
| return False, "π‘ Model is starting up. Please wait 3β4 min." |
| return False, f"π΄ Backend error: {e}" |
|
|
| |
| check_backend_connection() |
|
|
| |
| def queue_worker(): |
| """Background worker to process queue - one request at a time""" |
| while True: |
| try: |
| requests = queue_manager.get_next_requests() |
| |
| if not requests: |
| time.sleep(1) |
| continue |
| |
| |
| request_id, user_data = requests[0] |
| logger.info(f"Starting processing request {request_id}") |
| |
| process_single_request(request_id, user_data) |
| time.sleep(0.5) |
| |
| except Exception as e: |
| logger.error(f"Queue worker error: {e}") |
| time.sleep(5) |
|
|
| def process_single_request(request_id: str, user_data: dict): |
| """Process a single request""" |
| try: |
| img_b64 = user_data['image_b64'] |
| category = user_data['category'] |
| gender = user_data['gender'] |
| |
| if not backend_status["connected"]: |
| check_backend_connection() |
| if not backend_status["connected"]: |
| raise Exception("Backend not available") |
| |
| client = backend_status["client"] |
| start_time = time.time() |
| |
| result = client.predict( |
| img_b64, |
| category, |
| gender, |
| api_name="/predict", |
| ) |
| |
| processing_time = time.time() - start_time |
| |
| if not result or len(result) < 4: |
| raise ValueError("Invalid response structure from backend") |
| |
| _, overlay_b64, bg_b64, status = result |
| |
| final_result = { |
| 'overlay_b64': overlay_b64, |
| 'bg_b64': bg_b64, |
| 'status': status, |
| 'processing_time': processing_time |
| } |
| |
| queue_manager.complete_request(request_id, final_result) |
| |
| except Exception as e: |
| queue_manager.fail_request(request_id, str(e)) |
|
|
| |
| worker_thread = threading.Thread(target=queue_worker, daemon=True) |
| worker_thread.start() |
|
|
| |
| def image_to_base64(image: Image.Image) -> str: |
| if image is None: |
| return "" |
| if image.mode != "RGB": |
| image = image.convert("RGB") |
| buf = io.BytesIO() |
| image.save(buf, format="PNG") |
| return base64.b64encode(buf.getvalue()).decode() |
|
|
| def base64_to_image(b64: str) -> Optional[Image.Image]: |
| if not b64: |
| return None |
| try: |
| return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") |
| except Exception as e: |
| logger.error(f"Failed to decode base64 β image: {e}") |
| return None |
|
|
| |
| active_requests = {} |
|
|
| def submit_request(input_image: Image.Image, category: str, gender: str): |
| """Submit a new request to the queue""" |
| if input_image is None: |
| return None, None, "β Please upload an image.", gr.update(interactive=True), "" |
| |
| try: |
| request_id = str(uuid.uuid4()) |
| |
| img_b64 = image_to_base64(input_image) |
| user_data = { |
| 'image_b64': img_b64, |
| 'category': category, |
| 'gender': gender, |
| 'timestamp': time.time() |
| } |
| |
| position, estimated_wait = queue_manager.add_request(request_id, user_data) |
| |
| status_msg = f"π Request submitted! Position in queue: #{position}" |
| if position == 1 and len(queue_manager.processing) == 0: |
| status_msg += " | Starting processing now..." |
| elif estimated_wait > 0: |
| status_msg += f" | Estimated wait: {estimated_wait:.0f}s" |
| |
| return None, None, status_msg, gr.update(interactive=False), request_id |
| |
| except Exception as e: |
| return None, None, f"β {str(e)}", gr.update(interactive=True), "" |
|
|
| def check_request_status(request_id: str): |
| """Check the status of a request""" |
| if not request_id: |
| return None, None, "No active request", gr.update(interactive=True) |
| |
| status_info = queue_manager.get_request_status(request_id) |
| |
| if status_info['status'] == 'completed': |
| result = status_info['result'] |
| overlay_img = base64_to_image(result['overlay_b64']) |
| bg_img = base64_to_image(result['bg_b64']) |
| status_msg = f"β
{result['status']} (β± {result['processing_time']:.1f}s)" |
| return overlay_img, bg_img, status_msg, gr.update(interactive=True) |
| |
| elif status_info['status'] == 'failed': |
| return None, None, f"β {status_info['error']}", gr.update(interactive=True) |
| |
| elif status_info['status'] == 'processing': |
| processing_time = status_info['time'] |
| return None, None, f"β‘ Processing... ({processing_time:.1f}s)", gr.update(interactive=False) |
| |
| elif status_info['status'] == 'queued': |
| position = status_info['position'] |
| avg_time = queue_manager.stats['avg_processing_time'] |
| estimated_wait = position * avg_time |
| wait_msg = f" | Est. wait: {int(estimated_wait/60)}m {int(estimated_wait%60)}s" if estimated_wait > 30 else "" |
| return None, None, f"β³ In queue, position #{position}{wait_msg}", gr.update(interactive=False) |
| |
| else: |
| return None, None, "β Request not found", gr.update(interactive=True) |
|
|
| def disable_button(): |
| return gr.update(interactive=False) |
|
|
| |
| custom_css = """ |
| .gradio-container { |
| background: linear-gradient(135deg, #3b4371 0%, #2d1b69 25%, #673ab7 50%, #8e24aa 75%, #6a1b9a 100%); |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
| min-height: 100vh; |
| } |
| .contain { |
| background: rgba(255, 255, 255, 0.95); |
| border-radius: 15px; |
| padding: 25px; |
| margin: 15px; |
| box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2); |
| backdrop-filter: blur(10px); |
| } |
| .title-container { |
| text-align: center; |
| margin-bottom: 25px; |
| padding: 20px; |
| background: linear-gradient(135deg, #673ab7, #8e24aa); |
| border-radius: 12px; |
| box-shadow: 0 5px 20px rgba(103, 58, 183, 0.4); |
| } |
| .title-container h1 { |
| color: white; |
| font-size: 2.2em; |
| font-weight: bold; |
| margin: 0; |
| text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.3); |
| } |
| .info-bar { |
| background: linear-gradient(135deg, #7c4dff, #6a1b9a); |
| padding: 12px; |
| border-radius: 8px; |
| margin-bottom: 20px; |
| color: white; |
| text-align: center; |
| font-weight: 500; |
| box-shadow: 0 3px 12px rgba(124, 77, 255, 0.3); |
| } |
| .section-header { |
| background: linear-gradient(135deg, #e1bee7, #d1c4e9); |
| padding: 12px; |
| border-radius: 8px; |
| margin-bottom: 15px; |
| border-left: 4px solid #673ab7; |
| } |
| .section-header h3 { |
| margin: 0; |
| color: #333; |
| font-weight: 600; |
| } |
| .input-group { |
| background: rgba(255, 255, 255, 0.85); |
| padding: 18px; |
| border-radius: 12px; |
| margin-bottom: 15px; |
| border: 1px solid rgba(103, 58, 183, 0.2); |
| box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1); |
| } |
| .result-section { |
| background: rgba(255, 255, 255, 0.9); |
| padding: 18px; |
| border-radius: 12px; |
| border: 1px solid rgba(103, 58, 183, 0.2); |
| box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1); |
| } |
| .tip-box { |
| background: linear-gradient(135deg, #f3e5f5, #e8eaf6); |
| padding: 10px; |
| border-radius: 6px; |
| margin: 8px 0; |
| border-left: 3px solid #673ab7; |
| color: #4a148c; |
| font-weight: 500; |
| } |
| button.primary { |
| background: linear-gradient(135deg, #673ab7, #8e24aa) !important; |
| border: none !important; |
| border-radius: 20px !important; |
| padding: 12px 25px !important; |
| color: white !important; |
| font-weight: bold !important; |
| font-size: 15px !important; |
| box-shadow: 0 5px 15px rgba(103, 58, 183, 0.4) !important; |
| } |
| button.primary:hover { |
| box-shadow: 0 8px 25px rgba(103, 58, 183, 0.6) !important; |
| opacity: 0.9 !important; |
| transform: translateY(-2px) !important; |
| } |
| label { |
| color: #4a148c !important; |
| font-weight: 600 !important; |
| } |
| input, textarea, select { |
| border: 1px solid rgba(103, 58, 183, 0.3) !important; |
| border-radius: 6px !important; |
| } |
| input:focus, textarea:focus, select:focus { |
| border-color: #673ab7 !important; |
| box-shadow: 0 0 0 2px rgba(103, 58, 183, 0.2) !important; |
| } |
| .gr-slider input[type="range"] { |
| accent-color: #673ab7 !important; |
| } |
| input[type="checkbox"] { |
| accent-color: #673ab7 !important; |
| } |
| .preserve-aspect-ratio img { |
| object-fit: contain !important; |
| width: auto !important; |
| max-height: 512px !important; |
| } |
| .social-links { |
| text-align: center; |
| margin: 20px 0; |
| } |
| .social-links a { |
| margin: 0 10px; |
| padding: 8px 16px; |
| background: #667eea; |
| color: white; |
| text-decoration: none; |
| border-radius: 8px; |
| transition: all 0.3s ease; |
| } |
| .social-links a:hover { |
| background: #764ba2; |
| transform: translateY(-2px); |
| } |
| .feature-box { |
| background: #f8fafc; |
| border: 1px solid #e2e8f0; |
| padding: 20px; |
| border-radius: 12px; |
| margin: 10px 0; |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=custom_css, title="Jewellery Photography Preview") as demo: |
| |
| gr.HTML(""" |
| <div style="text-align: center; margin-bottom: 20px;"> |
| <h1 style="font-size: 2.5em;">π¨ Raresence: AI-Powered Jewellery Photo Preview</h1> |
| <p style="color: #666;">Upload a jewellery image, select model, and get professional photos instantly</p> |
| </div> |
| """) |
|
|
| |
| status_html = gr.HTML() |
|
|
| def _update_status(): |
| ok, msg = check_backend_connection() |
| cls = "status-ready" if ok else ("status-starting" if "π‘" in msg else "status-error") |
| return f'<div class="status-banner {cls}">{msg}</div>' |
|
|
| status_html.value = _update_status() |
| gr.Button("π Check Status").click(fn=_update_status, outputs=status_html) |
|
|
| with gr.Column(): |
| with gr.Row(): |
| |
| with gr.Column(scale=0.4): |
| gr.HTML(""" |
| <div class="feature-box""> |
| <h3>πΌοΈ Upload Jewellery Image</h3> |
| <p style="color: #666; font-size: 14px;">Select a clear jewellery image for best results</p> |
| </div> |
| """) |
| gr.Markdown("β") |
| gr.Markdown("β") |
| input_img = gr.Image(label="Upload image", type="pil", height=400) |
| |
| with gr.Column(): |
| gr.HTML(""" |
| <div class="feature-box"> |
| <h3>π¨ AI Generated Results</h3> |
| <p style="color: #666; font-size: 14px;">Preview overlay detection and final professional background</p> |
| </div> |
| """) |
| |
| with gr.Tabs(): |
| with gr.TabItem("Final result"): |
| info2 = gr.Markdown(value="### Final result") |
| out_bg = gr.Image(height=400) |
| with gr.TabItem("Detection overlay"): |
| info1 = gr.Markdown(value="### Detection overlay") |
| out_overlay = gr.Image(height=400) |
| run_btn = gr.Button("π― Generate", elem_id="button", variant="primary") |
| |
| with gr.Row(): |
| with gr.Column(scale=0.4): |
| gr.Markdown(value="Setting") |
| category = gr.Dropdown(label="Jewellery category", choices=["Rings", "Bracelets", "Watches", "Earrings"], value="Bracelets") |
| gender = gr.Dropdown(label="Model gender", choices=["male", "female"], value="female") |
| |
| |
| out_status = gr.Text(label="Status", interactive=False) |
| |
| |
| gr.HTML(""" |
| <div style="text-align:center;padding:40px 20px;background:#f8fafc;border:1px solid #e2e8f0;border-radius:16px;margin:30px 0;"> |
| <h3 style="color:#333;">π Powered by Snapwear AI</h3> |
| <p style="color:#666;"> |
| Experience the future of virtual fashion and garment visualization. |
| </p> |
| <div class="social-links"> |
| <a href="https://snapwear.io" target="_blank">π Website</a> |
| <a href="https://www.instagram.com/snapwearai/" target="_blank">πΈ Instagram</a> |
| <a href="https://huggingface.co/spaces/SnapwearAI/Snapwear-Texture-Transfer" target="_blank">π¨ Pattern Transfer</a> |
| </div> |
| <p style="font-size:12px;color:#999;margin-top:20px;"> |
| Β© 2024 Snapwear AI. Professional AI tools for fashion and design. |
| </p> |
| </div> |
| """) |
|
|
| |
| current_request_id = gr.State("") |
|
|
| |
| run_btn.click( |
| fn=disable_button, |
| inputs=None, |
| outputs=run_btn |
| ).then( |
| fn=submit_request, |
| inputs=[input_img, category, gender], |
| outputs=[out_overlay, out_bg, out_status, run_btn, current_request_id], |
| show_progress=True, |
| ) |
|
|
| |
| def auto_status_check(request_id): |
| if request_id: |
| return check_request_status(request_id) |
| return None, None, "Ready to generate", gr.update(interactive=True) |
|
|
| |
| demo.load(lambda: None) |
| |
| |
| timer = gr.Timer(2) |
| timer.tick( |
| fn=auto_status_check, |
| inputs=[current_request_id], |
| outputs=[out_overlay, out_bg, out_status, run_btn] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| demo.queue(max_size=MAX_QUEUE_SIZE + 10, default_concurrency_limit=1).launch(share=False) |