userisuser's picture
Align model-call logging for Space app
8757ce8
import os
import torch
import re
import av
import uuid
import copy
import threading
import time
import shutil
from PIL import Image
from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration, TextIteratorStreamer
from gradio import Server
from gradio.data_classes import FileData
from fastapi.responses import HTMLResponse
import logging
# Silence asyncio noise from ZeroGPU cleanup
logging.getLogger("asyncio").setLevel(logging.CRITICAL)
from starlette.middleware import Middleware
import hashlib
import base64
import json
import spaces
from typing import Generator
# ---------- Logging configuration ----------
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp", ".gif"}
VIDEO_EXTENSIONS = {".mp4", ".mkv", ".mov", ".avi", ".flv", ".wmv", ".webm", ".m4v"}
CLIENT_ID_HEADER = "x-v46-client-id"
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
if os.path.isdir("/data"):
DEFAULT_LOG_DIR = "/data/logs"
DEFAULT_UPLOAD_LOG_DIR = "/data/uploads"
else:
DEFAULT_LOG_DIR = os.path.join(PROJECT_ROOT, "logs")
DEFAULT_UPLOAD_LOG_DIR = os.path.join(DEFAULT_LOG_DIR, "uploads")
LOG_DIR = os.environ.get("V46_LOG_DIR", DEFAULT_LOG_DIR)
UPLOAD_LOG_DIR = os.environ.get("V46_UPLOAD_LOG_DIR", DEFAULT_UPLOAD_LOG_DIR)
HTTP_LOG_FILE = os.environ.get("V46_HTTP_LOG_FILE", os.path.join(LOG_DIR, "http_requests.jsonl"))
RAW_OUTPUT_LOG_FILE = os.environ.get("V46_RAW_OUTPUT_LOG_FILE", os.path.join(LOG_DIR, "raw_model_outputs.jsonl"))
LOG_ALL_HTTP_REQUESTS = os.environ.get("V46_LOG_ALL_HTTP_REQUESTS", "0") == "1"
HTTP_LOG_LOCK = threading.Lock()
RAW_OUTPUT_LOG_LOCK = threading.Lock()
# ---------- Logging Middleware ----------
def _headers_from_asgi(raw_headers) -> list[dict]:
headers = []
for raw_key, raw_value in raw_headers or []:
headers.append({
"name": raw_key.decode("latin-1", errors="replace"),
"value": raw_value.decode("latin-1", errors="replace"),
})
return headers
def _header_value(headers: list[dict], name: str) -> str:
name = name.lower()
for header in headers:
if header["name"].lower() == name:
return header["value"]
return ""
def _body_text(data: bytes, content_type: str) -> str | None:
if not data:
return ""
lower_type = (content_type or "").lower()
if "text/" in lower_type or "json" in lower_type or "x-www-form-urlencoded" in lower_type:
return data.decode("utf-8", errors="replace")
return None
def _body_record(data: bytes, content_type: str) -> dict:
return {
"size": len(data),
"sha256": hashlib.sha256(data).hexdigest() if data else "",
"base64": base64.b64encode(data).decode("ascii") if data else "",
"text": _body_text(data, content_type),
}
def _append_http_log(record: dict) -> None:
os.makedirs(os.path.dirname(HTTP_LOG_FILE), exist_ok=True)
line = json.dumps(record, ensure_ascii=False, separators=(",", ":"))
with HTTP_LOG_LOCK:
with open(HTTP_LOG_FILE, "a", encoding="utf-8") as f:
f.write(line + "\n")
class HTTPRequestLogMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope.get("type") != "http":
await self.app(scope, receive, send)
return
started = time.time()
request_id = uuid.uuid4().hex[:12]
scope["v46_request_id"] = request_id
request_body = bytearray()
response_headers = []
response_body = bytearray()
status_code = None
async def receive_wrapper():
message = await receive()
if message.get("type") == "http.request":
chunk = message.get("body", b"")
if chunk: request_body.extend(chunk)
return message
async def send_wrapper(message):
nonlocal status_code, response_headers
if message.get("type") == "http.response.start":
status_code = message.get("status")
headers = list(message.get("headers", []))
headers.append((b"x-v46-request-id", request_id.encode("ascii")))
message["headers"] = headers
response_headers = _headers_from_asgi(message.get("headers", []))
elif message.get("type") == "http.response.body":
chunk = message.get("body", b"")
if chunk: response_body.extend(chunk)
await send(message)
try:
await self.app(scope, receive_wrapper, send_wrapper)
finally:
request_headers = _headers_from_asgi(scope.get("headers", []))
client = scope.get("client") or (None, None)
request_content_type = _header_value(request_headers, "content-type")
response_content_type = _header_value(response_headers, "content-type")
record = {
"ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(started)),
"request_id": request_id,
"client_id": _header_value(request_headers, CLIENT_ID_HEADER),
"client_host": client[0],
"client_port": client[1],
"method": scope.get("method"),
"path": scope.get("path"),
"query_string": (scope.get("query_string") or b"").decode("latin-1", errors="replace"),
"http_version": scope.get("http_version"),
"request_headers": request_headers,
"request_body": _body_record(bytes(request_body), request_content_type),
"status_code": status_code,
"response_headers": response_headers,
"response_body": _body_record(bytes(response_body), response_content_type),
"duration_ms": round((time.time() - started) * 1000, 2),
}
try:
_append_http_log(record)
except Exception as e:
print(f"[http-log] failed to write request log: {e}", flush=True)
def http_request_logging_app_kwargs() -> dict:
print(f"[model-call-log] writing model calls to {RAW_OUTPUT_LOG_FILE}", flush=True)
print(f"[upload-log] persisting inference uploads to {UPLOAD_LOG_DIR}", flush=True)
if LOG_ALL_HTTP_REQUESTS:
print(f"[http-log] writing all HTTP requests to {HTTP_LOG_FILE}", flush=True)
return {"middleware": [Middleware(HTTPRequestLogMiddleware)]}
print("[http-log] all-request logging disabled; set V46_LOG_ALL_HTTP_REQUESTS=1 to enable", flush=True)
return {}
# ---------- Globals & Model Loading ----------
MODEL_ID = "openbmb/MiniCPM-V-4.6"
print(f"Loading processor: {MODEL_ID}")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
print(f"Loading model: {MODEL_ID}")
model = MiniCPMV4_6ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
trust_remote_code=True,
device_map="cuda"
).eval()
# ---------- Logging & Helper Functions ----------
def _append_raw_output_log(record: dict) -> None:
os.makedirs(os.path.dirname(RAW_OUTPUT_LOG_FILE), exist_ok=True)
line = json.dumps(record, ensure_ascii=False, separators=(",", ":"))
with RAW_OUTPUT_LOG_LOCK:
with open(RAW_OUTPUT_LOG_FILE, "a", encoding="utf-8") as f:
f.write(line + "\n")
def _json_safe_for_log(value):
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, dict):
return {str(k): _json_safe_for_log(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_json_safe_for_log(v) for v in value]
if isinstance(value, Image.Image):
return {"type": "PIL.Image", "mode": value.mode, "size": list(value.size)}
if hasattr(value, "__fspath__"):
return os.fspath(value)
return repr(value)
def log_raw_model_output(**record) -> None:
safe_record = {key: _json_safe_for_log(value) for key, value in record.items()}
payload = {
"ts": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"call_id": uuid.uuid4().hex[:12],
**safe_record,
}
try:
_append_raw_output_log(payload)
except Exception as e:
print(f"[raw-output-log] failed to write raw output log: {e}", flush=True)
def load_video(video_path, max_frames=64):
"""Fast video loading using PyAV timestamp seeking."""
try:
container = av.open(video_path)
stream = container.streams.video[0]
stream.thread_count = 8
duration = stream.duration
if duration is None or duration <= 0:
frames = [f.to_image() for f in container.decode(video=0)]
if len(frames) > max_frames:
indices = [int(i * len(frames) / max_frames) for i in range(max_frames)]
return [frames[i] for i in indices]
return frames
indices = [int(i * duration / max_frames) for i in range(max_frames)]
frames = []
for ts in indices:
container.seek(ts, stream=stream)
for frame in container.decode(video=0):
frames.append(frame.to_image())
break
container.close()
return frames
except Exception as e:
print(f"Error loading video: {e}")
return None
def _file_path(file_obj) -> str:
if isinstance(file_obj, str):
return file_obj
if isinstance(file_obj, dict):
for key in ("path", "name", "orig_name", "url"):
value = file_obj.get(key)
if isinstance(value, str) and value:
return value
for key in ("path", "name", "orig_name"):
value = getattr(file_obj, key, None)
if isinstance(value, str) and value:
return value
return str(file_obj)
def _safe_upload_name(path: str) -> str:
name = os.path.basename(path) or "upload.bin"
name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._")
return name[:120] or "upload.bin"
def _infer_upload_extension(path: str) -> str:
ext = os.path.splitext(path)[1].lower()
if ext in IMAGE_EXTENSIONS or ext in VIDEO_EXTENSIONS:
return ext
try:
with Image.open(path) as img:
fmt = (img.format or "").lower()
image_ext = {
"jpeg": ".jpg",
"jpg": ".jpg",
"png": ".png",
"gif": ".gif",
"webp": ".webp",
"bmp": ".bmp",
"tiff": ".tiff",
}.get(fmt)
if image_ext:
return image_ext
except Exception:
pass
try:
with open(path, "rb") as f:
header = f.read(32)
if len(header) >= 12 and header[4:8] == b"ftyp":
return ".mp4"
if header.startswith(b"\x1aE\xdf\xa3"):
return ".webm"
except Exception:
pass
return ""
def _file_kind(path: str) -> str | None:
ext = os.path.splitext(path)[1].lower()
if ext in IMAGE_EXTENSIONS:
return "image"
if ext in VIDEO_EXTENSIONS:
return "video"
inferred_ext = _infer_upload_extension(path)
if inferred_ext in IMAGE_EXTENSIONS:
return "image"
if inferred_ext in VIDEO_EXTENSIONS:
return "video"
return None
def persist_uploaded_files(files: list, session_id: str) -> list[str]:
"""Copy Gradio temp uploads into persistent storage before inference."""
if not files:
return []
session_dir = re.sub(r"[^A-Za-z0-9._-]+", "_", session_id or "session").strip("._")
dest_dir = os.path.join(UPLOAD_LOG_DIR, session_dir or "session")
os.makedirs(dest_dir, exist_ok=True)
persisted = []
upload_root = os.path.abspath(UPLOAD_LOG_DIR)
for f in files:
src = _file_path(f)
src_path = os.path.abspath(src)
if src_path.startswith(upload_root + os.sep):
persisted.append(src_path)
continue
if not os.path.isfile(src_path):
persisted.append(src)
continue
base = _safe_upload_name(src_path)
inferred_ext = _infer_upload_extension(src_path)
if inferred_ext and os.path.splitext(base)[1].lower() not in (IMAGE_EXTENSIONS | VIDEO_EXTENSIONS):
base = f"{base}{inferred_ext}"
stamp = time.strftime("%Y%m%dT%H%M%SZ", time.gmtime())
dest = os.path.join(dest_dir, f"{stamp}-{uuid.uuid4().hex[:8]}-{base}")
shutil.copy2(src_path, dest)
persisted.append(dest)
return persisted
def normalize_response_text(text: str) -> str:
"""Robust conversion of literal \n to newlines while protecting code/LaTeX."""
if not isinstance(text, str) or "\\" not in text:
return text
protected = {}
counter = [0]
def _convert(v):
v = re.sub(r"(?<!\\)(?:\\r\\n|\\n|\\r){2,}", lambda m: "\n" * len(re.findall(r"\\n|\\r", m.group(0))), v)
v = re.sub(r"(?<!\\)\\r\\n", "\n", v)
v = re.sub(r"(?<!\\)\\n(?![a-zA-Z])", "\n", v)
return v
def _protect(m):
key = f"\x00P{counter[0]}\x00"
counter[0] += 1
protected[key] = m.group(0)
return key
res = text
res = re.sub(r"```[\s\S]*?```", lambda m: _protect(re.match(r"```[\s\S]*?```", _convert(m.group(0)))), res) # Simplified for parity
res = re.sub(r"`[^`]+`", _protect, res)
res = _convert(res)
for k, v in protected.items(): res = res.replace(k, v)
return res
# ---------- Inference Endpoint ----------
demo = Server()
@demo.api()
@spaces.GPU(duration=120)
def predict(
message: str,
history: list[list] = None,
files: list[FileData] = None,
thinking_mode: bool = True,
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.8,
top_k: int = 100,
max_frames: int = 64,
generation_mode: str = "Sampling"
) -> Generator[str, None, None]:
"""
Streaming inference endpoint with history support.
"""
session_id = str(uuid.uuid4())
files = persist_uploaded_files(files or [], session_id)
messages = []
# Process history
if history:
for turn in history:
# history turn is [user_text, assistant_text, [optional_file_paths]]
user_text = turn[0]
assistant_text = turn[1]
turn_files = turn[2] if len(turn) > 2 else []
h_content = []
if turn_files:
for f in turn_files:
f_path = _file_path(f)
# In history, we don't have mime_type, so we check extension
if _file_kind(f_path) == "video":
v_frames = load_video(f_path, max_frames=max_frames)
if v_frames:
h_content.append({"type": "video", "video": v_frames})
else:
h_content.append({"type": "video", "path": f_path})
else:
try:
img = Image.open(f_path).convert("RGB")
h_content.append({"type": "image", "image": img})
except Exception:
v_frames = load_video(f_path, max_frames=max_frames)
if v_frames:
h_content.append({"type": "video", "video": v_frames})
else:
h_content.append({"type": "video", "path": f_path})
if user_text:
h_content.append({"type": "text", "text": user_text})
if h_content:
messages.append({"role": "user", "content": h_content})
if assistant_text:
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_text}]})
content = []
if files:
for file_path in files:
try:
if _file_kind(file_path) == "video":
v_frames = load_video(file_path, max_frames=max_frames)
if v_frames:
content.append({"type": "video", "video": v_frames})
else:
print(f"Failed to load video: {file_path}")
else:
img = Image.open(file_path).convert("RGB")
content.append({"type": "image", "image": img})
except Exception:
# Fallback to manual video frame extraction (bypasses broken torchvision)
v_frames = load_video(file_path, max_frames=max_frames)
if v_frames:
content.append({"type": "video", "video": v_frames})
else:
print(f"Failed to load video: {file_path}")
if message:
content.append({"type": "text", "text": message})
if content:
messages.append({"role": "user", "content": content})
# Prepare inputs with Advanced Parameters for MiniCPM-V 4.6
with torch.no_grad():
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
enable_thinking=thinking_mode,
processor_kwargs={
"downsample_mode": "16x",
"max_slice_nums": 1 if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else 9,
"use_image_id": False if any(it.get("type") == "video" for msg in messages for it in msg["content"]) else True,
"videos_kwargs": {
"max_num_frames": max_frames,
"do_sample_frames": False, # Frames are already sampled by load_video
"stack_frames": 1,
}
}
).to(model.device)
for k, v in inputs.items():
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
inputs[k] = v.to(dtype=torch.bfloat16)
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
sampling = (generation_mode == "Sampling")
generate_kwargs = {
**inputs,
"max_new_tokens": max_new_tokens,
"do_sample": sampling,
"streamer": streamer,
"downsample_mode": "16x"
}
if sampling:
generate_kwargs.update({
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
})
else:
generate_kwargs.update({"num_beams": 1})
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
full_text = ""
for new_text in streamer:
full_text += new_text
yield normalize_response_text(full_text)
normalized_text = normalize_response_text(full_text)
log_raw_model_output(
source="api",
session_id=session_id,
variant="thinking" if thinking_mode else "instruct",
thinking_mode=bool(thinking_mode),
generation_mode=generation_mode,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_frames=max_frames,
user_text=message,
user_files=files,
history=history or [],
model_messages=messages,
context_turns=len(history or []),
raw_full_text=full_text,
normalized_text=normalized_text,
)
@demo.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
if __name__ == "__main__":
demo.launch(
show_error=True,
app_kwargs=http_request_logging_app_kwargs()
)