""" Museum AI Studio - Unified Interactive Experience Backend: FastAPI + WebSocket + ONNX Style Transfer Frontend: MediaPipe + Canvas (see index.html) Modes: 1. Anime Studio - Real-time style transfer (Hayao, Shinkai, Disney, Cyberpunk) 2. Pose Challenge - Strike poses, get scored 3. Face Filters - AR masks (glasses, cat, crown, mustache, anime eyes) 4. Hand Painter - Gesture-controlled drawing 5. Rock-Paper-Scissors - Play against AI """ import os import io import base64 import json import asyncio from typing import Optional, Dict from contextlib import asynccontextmanager import cv2 import numpy as np import onnxruntime as ort from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse from huggingface_hub import hf_hub_download # ─── Configuration ───────────────────────────────────────────── MODELS = { "hayao": { "repo": "vumichien/AnimeGANv2_Hayao", "file": "AnimeGANv2_Hayao.onnx", "name": "Hayao / Ghibli", }, "shinkai": { "repo": "vumichien/AnimeGANv2_Shinkai", "file": "AnimeGANv2_Shinkai.onnx", "name": "Shinkai / Your Name", }, } MODELS_DIR = os.environ.get("MODELS_DIR", "./models") os.makedirs(MODELS_DIR, exist_ok=True) STYLE_TRANSFER_SIZE = int(os.environ.get("STYLE_SIZE", "512")) JPEG_QUALITY = int(os.environ.get("JPEG_QUALITY", "75")) # ─── Model Manager ───────────────────────────────────────────── class ONNXStyleTransfer: def __init__(self): self.sessions: Dict[str, ort.InferenceSession] = {} self.input_names: Dict[str, str] = {} self.input_shapes: Dict[str, tuple] = {} def _load(self, key: str): if key in self.sessions: return cfg = MODELS[key] local_path = os.path.join(MODELS_DIR, cfg["file"]) # Download if missing if not os.path.exists(local_path): print(f"[Model] Downloading {cfg['name']}...") hf_hub_download( repo_id=cfg["repo"], filename=cfg["file"], local_dir=MODELS_DIR, local_dir_use_symlinks=False, ) print(f"[Model] Loading {cfg['name']} ONNX...") sess = ort.InferenceSession( local_path, providers=["CPUExecutionProvider"], ) inp = sess.get_inputs()[0] self.sessions[key] = sess self.input_names[key] = inp.name # Store expected shape (usually NCHW, sometimes NHWC) self.input_shapes[key] = tuple(inp.shape) # e.g. (1,3,512,512) or (1,512,512,3) print(f"[Model] {cfg['name']} ready. Input shape: {inp.shape}") def stylize(self, frame_bgr: np.ndarray, key: str) -> np.ndarray: self._load(key) sess = self.sessions[key] inp_name = self.input_names[key] # AnimeGANv2 expects NHWC [1, H, W, 3] at 512x512 target_h = 512 target_w = 512 img = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) img = cv2.resize(img, (target_w, target_h)) img = img.astype(np.float32) # AnimeGANv2 normalization: [-1, 1] img = img / 127.5 - 1.0 # NHWC: add batch dimension as the first axis img = img[np.newaxis, ...] outputs = sess.run(None, {inp_name: img}) out = outputs[0] # Output is NHWC [1, H, W, 3] out = out[0] out = (out + 1.0) * 127.5 out = np.clip(out, 0, 255).astype(np.uint8) # Resize back to original h, w = frame_bgr.shape[:2] out = cv2.resize(out, (w, h)) return cv2.cvtColor(out, cv2.COLOR_RGB2BGR) # ─── Global State ────────────────────────────────────────────── style_engine = ONNXStyleTransfer() # ─── Utility: Frame ↔ Base64 ───────────────────────────────── def frame_to_base64(frame: np.ndarray) -> str: _, buf = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY]) return base64.b64encode(buf).decode("utf-8") def base64_to_frame(data: str) -> np.ndarray: buf = base64.b64decode(data) arr = np.frombuffer(buf, dtype=np.uint8) return cv2.imdecode(arr, cv2.IMREAD_COLOR) # ─── WebSocket Connection Manager ──────────────────────────── class ConnectionManager: def __init__(self): self.active: Dict[str, WebSocket] = {} async def connect(self, ws: WebSocket, client_id: str): await ws.accept() self.active[client_id] = ws print(f"[WS] Client {client_id} connected ({len(self.active)} active)") def disconnect(self, client_id: str): self.active.pop(client_id, None) print(f"[WS] Client {client_id} disconnected ({len(self.active)} active)") async def send_json(self, client_id: str, data: dict): ws = self.active.get(client_id) if ws: await ws.send_json(data) async def broadcast_json(self, data: dict): for ws in list(self.active.values()): try: await ws.send_json(data) except Exception: pass manager = ConnectionManager() # ─── FastAPI Lifecycle ─────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): print("[Startup] Museum AI Studio backend starting...") # Pre-download models for key in MODELS: try: style_engine._load(key) except Exception as e: print(f"[Startup] Warning: could not preload {key}: {e}") print("[Startup] Ready.") yield print("[Shutdown] Cleaning up...") app = FastAPI(title="Museum AI Studio", lifespan=lifespan) # Serve static frontend app.mount("/static", StaticFiles(directory="static", html=True), name="static") @app.get("/", response_class=HTMLResponse) async def root(): with open("static/index.html", "r", encoding="utf-8") as f: return f.read() @app.get("/health") async def health(): loaded = list(style_engine.sessions.keys()) return {"status": "ok", "loaded_models": loaded, "available_models": list(MODELS.keys())} # ─── WebSocket: Main Data Channel ───────────────────────────── @app.websocket("/ws/{client_id}") async def websocket_endpoint(ws: WebSocket, client_id: str): await manager.connect(ws, client_id) current_mode = "studio" current_style = "hayao" paint_history = [] # For hand painter (server-side backup) try: while True: msg = await ws.receive_json() action = msg.get("action", "frame") # ── 1. Mode Switching ────────────────────────────── if action == "set_mode": current_mode = msg.get("mode", "studio") await manager.send_json(client_id, {"type": "mode_set", "mode": current_mode}) continue # ── 2. Style Switching ─────────────────────────── if action == "set_style": new_style = msg.get("style", "hayao") if new_style in MODELS: current_style = new_style await manager.send_json(client_id, {"type": "style_set", "style": new_style, "name": MODELS[new_style]["name"]}) else: await manager.send_json(client_id, {"type": "error", "message": f"Unknown style: {new_style}"}) continue # ── 3. Receive raw frame from browser ────────────── if action == "frame": frame_b64 = msg.get("frame", "") if not frame_b64: continue frame = base64_to_frame(frame_b64) if frame is None: continue h, w = frame.shape[:2] result = None # ── Anime Studio: Style Transfer ─────────────── if current_mode == "studio": result = style_engine.stylize(frame, current_style) # ── Pose Challenge: Mirror + Scoring ─────────── elif current_mode == "pose": # Just mirror flip + minimal overlay text # Full pose detection is client-side via MediaPipe result = cv2.flip(frame, 1) cv2.putText(result, "Pose Challenge Mode", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) cv2.putText(result, "Strike a pose!", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1) # ── Face Filters: Just pass-through ──────────── elif current_mode == "face": result = cv2.flip(frame, 1) cv2.putText(result, "AR Face Filters Mode", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) cv2.putText(result, "Select filter in UI", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1) # ── Hand Painter: Pass-through ───────────────── elif current_mode == "painter": result = cv2.flip(frame, 1) cv2.putText(result, "Hand Painter Mode", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) cv2.putText(result, "Draw with your hand!", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 1) # ── Rock-Paper-Scissors: Pass-through ────────── elif current_mode == "rps": result = cv2.flip(frame, 1) cv2.putText(result, "Rock-Paper-Scissors", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) else: result = frame # Send back stylized/processed frame out_b64 = frame_to_base64(result) await manager.send_json(client_id, { "type": "frame", "mode": current_mode, "style": current_style if current_mode == "studio" else None, "frame": out_b64, }) # ── 4. RPS Game Logic (server-side) ─────────────── elif action == "rps_play": player_move = msg.get("move") import random ai_move = random.choice(["Rock", "Paper", "Scissors"]) beats = {"Rock": "Scissors", "Paper": "Rock", "Scissors": "Paper"} if player_move == ai_move: result_text = "Draw!" elif beats.get(player_move) == ai_move: result_text = "You Win!" else: result_text = "AI Wins!" await manager.send_json(client_id, { "type": "rps_result", "player": player_move, "ai": ai_move, "result": result_text, }) except WebSocketDisconnect: manager.disconnect(client_id) except Exception as e: print(f"[WS] Error for {client_id}: {e}") manager.disconnect(client_id) if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)