import os import torch import re import av import uuid import copy import threading import time import shutil from PIL import Image from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration, TextIteratorStreamer from gradio import Server from gradio.data_classes import FileData from fastapi.responses import HTMLResponse import logging # Silence asyncio noise from ZeroGPU cleanup logging.getLogger("asyncio").setLevel(logging.CRITICAL) from starlette.middleware import Middleware import hashlib import base64 import json import spaces from typing import Generator # ---------- Logging configuration ---------- IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp", ".gif"} VIDEO_EXTENSIONS = {".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v"} CLIENT_ID_HEADER = "x-v46-client-id" PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) if os.path.isdir("/data"): DEFAULT_LOG_DIR = "/data/logs" DEFAULT_UPLOAD_LOG_DIR = "/data/uploads" else: DEFAULT_LOG_DIR = os.path.join(PROJECT_ROOT, "logs") DEFAULT_UPLOAD_LOG_DIR = os.path.join(DEFAULT_LOG_DIR, "uploads") LOG_DIR = os.environ.get("V46_LOG_DIR", DEFAULT_LOG_DIR) UPLOAD_LOG_DIR = os.environ.get("V46_UPLOAD_LOG_DIR", DEFAULT_UPLOAD_LOG_DIR) HTTP_LOG_FILE = os.environ.get("V46_HTTP_LOG_FILE", os.path.join(LOG_DIR, "http_requests.jsonl")) RAW_OUTPUT_LOG_FILE = os.environ.get("V46_RAW_OUTPUT_LOG_FILE", os.path.join(LOG_DIR, "raw_model_outputs.jsonl")) LOG_ALL_HTTP_REQUESTS = os.environ.get("V46_LOG_ALL_HTTP_REQUESTS", "0") == "1" HTTP_LOG_LOCK = threading.Lock() RAW_OUTPUT_LOG_LOCK = threading.Lock() # ---------- Logging Middleware ---------- def _headers_from_asgi(raw_headers) -> list[dict]: headers = [] for raw_key, raw_value in raw_headers or []: headers.append({ "name": raw_key.decode("latin-1", errors="replace"), "value": raw_value.decode("latin-1", errors="replace"), }) return headers def _header_value(headers: list[dict], name: str) -> str: name = name.lower() for header in headers: if header["name"].lower() == name: return header["value"] return "" def _body_text(data: bytes, content_type: str) -> str | None: if not data: return "" lower_type = (content_type or "").lower() if "text/" in lower_type or "json" in lower_type or "x-www-form-urlencoded" in lower_type: return data.decode("utf-8", errors="replace") return None def _body_record(data: bytes, content_type: str) -> dict: return { "size": len(data), "sha256": hashlib.sha256(data).hexdigest() if data else "", "base64": base64.b64encode(data).decode("ascii") if data else "", "text": _body_text(data, content_type), } def _append_http_log(record: dict) -> None: os.makedirs(os.path.dirname(HTTP_LOG_FILE), exist_ok=True) line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) with HTTP_LOG_LOCK: with open(HTTP_LOG_FILE, "a", encoding="utf-8") as f: f.write(line + "\n") class HTTPRequestLogMiddleware: def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope.get("type") != "http": await self.app(scope, receive, send) return started = time.time() request_id = uuid.uuid4().hex[:12] scope["v46_request_id"] = request_id request_body = bytearray() response_headers = [] response_body = bytearray() status_code = None async def receive_wrapper(): message = await receive() if message.get("type") == "http.request": chunk = message.get("body", b"") if chunk: request_body.extend(chunk) return message async def send_wrapper(message): nonlocal status_code, response_headers if message.get("type") == "http.response.start": status_code = message.get("status") headers = list(message.get("headers", [])) headers.append((b"x-v46-request-id", request_id.encode("ascii"))) message["headers"] = headers response_headers = _headers_from_asgi(message.get("headers", [])) elif message.get("type") == "http.response.body": chunk = message.get("body", b"") if chunk: response_body.extend(chunk) await send(message) try: await self.app(scope, receive_wrapper, send_wrapper) finally: request_headers = _headers_from_asgi(scope.get("headers", [])) client = scope.get("client") or (None, None) request_content_type = _header_value(request_headers, "content-type") response_content_type = _header_value(response_headers, "content-type") record = { "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(started)), "request_id": request_id, "client_id": _header_value(request_headers, CLIENT_ID_HEADER), "client_host": client[0], "client_port": client[1], "method": scope.get("method"), "path": scope.get("path"), "query_string": (scope.get("query_string") or b"").decode("latin-1", errors="replace"), "http_version": scope.get("http_version"), "request_headers": request_headers, "request_body": _body_record(bytes(request_body), request_content_type), "status_code": status_code, "response_headers": response_headers, "response_body": _body_record(bytes(response_body), response_content_type), "duration_ms": round((time.time() - started) * 1000, 2), } try: _append_http_log(record) except Exception as e: print(f"[http-log] failed to write request log: {e}", flush=True) def http_request_logging_app_kwargs() -> dict: print(f"[model-call-log] writing model calls to {RAW_OUTPUT_LOG_FILE}", flush=True) print(f"[upload-log] persisting inference uploads to {UPLOAD_LOG_DIR}", flush=True) if LOG_ALL_HTTP_REQUESTS: print(f"[http-log] writing all HTTP requests to {HTTP_LOG_FILE}", flush=True) return {"middleware": [Middleware(HTTPRequestLogMiddleware)]} print("[http-log] all-request logging disabled; set V46_LOG_ALL_HTTP_REQUESTS=1 to enable", flush=True) return {} # ---------- Globals & Model Loading ---------- MODEL_ID = "openbmb/MiniCPM-V-4.6" print(f"Loading processor: {MODEL_ID}") processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) print(f"Loading model: {MODEL_ID}") model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, attn_implementation="sdpa", trust_remote_code=True, device_map="cuda" ).eval() # ---------- Logging & Helper Functions ---------- def _append_raw_output_log(record: dict) -> None: os.makedirs(os.path.dirname(RAW_OUTPUT_LOG_FILE), exist_ok=True) line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) with RAW_OUTPUT_LOG_LOCK: with open(RAW_OUTPUT_LOG_FILE, "a", encoding="utf-8") as f: f.write(line + "\n") def _json_safe_for_log(value): if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, dict): return {str(k): _json_safe_for_log(v) for k, v in value.items()} if isinstance(value, (list, tuple)): return [_json_safe_for_log(v) for v in value] if isinstance(value, Image.Image): return {"type": "PIL.Image", "mode": value.mode, "size": list(value.size)} if hasattr(value, "__fspath__"): return os.fspath(value) return repr(value) def log_raw_model_output(**record) -> None: safe_record = {key: _json_safe_for_log(value) for key, value in record.items()} payload = { "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "call_id": uuid.uuid4().hex[:12], **safe_record, } try: _append_raw_output_log(payload) except Exception as e: print(f"[raw-output-log] failed to write raw output log: {e}", flush=True) def load_video(video_path, max_frames=64): """Fast video loading using PyAV timestamp seeking.""" try: container = av.open(video_path) stream = container.streams.video[0] stream.thread_count = 8 duration = stream.duration if duration is None or duration <= 0: frames = [f.to_image() for f in container.decode(video=0)] if len(frames) > max_frames: indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] return [frames[i] for i in indices] return frames indices = [int(i * duration / max_frames) for i in range(max_frames)] frames = [] for ts in indices: container.seek(ts, stream=stream) for frame in container.decode(video=0): frames.append(frame.to_image()) break container.close() return frames except Exception as e: print(f"Error loading video: {e}") return None def _file_path(file_obj) -> str: if isinstance(file_obj, str): return file_obj if isinstance(file_obj, dict): for key in ("path", "name", "orig_name", "url"): value = file_obj.get(key) if isinstance(value, str) and value: return value for key in ("path", "name", "orig_name"): value = getattr(file_obj, key, None) if isinstance(value, str) and value: return value return str(file_obj) def _safe_upload_name(path: str) -> str: name = os.path.basename(path) or "upload.bin" name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") return name[:120] or "upload.bin" def _infer_upload_extension(path: str) -> str: ext = os.path.splitext(path)[1].lower() if ext in IMAGE_EXTENSIONS or ext in VIDEO_EXTENSIONS: return ext try: with Image.open(path) as img: fmt = (img.format or "").lower() image_ext = { "jpeg": ".jpg", "jpg": ".jpg", "png": ".png", "gif": ".gif", "webp": ".webp", "bmp": ".bmp", "tiff": ".tiff", }.get(fmt) if image_ext: return image_ext except Exception: pass try: with open(path, "rb") as f: header = f.read(32) if len(header) >= 12 and header[4:8] == b"ftyp": return ".mp4" if header.startswith(b"\x1aE\xdf\xa3"): return ".webm" except Exception: pass return "" def _file_kind(path: str) -> str | None: ext = os.path.splitext(path)[1].lower() if ext in IMAGE_EXTENSIONS: return "image" if ext in VIDEO_EXTENSIONS: return "video" inferred_ext = _infer_upload_extension(path) if inferred_ext in IMAGE_EXTENSIONS: return "image" if inferred_ext in VIDEO_EXTENSIONS: return "video" return None def persist_uploaded_files(files: list, session_id: str) -> list[str]: """Copy Gradio temp uploads into persistent storage before inference.""" if not files: return [] session_dir = re.sub(r"[^A-Za-z0-9._-]+", "_", session_id or "session").strip("._") dest_dir = os.path.join(UPLOAD_LOG_DIR, session_dir or "session") os.makedirs(dest_dir, exist_ok=True) persisted = [] upload_root = os.path.abspath(UPLOAD_LOG_DIR) for f in files: src = _file_path(f) src_path = os.path.abspath(src) if src_path.startswith(upload_root + os.sep): persisted.append(src_path) continue if not os.path.isfile(src_path): persisted.append(src) continue base = _safe_upload_name(src_path) inferred_ext = _infer_upload_extension(src_path) if inferred_ext and os.path.splitext(base)[1].lower() not in (IMAGE_EXTENSIONS | VIDEO_EXTENSIONS): base = f"{base}{inferred_ext}" stamp = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime()) dest = os.path.join(dest_dir, f"{stamp}-{uuid.uuid4().hex[:8]}-{base}") shutil.copy2(src_path, dest) persisted.append(dest) return persisted def normalize_response_text(text: str) -> str: """Robust conversion of literal \n to newlines while protecting code/LaTeX.""" if not isinstance(text, str) or "\\" not in text: return text protected = {} counter = [0] def _convert(v): v = re.sub(r"(? Generator[str, None, None]: """ Streaming inference endpoint with history support. """ session_id = str(uuid.uuid4()) files = persist_uploaded_files(files or [], session_id) messages = [] # Process history if history: for turn in history: # history turn is [user_text, assistant_text, [optional_file_paths]] user_text = turn[0] assistant_text = turn[1] turn_files = turn[2] if len(turn) > 2 else [] h_content = [] if turn_files: for f in turn_files: f_path = _file_path(f) # In history, we don't have mime_type, so we check extension if _file_kind(f_path) == "video": v_frames = load_video(f_path, max_frames=max_frames) if v_frames: h_content.append({"type": "video", "video": v_frames}) else: h_content.append({"type": "video", "path": f_path}) else: try: img = Image.open(f_path).convert("RGB") h_content.append({"type": "image", "image": img}) except Exception: v_frames = load_video(f_path, max_frames=max_frames) if v_frames: h_content.append({"type": "video", "video": v_frames}) else: h_content.append({"type": "video", "path": f_path}) if user_text: h_content.append({"type": "text", "text": user_text}) if h_content: messages.append({"role": "user", "content": h_content}) if assistant_text: messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}) content = [] if files: for file_path in files: try: if _file_kind(file_path) == "video": v_frames = load_video(file_path, max_frames=max_frames) if v_frames: content.append({"type": "video", "video": v_frames}) else: print(f"Failed to load video: {file_path}") else: img = Image.open(file_path).convert("RGB") content.append({"type": "image", "image": img}) except Exception: # Fallback to manual video frame extraction (bypasses broken torchvision) v_frames = load_video(file_path, max_frames=max_frames) if v_frames: content.append({"type": "video", "video": v_frames}) else: print(f"Failed to load video: {file_path}") if message: content.append({"type": "text", "text": message}) if content: messages.append({"role": "user", "content": content}) # Prepare inputs with Advanced Parameters for MiniCPM-V 4.6 with torch.no_grad(): inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", enable_thinking=thinking_mode, processor_kwargs={ "downsample_mode": "16x", "max_slice_nums": 1 if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else 9, "use_image_id": False if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else True, "videos_kwargs": { "max_num_frames": max_frames, "do_sample_frames": False, # Frames are already sampled by load_video "stack_frames": 1, } } ).to(model.device) for k, v in inputs.items(): if isinstance(v, torch.Tensor) and torch.is_floating_point(v): inputs[k] = v.to(dtype=torch.bfloat16) streamer = TextIteratorStreamer( processor.tokenizer, skip_prompt=True, skip_special_tokens=True, ) sampling = (generation_mode == "Sampling") generate_kwargs = { **inputs, "max_new_tokens": max_new_tokens, "do_sample": sampling, "streamer": streamer, "downsample_mode": "16x" } if sampling: generate_kwargs.update({ "temperature": temperature, "top_p": top_p, "top_k": top_k, }) else: generate_kwargs.update({"num_beams": 1}) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() full_text = "" for new_text in streamer: full_text += new_text yield normalize_response_text(full_text) normalized_text = normalize_response_text(full_text) log_raw_model_output( source="api", session_id=session_id, variant="thinking" if thinking_mode else "instruct", thinking_mode=bool(thinking_mode), generation_mode=generation_mode, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, max_frames=max_frames, user_text=message, user_files=files, history=history or [], model_messages=messages, context_turns=len(history or []), raw_full_text=full_text, normalized_text=normalized_text, ) @demo.get("/", response_class=HTMLResponse) async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return f.read() if __name__ == "__main__": demo.launch( show_error=True, app_kwargs=http_request_logging_app_kwargs() )