Spaces:
Sleeping
Sleeping
| """ | |
| Museum AI Studio - Unified Interactive Experience | |
| Backend: FastAPI + WebSocket + ONNX Style Transfer | |
| Frontend: MediaPipe + Canvas (see index.html) | |
| Modes: | |
| 1. Anime Studio - Real-time style transfer (Hayao, Shinkai, Disney, Cyberpunk) | |
| 2. Pose Challenge - Strike poses, get scored | |
| 3. Face Filters - AR masks (glasses, cat, crown, mustache, anime eyes) | |
| 4. Hand Painter - Gesture-controlled drawing | |
| 5. Rock-Paper-Scissors - Play against AI | |
| """ | |
| import os | |
| import io | |
| import base64 | |
| import json | |
| import asyncio | |
| from typing import Optional, Dict | |
| from contextlib import asynccontextmanager | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse | |
| from huggingface_hub import hf_hub_download | |
| # βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS = { | |
| "hayao": { | |
| "repo": "vumichien/AnimeGANv2_Hayao", | |
| "file": "AnimeGANv2_Hayao.onnx", | |
| "name": "Hayao / Ghibli", | |
| }, | |
| "shinkai": { | |
| "repo": "vumichien/AnimeGANv2_Shinkai", | |
| "file": "AnimeGANv2_Shinkai.onnx", | |
| "name": "Shinkai / Your Name", | |
| }, | |
| } | |
| MODELS_DIR = os.environ.get("MODELS_DIR", "./models") | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| STYLE_TRANSFER_SIZE = int(os.environ.get("STYLE_SIZE", "512")) | |
| JPEG_QUALITY = int(os.environ.get("JPEG_QUALITY", "75")) | |
| # βββ Model Manager βββββββββββββββββββββββββββββββββββββββββββββ | |
| class ONNXStyleTransfer: | |
| def __init__(self): | |
| self.sessions: Dict[str, ort.InferenceSession] = {} | |
| self.input_names: Dict[str, str] = {} | |
| self.input_shapes: Dict[str, tuple] = {} | |
| def _load(self, key: str): | |
| if key in self.sessions: | |
| return | |
| cfg = MODELS[key] | |
| local_path = os.path.join(MODELS_DIR, cfg["file"]) | |
| # Download if missing | |
| if not os.path.exists(local_path): | |
| print(f"[Model] Downloading {cfg['name']}...") | |
| hf_hub_download( | |
| repo_id=cfg["repo"], | |
| filename=cfg["file"], | |
| local_dir=MODELS_DIR, | |
| local_dir_use_symlinks=False, | |
| ) | |
| print(f"[Model] Loading {cfg['name']} ONNX...") | |
| sess = ort.InferenceSession( | |
| local_path, | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| inp = sess.get_inputs()[0] | |
| self.sessions[key] = sess | |
| self.input_names[key] = inp.name | |
| # Store expected shape (usually NCHW, sometimes NHWC) | |
| self.input_shapes[key] = tuple(inp.shape) # e.g. (1,3,512,512) or (1,512,512,3) | |
| print(f"[Model] {cfg['name']} ready. Input shape: {inp.shape}") | |
| def stylize(self, frame_bgr: np.ndarray, key: str) -> np.ndarray: | |
| self._load(key) | |
| sess = self.sessions[key] | |
| inp_name = self.input_names[key] | |
| # AnimeGANv2 expects NHWC [1, H, W, 3] at 512x512 | |
| target_h = 512 | |
| target_w = 512 | |
| img = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, (target_w, target_h)) | |
| img = img.astype(np.float32) | |
| # AnimeGANv2 normalization: [-1, 1] | |
| img = img / 127.5 - 1.0 | |
| # NHWC: add batch dimension as the first axis | |
| img = img[np.newaxis, ...] | |
| outputs = sess.run(None, {inp_name: img}) | |
| out = outputs[0] | |
| # Output is NHWC [1, H, W, 3] | |
| out = out[0] | |
| out = (out + 1.0) * 127.5 | |
| out = np.clip(out, 0, 255).astype(np.uint8) | |
| # Resize back to original | |
| h, w = frame_bgr.shape[:2] | |
| out = cv2.resize(out, (w, h)) | |
| return cv2.cvtColor(out, cv2.COLOR_RGB2BGR) | |
| # βββ Global State ββββββββββββββββββββββββββββββββββββββββββββββ | |
| style_engine = ONNXStyleTransfer() | |
| # βββ Utility: Frame β Base64 βββββββββββββββββββββββββββββββββ | |
| def frame_to_base64(frame: np.ndarray) -> str: | |
| _, buf = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY]) | |
| return base64.b64encode(buf).decode("utf-8") | |
| def base64_to_frame(data: str) -> np.ndarray: | |
| buf = base64.b64decode(data) | |
| arr = np.frombuffer(buf, dtype=np.uint8) | |
| return cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| # βββ WebSocket Connection Manager ββββββββββββββββββββββββββββ | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active: Dict[str, WebSocket] = {} | |
| async def connect(self, ws: WebSocket, client_id: str): | |
| await ws.accept() | |
| self.active[client_id] = ws | |
| print(f"[WS] Client {client_id} connected ({len(self.active)} active)") | |
| def disconnect(self, client_id: str): | |
| self.active.pop(client_id, None) | |
| print(f"[WS] Client {client_id} disconnected ({len(self.active)} active)") | |
| async def send_json(self, client_id: str, data: dict): | |
| ws = self.active.get(client_id) | |
| if ws: | |
| await ws.send_json(data) | |
| async def broadcast_json(self, data: dict): | |
| for ws in list(self.active.values()): | |
| try: | |
| await ws.send_json(data) | |
| except Exception: | |
| pass | |
| manager = ConnectionManager() | |
| # βββ FastAPI Lifecycle βββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| print("[Startup] Museum AI Studio backend starting...") | |
| # Pre-download models | |
| for key in MODELS: | |
| try: | |
| style_engine._load(key) | |
| except Exception as e: | |
| print(f"[Startup] Warning: could not preload {key}: {e}") | |
| print("[Startup] Ready.") | |
| yield | |
| print("[Shutdown] Cleaning up...") | |
| app = FastAPI(title="Museum AI Studio", lifespan=lifespan) | |
| # Serve static frontend | |
| app.mount("/static", StaticFiles(directory="static", html=True), name="static") | |
| async def root(): | |
| with open("static/index.html", "r", encoding="utf-8") as f: | |
| return f.read() | |
| async def health(): | |
| loaded = list(style_engine.sessions.keys()) | |
| return {"status": "ok", "loaded_models": loaded, "available_models": list(MODELS.keys())} | |
| # βββ WebSocket: Main Data Channel βββββββββββββββββββββββββββββ | |
| async def websocket_endpoint(ws: WebSocket, client_id: str): | |
| await manager.connect(ws, client_id) | |
| current_mode = "studio" | |
| current_style = "hayao" | |
| paint_history = [] # For hand painter (server-side backup) | |
| try: | |
| while True: | |
| msg = await ws.receive_json() | |
| action = msg.get("action", "frame") | |
| # ββ 1. Mode Switching ββββββββββββββββββββββββββββββ | |
| if action == "set_mode": | |
| current_mode = msg.get("mode", "studio") | |
| await manager.send_json(client_id, {"type": "mode_set", "mode": current_mode}) | |
| continue | |
| # ββ 2. Style Switching βββββββββββββββββββββββββββ | |
| if action == "set_style": | |
| new_style = msg.get("style", "hayao") | |
| if new_style in MODELS: | |
| current_style = new_style | |
| await manager.send_json(client_id, {"type": "style_set", "style": new_style, "name": MODELS[new_style]["name"]}) | |
| else: | |
| await manager.send_json(client_id, {"type": "error", "message": f"Unknown style: {new_style}"}) | |
| continue | |
| # ββ 3. Receive raw frame from browser ββββββββββββββ | |
| if action == "frame": | |
| frame_b64 = msg.get("frame", "") | |
| if not frame_b64: | |
| continue | |
| frame = base64_to_frame(frame_b64) | |
| if frame is None: | |
| continue | |
| h, w = frame.shape[:2] | |
| result = None | |
| # ββ Anime Studio: Style Transfer βββββββββββββββ | |
| if current_mode == "studio": | |
| result = style_engine.stylize(frame, current_style) | |
| # ββ Pose Challenge: Mirror + Scoring βββββββββββ | |
| elif current_mode == "pose": | |
| # Just mirror flip + minimal overlay text | |
| # Full pose detection is client-side via MediaPipe | |
| result = cv2.flip(frame, 1) | |
| cv2.putText(result, "Pose Challenge Mode", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| cv2.putText(result, "Strike a pose!", (10, 60), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1) | |
| # ββ Face Filters: Just pass-through ββββββββββββ | |
| elif current_mode == "face": | |
| result = cv2.flip(frame, 1) | |
| cv2.putText(result, "AR Face Filters Mode", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| cv2.putText(result, "Select filter in UI", (10, 60), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1) | |
| # ββ Hand Painter: Pass-through βββββββββββββββββ | |
| elif current_mode == "painter": | |
| result = cv2.flip(frame, 1) | |
| cv2.putText(result, "Hand Painter Mode", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| cv2.putText(result, "Draw with your hand!", (10, 60), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1) | |
| # ββ Rock-Paper-Scissors: Pass-through ββββββββββ | |
| elif current_mode == "rps": | |
| result = cv2.flip(frame, 1) | |
| cv2.putText(result, "Rock-Paper-Scissors", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) | |
| else: | |
| result = frame | |
| # Send back stylized/processed frame | |
| out_b64 = frame_to_base64(result) | |
| await manager.send_json(client_id, { | |
| "type": "frame", | |
| "mode": current_mode, | |
| "style": current_style if current_mode == "studio" else None, | |
| "frame": out_b64, | |
| }) | |
| # ββ 4. RPS Game Logic (server-side) βββββββββββββββ | |
| elif action == "rps_play": | |
| player_move = msg.get("move") | |
| import random | |
| ai_move = random.choice(["Rock", "Paper", "Scissors"]) | |
| beats = {"Rock": "Scissors", "Paper": "Rock", "Scissors": "Paper"} | |
| if player_move == ai_move: | |
| result_text = "Draw!" | |
| elif beats.get(player_move) == ai_move: | |
| result_text = "You Win!" | |
| else: | |
| result_text = "AI Wins!" | |
| await manager.send_json(client_id, { | |
| "type": "rps_result", | |
| "player": player_move, | |
| "ai": ai_move, | |
| "result": result_text, | |
| }) | |
| except WebSocketDisconnect: | |
| manager.disconnect(client_id) | |
| except Exception as e: | |
| print(f"[WS] Error for {client_id}: {e}") | |
| manager.disconnect(client_id) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) | |