import os import torch import re import av import uuid import copy import threading import time import shutil from PIL import Image from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration, TextIteratorStreamer from gradio import Server from gradio.data_classes import FileData from fastapi.responses import HTMLResponse import logging # Silence asyncio noise from ZeroGPU cleanup logging.getLogger("asyncio").setLevel(logging.CRITICAL) from starlette.middleware import Middleware import hashlib import base64 import json # ---------- Logging Middleware ---------- def _headers_from_asgi(raw_headers) -> list[dict]: headers = [] for raw_key, raw_value in raw_headers or []: headers.append({ "name": raw_key.decode("latin-1", errors="replace"), "value": raw_value.decode("latin-1", errors="replace"), }) return headers def _header_value(headers: list[dict], name: str) -> str: name = name.lower() for header in headers: if header["name"].lower() == name: return header["value"] return "" def _body_text(data: bytes, content_type: str) -> str | None: if not data: return "" lower_type = (content_type or "").lower() if "text/" in lower_type or "json" in lower_type or "x-www-form-urlencoded" in lower_type: return data.decode("utf-8", errors="replace") return None def _body_record(data: bytes, content_type: str) -> dict: return { "size": len(data), "sha256": hashlib.sha256(data).hexdigest() if data else "", "base64": base64.b64encode(data).decode("ascii") if data else "", "text": _body_text(data, content_type), } def _append_http_log(record: dict) -> None: os.makedirs(os.path.dirname(HTTP_LOG_FILE), exist_ok=True) line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) with HTTP_LOG_LOCK: with open(HTTP_LOG_FILE, "a", encoding="utf-8") as f: f.write(line + "\n") class HTTPRequestLogMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope.get("type") != "http": await self.app(scope, receive, send) return started = time.time() request_id = uuid.uuid4().hex[:12] request_body = bytearray() response_headers = [] response_body = bytearray() status_code = None async def receive_wrapper(): message = await receive() if message.get("type") == "http.request": chunk = message.get("body", b"") if chunk: request_body.extend(chunk) return message async def send_wrapper(message): nonlocal status_code, response_headers if message.get("type") == "http.response.start": status_code = message.get("status") response_headers = _headers_from_asgi(message.get("headers", [])) elif message.get("type") == "http.response.body": chunk = message.get("body", b"") if chunk: response_body.extend(chunk) await send(message) try: await self.app(scope, receive_wrapper, send_wrapper) finally: request_headers = _headers_from_asgi(scope.get("headers", [])) client = scope.get("client") or (None, None) record = { "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(started)), "request_id": request_id, "client_id": _header_value(request_headers, "x-v46-client-id"), "method": scope.get("method"), "path": scope.get("path"), "status_code": status_code, "duration_ms": round((time.time() - started) * 1000, 2), } try: _append_http_log(record) except Exception as e: print(f"Logging error: {e}") import spaces from typing import Generator # ---------- Globals & Model Loading ---------- MODEL_ID = "openbmb/MiniCPM-V-4.6" print(f"Loading processor: {MODEL_ID}") processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) print(f"Loading model: {MODEL_ID}") model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation="sdpa", trust_remote_code=True, device_map="cuda" ).eval() # ---------- Logging & Helper Functions ---------- PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) LOG_DIR = os.path.join(PROJECT_ROOT, "logs") UPLOAD_LOG_DIR = os.path.join(LOG_DIR, "uploads") HTTP_LOG_FILE = os.path.join(LOG_DIR, "http_requests.jsonl") RAW_OUTPUT_LOG_FILE = os.path.join(LOG_DIR, "raw_model_outputs.jsonl") HTTP_LOG_LOCK = threading.Lock() RAW_OUTPUT_LOG_LOCK = threading.Lock() def _append_raw_output_log(record: dict) -> None: os.makedirs(os.path.dirname(RAW_OUTPUT_LOG_FILE), exist_ok=True) line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) with RAW_OUTPUT_LOG_LOCK: with open(RAW_OUTPUT_LOG_FILE, "a", encoding="utf-8") as f: f.write(line + "\n") def log_raw_model_output(session_id: str, **record) -> None: payload = { "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "session_id": session_id, **record, } try: _append_raw_output_log(payload) except Exception as e: print(f"Logging error: {e}") def load_video(video_path, max_frames=64): """Fast video loading using PyAV timestamp seeking.""" try: container = av.open(video_path) stream = container.streams.video[0] stream.thread_count = 8 duration = stream.duration if duration is None or duration <= 0: frames = [f.to_image() for f in container.decode(video=0)] if len(frames) > max_frames: indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] return [frames[i] for i in indices] return frames indices = [int(i * duration / max_frames) for i in range(max_frames)] frames = [] for ts in indices: container.seek(ts, stream=stream) for frame in container.decode(video=0): frames.append(frame.to_image()) break container.close() return frames except Exception as e: print(f"Error loading video: {e}") return None def persist_uploaded_files(files: list, session_id: str) -> list: """Copy Gradio temp uploads into the project log directory.""" if not files: return [] dest_dir = os.path.join(UPLOAD_LOG_DIR, session_id or "session") os.makedirs(dest_dir, exist_ok=True) persisted = [] for f in files: src = f["path"] if isinstance(f, dict) else f if not os.path.isfile(src): persisted.append(src) continue base = os.path.basename(src) stamp = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime()) dest = os.path.join(dest_dir, f"{stamp}-{uuid.uuid4().hex[:8]}-{base}") shutil.copy2(src, dest) persisted.append(dest) return persisted def normalize_response_text(text: str) -> str: """Robust conversion of literal \n to newlines while protecting code/LaTeX.""" if not isinstance(text, str) or "\\" not in text: return text protected = {} counter = [0] def _convert(v): v = re.sub(r"(? Generator[str, None, None]: """ Streaming inference endpoint with history support. """ session_id = str(uuid.uuid4()) # Persist files in background to avoid blocking user (parity audit) if files: threading.Thread(target=persist_uploaded_files, args=(files, session_id), daemon=True).start() messages = [] # Process history if history: for turn in history: # history turn is [user_text, assistant_text, [optional_file_paths]] user_text = turn[0] assistant_text = turn[1] turn_files = turn[2] if len(turn) > 2 else [] h_content = [] if turn_files: for f_path in turn_files: # In history, we don't have mime_type, so we check extension ext = os.path.splitext(f_path)[1].lower() if ext in {".mp4", ".mkv", ".mov", ".avi", ".webm"}: v_frames = load_video(f_path, max_frames=max_frames) if v_frames: h_content.append({"type": "video", "video": v_frames}) else: h_content.append({"type": "video", "path": f_path}) else: try: img = Image.open(f_path).convert("RGB") h_content.append({"type": "image", "image": img}) except Exception: v_frames = load_video(f_path, max_frames=max_frames) if v_frames: h_content.append({"type": "video", "video": v_frames}) else: h_content.append({"type": "video", "path": f_path}) if user_text: h_content.append({"type": "text", "text": user_text}) if h_content: messages.append({"role": "user", "content": h_content}) if assistant_text: messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}) content = [] if files: for f in files: file_path = f["path"] try: # Try image first img = Image.open(file_path).convert("RGB") content.append({"type": "image", "image": img}) except Exception: # Fallback to manual video frame extraction (bypasses broken torchvision) v_frames = load_video(file_path, max_frames=max_frames) if v_frames: content.append({"type": "video", "video": v_frames}) else: print(f"Failed to load video: {file_path}") if message: content.append({"type": "text", "text": message}) if content: messages.append({"role": "user", "content": content}) # Prepare inputs with Advanced Parameters for MiniCPM-V 4.6 with torch.no_grad(): inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", enable_thinking=thinking_mode, processor_kwargs={ "downsample_mode": "16x", "max_slice_nums": 1 if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else 9, "use_image_id": False if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else True, "videos_kwargs": { "max_num_frames": max_frames, "do_sample_frames": False, # Frames are already sampled by load_video "stack_frames": 1, } } ).to(model.device) for k, v in inputs.items(): if isinstance(v, torch.Tensor) and torch.is_floating_point(v): inputs[k] = v.to(dtype=torch.bfloat16) streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) sampling = (generation_mode == "Sampling") generate_kwargs = { **inputs, "max_new_tokens": max_new_tokens, "do_sample": sampling, "streamer": streamer, "downsample_mode": "16x" } if sampling: generate_kwargs.update({ "temperature": temperature, "top_p": top_p, "top_k": top_k, }) else: generate_kwargs.update({"num_beams": 1}) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() full_text = "" for new_text in streamer: full_text += new_text yield normalize_response_text(full_text) log_raw_model_output(session_id, message=message, response=full_text, variant="thinking" if thinking_mode else "instruct") @demo.get("/", response_class=HTMLResponse) async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return f.read() if __name__ == "__main__": demo.launch( show_error=True, app_kwargs={"middleware": [Middleware(HTTPRequestLogMiddleware)]} )