Spaces:
Running on Zero
Running on Zero
| import os | |
| import torch | |
| import re | |
| import av | |
| import uuid | |
| import copy | |
| import threading | |
| import time | |
| import shutil | |
| from PIL import Image | |
| from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration, TextIteratorStreamer | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| import logging | |
| # Silence asyncio noise from ZeroGPU cleanup | |
| logging.getLogger("asyncio").setLevel(logging.CRITICAL) | |
| from starlette.middleware import Middleware | |
| import hashlib | |
| import base64 | |
| import json | |
| import spaces | |
| from typing import Generator | |
| # ---------- Logging configuration ---------- | |
| IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp", ".gif"} | |
| VIDEO_EXTENSIONS = {".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v"} | |
| CLIENT_ID_HEADER = "x-v46-client-id" | |
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| if os.path.isdir("/data"): | |
| DEFAULT_LOG_DIR = "/data/logs" | |
| DEFAULT_UPLOAD_LOG_DIR = "/data/uploads" | |
| else: | |
| DEFAULT_LOG_DIR = os.path.join(PROJECT_ROOT, "logs") | |
| DEFAULT_UPLOAD_LOG_DIR = os.path.join(DEFAULT_LOG_DIR, "uploads") | |
| LOG_DIR = os.environ.get("V46_LOG_DIR", DEFAULT_LOG_DIR) | |
| UPLOAD_LOG_DIR = os.environ.get("V46_UPLOAD_LOG_DIR", DEFAULT_UPLOAD_LOG_DIR) | |
| HTTP_LOG_FILE = os.environ.get("V46_HTTP_LOG_FILE", os.path.join(LOG_DIR, "http_requests.jsonl")) | |
| RAW_OUTPUT_LOG_FILE = os.environ.get("V46_RAW_OUTPUT_LOG_FILE", os.path.join(LOG_DIR, "raw_model_outputs.jsonl")) | |
| LOG_ALL_HTTP_REQUESTS = os.environ.get("V46_LOG_ALL_HTTP_REQUESTS", "0") == "1" | |
| HTTP_LOG_LOCK = threading.Lock() | |
| RAW_OUTPUT_LOG_LOCK = threading.Lock() | |
| # ---------- Logging Middleware ---------- | |
| def _headers_from_asgi(raw_headers) -> list[dict]: | |
| headers = [] | |
| for raw_key, raw_value in raw_headers or []: | |
| headers.append({ | |
| "name": raw_key.decode("latin-1", errors="replace"), | |
| "value": raw_value.decode("latin-1", errors="replace"), | |
| }) | |
| return headers | |
| def _header_value(headers: list[dict], name: str) -> str: | |
| name = name.lower() | |
| for header in headers: | |
| if header["name"].lower() == name: | |
| return header["value"] | |
| return "" | |
| def _body_text(data: bytes, content_type: str) -> str | None: | |
| if not data: | |
| return "" | |
| lower_type = (content_type or "").lower() | |
| if "text/" in lower_type or "json" in lower_type or "x-www-form-urlencoded" in lower_type: | |
| return data.decode("utf-8", errors="replace") | |
| return None | |
| def _body_record(data: bytes, content_type: str) -> dict: | |
| return { | |
| "size": len(data), | |
| "sha256": hashlib.sha256(data).hexdigest() if data else "", | |
| "base64": base64.b64encode(data).decode("ascii") if data else "", | |
| "text": _body_text(data, content_type), | |
| } | |
| def _append_http_log(record: dict) -> None: | |
| os.makedirs(os.path.dirname(HTTP_LOG_FILE), exist_ok=True) | |
| line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) | |
| with HTTP_LOG_LOCK: | |
| with open(HTTP_LOG_FILE, "a", encoding="utf-8") as f: | |
| f.write(line + "\n") | |
| class HTTPRequestLogMiddleware: | |
| def __init__(self, app): | |
| self.app = app | |
| async def __call__(self, scope, receive, send): | |
| if scope.get("type") != "http": | |
| await self.app(scope, receive, send) | |
| return | |
| started = time.time() | |
| request_id = uuid.uuid4().hex[:12] | |
| scope["v46_request_id"] = request_id | |
| request_body = bytearray() | |
| response_headers = [] | |
| response_body = bytearray() | |
| status_code = None | |
| async def receive_wrapper(): | |
| message = await receive() | |
| if message.get("type") == "http.request": | |
| chunk = message.get("body", b"") | |
| if chunk: request_body.extend(chunk) | |
| return message | |
| async def send_wrapper(message): | |
| nonlocal status_code, response_headers | |
| if message.get("type") == "http.response.start": | |
| status_code = message.get("status") | |
| headers = list(message.get("headers", [])) | |
| headers.append((b"x-v46-request-id", request_id.encode("ascii"))) | |
| message["headers"] = headers | |
| response_headers = _headers_from_asgi(message.get("headers", [])) | |
| elif message.get("type") == "http.response.body": | |
| chunk = message.get("body", b"") | |
| if chunk: response_body.extend(chunk) | |
| await send(message) | |
| try: | |
| await self.app(scope, receive_wrapper, send_wrapper) | |
| finally: | |
| request_headers = _headers_from_asgi(scope.get("headers", [])) | |
| client = scope.get("client") or (None, None) | |
| request_content_type = _header_value(request_headers, "content-type") | |
| response_content_type = _header_value(response_headers, "content-type") | |
| record = { | |
| "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(started)), | |
| "request_id": request_id, | |
| "client_id": _header_value(request_headers, CLIENT_ID_HEADER), | |
| "client_host": client[0], | |
| "client_port": client[1], | |
| "method": scope.get("method"), | |
| "path": scope.get("path"), | |
| "query_string": (scope.get("query_string") or b"").decode("latin-1", errors="replace"), | |
| "http_version": scope.get("http_version"), | |
| "request_headers": request_headers, | |
| "request_body": _body_record(bytes(request_body), request_content_type), | |
| "status_code": status_code, | |
| "response_headers": response_headers, | |
| "response_body": _body_record(bytes(response_body), response_content_type), | |
| "duration_ms": round((time.time() - started) * 1000, 2), | |
| } | |
| try: | |
| _append_http_log(record) | |
| except Exception as e: | |
| print(f"[http-log] failed to write request log: {e}", flush=True) | |
| def http_request_logging_app_kwargs() -> dict: | |
| print(f"[model-call-log] writing model calls to {RAW_OUTPUT_LOG_FILE}", flush=True) | |
| print(f"[upload-log] persisting inference uploads to {UPLOAD_LOG_DIR}", flush=True) | |
| if LOG_ALL_HTTP_REQUESTS: | |
| print(f"[http-log] writing all HTTP requests to {HTTP_LOG_FILE}", flush=True) | |
| return {"middleware": [Middleware(HTTPRequestLogMiddleware)]} | |
| print("[http-log] all-request logging disabled; set V46_LOG_ALL_HTTP_REQUESTS=1 to enable", flush=True) | |
| return {} | |
| # ---------- Globals & Model Loading ---------- | |
| MODEL_ID = "openbmb/MiniCPM-V-4.6" | |
| print(f"Loading processor: {MODEL_ID}") | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| print(f"Loading model: {MODEL_ID}") | |
| model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| device_map="cuda" | |
| ).eval() | |
| # ---------- Logging & Helper Functions ---------- | |
| def _append_raw_output_log(record: dict) -> None: | |
| os.makedirs(os.path.dirname(RAW_OUTPUT_LOG_FILE), exist_ok=True) | |
| line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) | |
| with RAW_OUTPUT_LOG_LOCK: | |
| with open(RAW_OUTPUT_LOG_FILE, "a", encoding="utf-8") as f: | |
| f.write(line + "\n") | |
| def _json_safe_for_log(value): | |
| if isinstance(value, (str, int, float, bool)) or value is None: | |
| return value | |
| if isinstance(value, dict): | |
| return {str(k): _json_safe_for_log(v) for k, v in value.items()} | |
| if isinstance(value, (list, tuple)): | |
| return [_json_safe_for_log(v) for v in value] | |
| if isinstance(value, Image.Image): | |
| return {"type": "PIL.Image", "mode": value.mode, "size": list(value.size)} | |
| if hasattr(value, "__fspath__"): | |
| return os.fspath(value) | |
| return repr(value) | |
| def log_raw_model_output(**record) -> None: | |
| safe_record = {key: _json_safe_for_log(value) for key, value in record.items()} | |
| payload = { | |
| "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "call_id": uuid.uuid4().hex[:12], | |
| **safe_record, | |
| } | |
| try: | |
| _append_raw_output_log(payload) | |
| except Exception as e: | |
| print(f"[raw-output-log] failed to write raw output log: {e}", flush=True) | |
| def load_video(video_path, max_frames=64): | |
| """Fast video loading using PyAV timestamp seeking.""" | |
| try: | |
| container = av.open(video_path) | |
| stream = container.streams.video[0] | |
| stream.thread_count = 8 | |
| duration = stream.duration | |
| if duration is None or duration <= 0: | |
| frames = [f.to_image() for f in container.decode(video=0)] | |
| if len(frames) > max_frames: | |
| indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] | |
| return [frames[i] for i in indices] | |
| return frames | |
| indices = [int(i * duration / max_frames) for i in range(max_frames)] | |
| frames = [] | |
| for ts in indices: | |
| container.seek(ts, stream=stream) | |
| for frame in container.decode(video=0): | |
| frames.append(frame.to_image()) | |
| break | |
| container.close() | |
| return frames | |
| except Exception as e: | |
| print(f"Error loading video: {e}") | |
| return None | |
| def _file_path(file_obj) -> str: | |
| if isinstance(file_obj, str): | |
| return file_obj | |
| if isinstance(file_obj, dict): | |
| for key in ("path", "name", "orig_name", "url"): | |
| value = file_obj.get(key) | |
| if isinstance(value, str) and value: | |
| return value | |
| for key in ("path", "name", "orig_name"): | |
| value = getattr(file_obj, key, None) | |
| if isinstance(value, str) and value: | |
| return value | |
| return str(file_obj) | |
| def _safe_upload_name(path: str) -> str: | |
| name = os.path.basename(path) or "upload.bin" | |
| name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") | |
| return name[:120] or "upload.bin" | |
| def _infer_upload_extension(path: str) -> str: | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext in IMAGE_EXTENSIONS or ext in VIDEO_EXTENSIONS: | |
| return ext | |
| try: | |
| with Image.open(path) as img: | |
| fmt = (img.format or "").lower() | |
| image_ext = { | |
| "jpeg": ".jpg", | |
| "jpg": ".jpg", | |
| "png": ".png", | |
| "gif": ".gif", | |
| "webp": ".webp", | |
| "bmp": ".bmp", | |
| "tiff": ".tiff", | |
| }.get(fmt) | |
| if image_ext: | |
| return image_ext | |
| except Exception: | |
| pass | |
| try: | |
| with open(path, "rb") as f: | |
| header = f.read(32) | |
| if len(header) >= 12 and header[4:8] == b"ftyp": | |
| return ".mp4" | |
| if header.startswith(b"\x1aE\xdf\xa3"): | |
| return ".webm" | |
| except Exception: | |
| pass | |
| return "" | |
| def _file_kind(path: str) -> str | None: | |
| ext = os.path.splitext(path)[1].lower() | |
| if ext in IMAGE_EXTENSIONS: | |
| return "image" | |
| if ext in VIDEO_EXTENSIONS: | |
| return "video" | |
| inferred_ext = _infer_upload_extension(path) | |
| if inferred_ext in IMAGE_EXTENSIONS: | |
| return "image" | |
| if inferred_ext in VIDEO_EXTENSIONS: | |
| return "video" | |
| return None | |
| def persist_uploaded_files(files: list, session_id: str) -> list[str]: | |
| """Copy Gradio temp uploads into persistent storage before inference.""" | |
| if not files: | |
| return [] | |
| session_dir = re.sub(r"[^A-Za-z0-9._-]+", "_", session_id or "session").strip("._") | |
| dest_dir = os.path.join(UPLOAD_LOG_DIR, session_dir or "session") | |
| os.makedirs(dest_dir, exist_ok=True) | |
| persisted = [] | |
| upload_root = os.path.abspath(UPLOAD_LOG_DIR) | |
| for f in files: | |
| src = _file_path(f) | |
| src_path = os.path.abspath(src) | |
| if src_path.startswith(upload_root + os.sep): | |
| persisted.append(src_path) | |
| continue | |
| if not os.path.isfile(src_path): | |
| persisted.append(src) | |
| continue | |
| base = _safe_upload_name(src_path) | |
| inferred_ext = _infer_upload_extension(src_path) | |
| if inferred_ext and os.path.splitext(base)[1].lower() not in (IMAGE_EXTENSIONS | VIDEO_EXTENSIONS): | |
| base = f"{base}{inferred_ext}" | |
| stamp = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime()) | |
| dest = os.path.join(dest_dir, f"{stamp}-{uuid.uuid4().hex[:8]}-{base}") | |
| shutil.copy2(src_path, dest) | |
| persisted.append(dest) | |
| return persisted | |
| def normalize_response_text(text: str) -> str: | |
| """Robust conversion of literal \n to newlines while protecting code/LaTeX.""" | |
| if not isinstance(text, str) or "\\" not in text: | |
| return text | |
| protected = {} | |
| counter = [0] | |
| def _convert(v): | |
| v = re.sub(r"(?<!\\)(?:\\r\\n|\\n|\\r){2,}", lambda m: "\n" * len(re.findall(r"\\n|\\r", m.group(0))), v) | |
| v = re.sub(r"(?<!\\)\\r\\n", "\n", v) | |
| v = re.sub(r"(?<!\\)\\n(?![a-zA-Z])", "\n", v) | |
| return v | |
| def _protect(m): | |
| key = f"\x00P{counter[0]}\x00" | |
| counter[0] += 1 | |
| protected[key] = m.group(0) | |
| return key | |
| res = text | |
| res = re.sub(r"```[\s\S]*?```", lambda m: _protect(re.match(r"```[\s\S]*?```", _convert(m.group(0)))), res) # Simplified for parity | |
| res = re.sub(r"`[^`]+`", _protect, res) | |
| res = _convert(res) | |
| for k, v in protected.items(): res = res.replace(k, v) | |
| return res | |
| # ---------- Inference Endpoint ---------- | |
| demo = Server() | |
| def predict( | |
| message: str, | |
| history: list[list] = None, | |
| files: list[FileData] = None, | |
| thinking_mode: bool = True, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| top_p: float = 0.8, | |
| top_k: int = 100, | |
| max_frames: int = 64, | |
| generation_mode: str = "Sampling" | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Streaming inference endpoint with history support. | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| files = persist_uploaded_files(files or [], session_id) | |
| messages = [] | |
| # Process history | |
| if history: | |
| for turn in history: | |
| # history turn is [user_text, assistant_text, [optional_file_paths]] | |
| user_text = turn[0] | |
| assistant_text = turn[1] | |
| turn_files = turn[2] if len(turn) > 2 else [] | |
| h_content = [] | |
| if turn_files: | |
| for f in turn_files: | |
| f_path = _file_path(f) | |
| # In history, we don't have mime_type, so we check extension | |
| if _file_kind(f_path) == "video": | |
| v_frames = load_video(f_path, max_frames=max_frames) | |
| if v_frames: | |
| h_content.append({"type": "video", "video": v_frames}) | |
| else: | |
| h_content.append({"type": "video", "path": f_path}) | |
| else: | |
| try: | |
| img = Image.open(f_path).convert("RGB") | |
| h_content.append({"type": "image", "image": img}) | |
| except Exception: | |
| v_frames = load_video(f_path, max_frames=max_frames) | |
| if v_frames: | |
| h_content.append({"type": "video", "video": v_frames}) | |
| else: | |
| h_content.append({"type": "video", "path": f_path}) | |
| if user_text: | |
| h_content.append({"type": "text", "text": user_text}) | |
| if h_content: | |
| messages.append({"role": "user", "content": h_content}) | |
| if assistant_text: | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}) | |
| content = [] | |
| if files: | |
| for file_path in files: | |
| try: | |
| if _file_kind(file_path) == "video": | |
| v_frames = load_video(file_path, max_frames=max_frames) | |
| if v_frames: | |
| content.append({"type": "video", "video": v_frames}) | |
| else: | |
| print(f"Failed to load video: {file_path}") | |
| else: | |
| img = Image.open(file_path).convert("RGB") | |
| content.append({"type": "image", "image": img}) | |
| except Exception: | |
| # Fallback to manual video frame extraction (bypasses broken torchvision) | |
| v_frames = load_video(file_path, max_frames=max_frames) | |
| if v_frames: | |
| content.append({"type": "video", "video": v_frames}) | |
| else: | |
| print(f"Failed to load video: {file_path}") | |
| if message: | |
| content.append({"type": "text", "text": message}) | |
| if content: | |
| messages.append({"role": "user", "content": content}) | |
| # Prepare inputs with Advanced Parameters for MiniCPM-V 4.6 | |
| with torch.no_grad(): | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=thinking_mode, | |
| processor_kwargs={ | |
| "downsample_mode": "16x", | |
| "max_slice_nums": 1 if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else 9, | |
| "use_image_id": False if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else True, | |
| "videos_kwargs": { | |
| "max_num_frames": max_frames, | |
| "do_sample_frames": False, # Frames are already sampled by load_video | |
| "stack_frames": 1, | |
| } | |
| } | |
| ).to(model.device) | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor) and torch.is_floating_point(v): | |
| inputs[k] = v.to(dtype=torch.bfloat16) | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| sampling = (generation_mode == "Sampling") | |
| generate_kwargs = { | |
| **inputs, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": sampling, | |
| "streamer": streamer, | |
| "downsample_mode": "16x" | |
| } | |
| if sampling: | |
| generate_kwargs.update({ | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| }) | |
| else: | |
| generate_kwargs.update({"num_beams": 1}) | |
| thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| full_text = "" | |
| for new_text in streamer: | |
| full_text += new_text | |
| yield normalize_response_text(full_text) | |
| normalized_text = normalize_response_text(full_text) | |
| log_raw_model_output( | |
| source="api", | |
| session_id=session_id, | |
| variant="thinking" if thinking_mode else "instruct", | |
| thinking_mode=bool(thinking_mode), | |
| generation_mode=generation_mode, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| max_frames=max_frames, | |
| user_text=message, | |
| user_files=files, | |
| history=history or [], | |
| model_messages=messages, | |
| context_turns=len(history or []), | |
| raw_full_text=full_text, | |
| normalized_text=normalized_text, | |
| ) | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| show_error=True, | |
| app_kwargs=http_request_logging_app_kwargs() | |
| ) | |