Spaces:
Running on Zero
Running on Zero
| import os | |
| import torch | |
| import re | |
| import av | |
| import uuid | |
| import copy | |
| import threading | |
| import time | |
| import shutil | |
| from PIL import Image | |
| from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration, TextIteratorStreamer | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| import logging | |
| # Silence asyncio noise from ZeroGPU cleanup | |
| logging.getLogger("asyncio").setLevel(logging.CRITICAL) | |
| from starlette.middleware import Middleware | |
| import hashlib | |
| import base64 | |
| import json | |
| # ---------- Logging Middleware ---------- | |
| def _headers_from_asgi(raw_headers) -> list[dict]: | |
| headers = [] | |
| for raw_key, raw_value in raw_headers or []: | |
| headers.append({ | |
| "name": raw_key.decode("latin-1", errors="replace"), | |
| "value": raw_value.decode("latin-1", errors="replace"), | |
| }) | |
| return headers | |
| def _header_value(headers: list[dict], name: str) -> str: | |
| name = name.lower() | |
| for header in headers: | |
| if header["name"].lower() == name: | |
| return header["value"] | |
| return "" | |
| def _body_text(data: bytes, content_type: str) -> str | None: | |
| if not data: return "" | |
| lower_type = (content_type or "").lower() | |
| if "text/" in lower_type or "json" in lower_type or "x-www-form-urlencoded" in lower_type: | |
| return data.decode("utf-8", errors="replace") | |
| return None | |
| def _body_record(data: bytes, content_type: str) -> dict: | |
| return { | |
| "size": len(data), | |
| "sha256": hashlib.sha256(data).hexdigest() if data else "", | |
| "base64": base64.b64encode(data).decode("ascii") if data else "", | |
| "text": _body_text(data, content_type), | |
| } | |
| def _append_http_log(record: dict) -> None: | |
| os.makedirs(os.path.dirname(HTTP_LOG_FILE), exist_ok=True) | |
| line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) | |
| with HTTP_LOG_LOCK: | |
| with open(HTTP_LOG_FILE, "a", encoding="utf-8") as f: | |
| f.write(line + "\n") | |
| class HTTPRequestLogMiddleware: | |
| def __init__(self, app): | |
| self.app = app | |
| async def __call__(self, scope, receive, send): | |
| if scope.get("type") != "http": | |
| await self.app(scope, receive, send) | |
| return | |
| started = time.time() | |
| request_id = uuid.uuid4().hex[:12] | |
| request_body = bytearray() | |
| response_headers = [] | |
| response_body = bytearray() | |
| status_code = None | |
| async def receive_wrapper(): | |
| message = await receive() | |
| if message.get("type") == "http.request": | |
| chunk = message.get("body", b"") | |
| if chunk: request_body.extend(chunk) | |
| return message | |
| async def send_wrapper(message): | |
| nonlocal status_code, response_headers | |
| if message.get("type") == "http.response.start": | |
| status_code = message.get("status") | |
| response_headers = _headers_from_asgi(message.get("headers", [])) | |
| elif message.get("type") == "http.response.body": | |
| chunk = message.get("body", b"") | |
| if chunk: response_body.extend(chunk) | |
| await send(message) | |
| try: | |
| await self.app(scope, receive_wrapper, send_wrapper) | |
| finally: | |
| request_headers = _headers_from_asgi(scope.get("headers", [])) | |
| client = scope.get("client") or (None, None) | |
| record = { | |
| "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(started)), | |
| "request_id": request_id, | |
| "client_id": _header_value(request_headers, "x-v46-client-id"), | |
| "method": scope.get("method"), | |
| "path": scope.get("path"), | |
| "status_code": status_code, | |
| "duration_ms": round((time.time() - started) * 1000, 2), | |
| } | |
| try: | |
| _append_http_log(record) | |
| except Exception as e: | |
| print(f"Logging error: {e}") | |
| import spaces | |
| from typing import Generator | |
| # ---------- Globals & Model Loading ---------- | |
| MODEL_ID = "openbmb/MiniCPM-V-4.6" | |
| print(f"Loading processor: {MODEL_ID}") | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| print(f"Loading model: {MODEL_ID}") | |
| model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| device_map="cuda" | |
| ).eval() | |
| # ---------- Logging & Helper Functions ---------- | |
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) | |
| LOG_DIR = os.path.join(PROJECT_ROOT, "logs") | |
| UPLOAD_LOG_DIR = os.path.join(LOG_DIR, "uploads") | |
| HTTP_LOG_FILE = os.path.join(LOG_DIR, "http_requests.jsonl") | |
| RAW_OUTPUT_LOG_FILE = os.path.join(LOG_DIR, "raw_model_outputs.jsonl") | |
| HTTP_LOG_LOCK = threading.Lock() | |
| RAW_OUTPUT_LOG_LOCK = threading.Lock() | |
| def _append_raw_output_log(record: dict) -> None: | |
| os.makedirs(os.path.dirname(RAW_OUTPUT_LOG_FILE), exist_ok=True) | |
| line = json.dumps(record, ensure_ascii=False, separators=(",", ":")) | |
| with RAW_OUTPUT_LOG_LOCK: | |
| with open(RAW_OUTPUT_LOG_FILE, "a", encoding="utf-8") as f: | |
| f.write(line + "\n") | |
| def log_raw_model_output(session_id: str, **record) -> None: | |
| payload = { | |
| "ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "session_id": session_id, | |
| **record, | |
| } | |
| try: | |
| _append_raw_output_log(payload) | |
| except Exception as e: | |
| print(f"Logging error: {e}") | |
| def load_video(video_path, max_frames=64): | |
| """Fast video loading using PyAV timestamp seeking.""" | |
| try: | |
| container = av.open(video_path) | |
| stream = container.streams.video[0] | |
| stream.thread_count = 8 | |
| duration = stream.duration | |
| if duration is None or duration <= 0: | |
| frames = [f.to_image() for f in container.decode(video=0)] | |
| if len(frames) > max_frames: | |
| indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] | |
| return [frames[i] for i in indices] | |
| return frames | |
| indices = [int(i * duration / max_frames) for i in range(max_frames)] | |
| frames = [] | |
| for ts in indices: | |
| container.seek(ts, stream=stream) | |
| for frame in container.decode(video=0): | |
| frames.append(frame.to_image()) | |
| break | |
| container.close() | |
| return frames | |
| except Exception as e: | |
| print(f"Error loading video: {e}") | |
| return None | |
| def persist_uploaded_files(files: list, session_id: str) -> list: | |
| """Copy Gradio temp uploads into the project log directory.""" | |
| if not files: return [] | |
| dest_dir = os.path.join(UPLOAD_LOG_DIR, session_id or "session") | |
| os.makedirs(dest_dir, exist_ok=True) | |
| persisted = [] | |
| for f in files: | |
| src = f["path"] if isinstance(f, dict) else f | |
| if not os.path.isfile(src): | |
| persisted.append(src) | |
| continue | |
| base = os.path.basename(src) | |
| stamp = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime()) | |
| dest = os.path.join(dest_dir, f"{stamp}-{uuid.uuid4().hex[:8]}-{base}") | |
| shutil.copy2(src, dest) | |
| persisted.append(dest) | |
| return persisted | |
| def normalize_response_text(text: str) -> str: | |
| """Robust conversion of literal \n to newlines while protecting code/LaTeX.""" | |
| if not isinstance(text, str) or "\\" not in text: | |
| return text | |
| protected = {} | |
| counter = [0] | |
| def _convert(v): | |
| v = re.sub(r"(?<!\\)(?:\\r\\n|\\n|\\r){2,}", lambda m: "\n" * len(re.findall(r"\\n|\\r", m.group(0))), v) | |
| v = re.sub(r"(?<!\\)\\r\\n", "\n", v) | |
| v = re.sub(r"(?<!\\)\\n(?![a-zA-Z])", "\n", v) | |
| return v | |
| def _protect(m): | |
| key = f"\x00P{counter[0]}\x00" | |
| counter[0] += 1 | |
| protected[key] = m.group(0) | |
| return key | |
| res = text | |
| res = re.sub(r"```[\s\S]*?```", lambda m: _protect(re.match(r"```[\s\S]*?```", _convert(m.group(0)))), res) # Simplified for parity | |
| res = re.sub(r"`[^`]+`", _protect, res) | |
| res = _convert(res) | |
| for k, v in protected.items(): res = res.replace(k, v) | |
| return res | |
| # ---------- Inference Endpoint ---------- | |
| demo = Server() | |
| def predict( | |
| message: str, | |
| history: list[list] = None, | |
| files: list[FileData] = None, | |
| thinking_mode: bool = True, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.7, | |
| top_p: float = 0.8, | |
| top_k: int = 100, | |
| max_frames: int = 64, | |
| generation_mode: str = "Sampling" | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Streaming inference endpoint with history support. | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| # Persist files in background to avoid blocking user (parity audit) | |
| if files: | |
| threading.Thread(target=persist_uploaded_files, args=(files, session_id), daemon=True).start() | |
| messages = [] | |
| # Process history | |
| if history: | |
| for turn in history: | |
| # history turn is [user_text, assistant_text, [optional_file_paths]] | |
| user_text = turn[0] | |
| assistant_text = turn[1] | |
| turn_files = turn[2] if len(turn) > 2 else [] | |
| h_content = [] | |
| if turn_files: | |
| for f_path in turn_files: | |
| # In history, we don't have mime_type, so we check extension | |
| ext = os.path.splitext(f_path)[1].lower() | |
| if ext in {".mp4", ".mkv", ".mov", ".avi", ".webm"}: | |
| v_frames = load_video(f_path, max_frames=max_frames) | |
| if v_frames: | |
| h_content.append({"type": "video", "video": v_frames}) | |
| else: | |
| h_content.append({"type": "video", "path": f_path}) | |
| else: | |
| try: | |
| img = Image.open(f_path).convert("RGB") | |
| h_content.append({"type": "image", "image": img}) | |
| except Exception: | |
| v_frames = load_video(f_path, max_frames=max_frames) | |
| if v_frames: | |
| h_content.append({"type": "video", "video": v_frames}) | |
| else: | |
| h_content.append({"type": "video", "path": f_path}) | |
| if user_text: | |
| h_content.append({"type": "text", "text": user_text}) | |
| if h_content: | |
| messages.append({"role": "user", "content": h_content}) | |
| if assistant_text: | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_text}]}) | |
| content = [] | |
| if files: | |
| for f in files: | |
| file_path = f["path"] | |
| try: | |
| # Try image first | |
| img = Image.open(file_path).convert("RGB") | |
| content.append({"type": "image", "image": img}) | |
| except Exception: | |
| # Fallback to manual video frame extraction (bypasses broken torchvision) | |
| v_frames = load_video(file_path, max_frames=max_frames) | |
| if v_frames: | |
| content.append({"type": "video", "video": v_frames}) | |
| else: | |
| print(f"Failed to load video: {file_path}") | |
| if message: | |
| content.append({"type": "text", "text": message}) | |
| if content: | |
| messages.append({"role": "user", "content": content}) | |
| # Prepare inputs with Advanced Parameters for MiniCPM-V 4.6 | |
| with torch.no_grad(): | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=thinking_mode, | |
| processor_kwargs={ | |
| "downsample_mode": "16x", | |
| "max_slice_nums": 1 if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else 9, | |
| "use_image_id": False if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else True, | |
| "videos_kwargs": { | |
| "max_num_frames": max_frames, | |
| "do_sample_frames": False, # Frames are already sampled by load_video | |
| "stack_frames": 1, | |
| } | |
| } | |
| ).to(model.device) | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor) and torch.is_floating_point(v): | |
| inputs[k] = v.to(dtype=torch.bfloat16) | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| sampling = (generation_mode == "Sampling") | |
| generate_kwargs = { | |
| **inputs, | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": sampling, | |
| "streamer": streamer, | |
| "downsample_mode": "16x" | |
| } | |
| if sampling: | |
| generate_kwargs.update({ | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| }) | |
| else: | |
| generate_kwargs.update({"num_beams": 1}) | |
| thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| full_text = "" | |
| for new_text in streamer: | |
| full_text += new_text | |
| yield normalize_response_text(full_text) | |
| log_raw_model_output(session_id, message=message, response=full_text, variant="thinking" if thinking_mode else "instruct") | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| show_error=True, | |
| app_kwargs={"middleware": [Middleware(HTTPRequestLogMiddleware)]} | |
| ) | |