MiniCPM-Evaluation / scripts /minicpmo_inference.py
Rakancorle11's picture
Upload folder using huggingface_hub
b2c2640 verified
"""
Common inference wrapper for MiniCPM-o 4.5.
MiniCPM-o's API is `model.chat(msgs=[...], tokenizer=...)` where `msgs` is a
list of `{"role": ..., "content": [image, audio, ..., text]}`. This module
hides that detail behind `run_inference(model, tokenizer, video, audio,
prompt)` so the 6 benchmark eval scripts can share one inference code path.
Also runs the compatibility patcher on import so users who haven't run
`setup_env.sh` still get a working model.
"""
from __future__ import annotations
import os
import subprocess
import tempfile
from pathlib import Path
from typing import Any, List, Optional, Tuple
import numpy as np
# ---------------------------------------------------------------------------
# Apply transformers>=4.52 compatibility patches lazily on import.
# Safe to call multiple times; idempotent.
# ---------------------------------------------------------------------------
def _maybe_patch_once() -> None:
try:
from patch_minicpmo import (
_find_modeling_file,
_find_processing_file,
patch_file,
patch_processing_file,
)
except ImportError:
return
path = _find_modeling_file()
if path is not None:
try:
patch_file(path)
except Exception as exc: # pragma: no cover
print(f"[minicpmo] (warn) patch failed: {exc}")
proc = _find_processing_file()
if proc is not None:
try:
patch_processing_file(proc)
except Exception as exc: # pragma: no cover
print(f"[minicpmo] (warn) processing patch failed: {exc}")
_maybe_patch_once()
def _max_inp_length_for_chat(model: Any, max_new_tokens: int) -> int:
"""Upper bound for ``model.chat(..., max_inp_length=...)`` (defaults to 8192).
Many frames × per-frame image placeholders can exceed 8k text tokens; the
processor then truncates ``input_ids`` and image start/end counts diverge,
causing ``RuntimeError`` in ``processing_minicpmo._convert``.
"""
reserve = int(max_new_tokens) + 1024
best = 32768
for cfg in (
getattr(model, "config", None),
getattr(getattr(model, "llm", None), "config", None),
):
if cfg is None:
continue
npos = getattr(cfg, "max_position_embeddings", None)
if isinstance(npos, int) and npos > 8192:
best = min(best, max(npos - reserve, 16384))
return best
# ---------------------------------------------------------------------------
# Frame / audio loaders
# ---------------------------------------------------------------------------
def load_video_frames(video_path: str, max_frames: int = 32,
fps: float = 1.0) -> List:
"""Sample PIL RGB frames uniformly from a video.
MiniCPM-o expects a list of PIL Images (not a tensor). `fps=1.0,
max_frames=32` covers ~32s; longer videos get sparser sampling.
"""
from PIL import Image
import decord
vr = decord.VideoReader(video_path, num_threads=1)
total_frames = len(vr)
video_fps = vr.get_avg_fps()
duration = total_frames / max(video_fps, 1e-6)
target = max(int(round(fps * duration)), 2)
target = min(target, max_frames)
target = min(target, total_frames)
idx = np.linspace(0, total_frames - 1, target).round().astype(int).tolist()
frames = vr.get_batch(idx).asnumpy()
return [Image.fromarray(f).convert("RGB") for f in frames]
def load_audio_waveform(audio_path: str, target_sr: int = 16000) -> np.ndarray:
"""Load audio as float32 numpy in [-1, 1] at `target_sr`."""
import librosa
y, _ = librosa.load(audio_path, sr=target_sr, mono=True)
return y.astype(np.float32)
def extract_audio_from_video(video_path: str, target_sr: int = 16000,
tmp_dir: Optional[str] = None) -> Optional[str]:
"""Extract the audio track from a video file to a temp .wav via ffmpeg.
Returns the path to the .wav file, or None if the video has no audio
track or extraction fails. Caller is responsible for cleanup.
"""
tmp_dir = tmp_dir or tempfile.mkdtemp(prefix="mo_audio_")
out = os.path.join(tmp_dir, "audio.wav")
try:
subprocess.run(
["ffmpeg", "-y", "-loglevel", "error", "-i", video_path,
"-vn", "-ac", "1", "-ar", str(target_sr), out],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
timeout=120,
)
except Exception:
return None
if not os.path.isfile(out) or os.path.getsize(out) < 64:
return None
return out
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model(model_id: str = "openbmb/MiniCPM-o-4_5",
device: str = "cuda",
dtype: str = "bfloat16",
init_audio: bool = True,
attn_implementation: str = "flash_attention_2"):
"""Load MiniCPM-o model + tokenizer. Returns (model, tokenizer).
Tries `attn_implementation` first; if flash_attention_2 isn't installed or
the backbone doesn't support it, falls back to sdpa automatically.
"""
import torch
from transformers import AutoModel, AutoTokenizer
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
"float32": torch.float32}[dtype]
def _try_load(attn: str):
print(f"[minicpmo] Loading {model_id} (dtype={dtype}, device={device}, "
f"init_audio={init_audio}, attn={attn})...")
return AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
attn_implementation=attn,
torch_dtype=torch_dtype,
init_vision=True,
init_audio=init_audio,
init_tts=False,
)
try:
model = _try_load(attn_implementation)
except Exception as exc:
if attn_implementation != "sdpa":
print(f"[minicpmo] (warn) {attn_implementation} failed ({exc}); falling back to sdpa.")
model = _try_load("sdpa")
else:
raise
model = model.eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print(f"[minicpmo] Model ready.")
return model, tokenizer
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def run_inference(
model,
tokenizer,
video_path: Optional[str],
audio_path: Optional[str],
prompt: str,
max_new_tokens: int = 256,
temperature: float = 0.0,
max_frames: int = 32,
fps: float = 1.0,
use_audio_from_video: bool = False,
) -> str:
"""Run MiniCPM-o chat inference.
Args:
video_path: optional path to an mp4/etc. file.
audio_path: optional path to a wav file. If `use_audio_from_video` is
True and `audio_path` is None, we extract audio from the video.
prompt: user instruction text.
temperature: 0 means greedy.
use_audio_from_video: if True, extract audio from the video automatically
(useful for WorldSense / Daily-Omni where video has embedded audio but
no separate wav is provided).
"""
content: List[Any] = []
tmp_audio_dir: Optional[str] = None
if video_path is not None:
frames = load_video_frames(video_path, max_frames=max_frames, fps=fps)
content.extend(frames)
if audio_path is None and use_audio_from_video and video_path is not None:
tmp_audio_dir = tempfile.mkdtemp(prefix="mo_audio_")
audio_path = extract_audio_from_video(video_path, tmp_dir=tmp_audio_dir)
if audio_path is not None:
try:
audio = load_audio_waveform(audio_path, target_sr=16000)
if audio.size > 0:
content.append(audio)
except Exception as exc:
print(f" [minicpmo] (warn) audio load failed: {exc}")
content.append(prompt)
msgs = [{"role": "user", "content": content}]
# Critical defaults for video understanding (see MiniCPM-o 4.5 HF README
# "Chat with Video"): without ``use_image_id=False, max_slice_nums=1`` the
# processor treats each frame as an independent HD image, slicing it into
# multiple sub-images with per-image ID tokens. That token distribution is
# OOD for the video-trained model and produces degenerate output (repeated
# training-data fragments, e.g. "the image description of the first image
# you see as a brief description ...").
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else 1.0,
top_p=0.9 if temperature > 0 else 1.0,
max_inp_length=_max_inp_length_for_chat(model, max_new_tokens),
use_tts_template=False,
enable_thinking=False,
)
if video_path is not None:
gen_kwargs["use_image_id"] = False
gen_kwargs["max_slice_nums"] = 1
if use_audio_from_video and video_path is not None:
gen_kwargs.setdefault("omni_mode", True)
try:
res = model.chat(msgs=msgs, tokenizer=tokenizer, **gen_kwargs)
except TypeError:
res = model.chat(msgs=msgs, tokenizer=tokenizer)
if tmp_audio_dir is not None:
import shutil
shutil.rmtree(tmp_audio_dir, ignore_errors=True)
if isinstance(res, tuple):
res = res[0]
return str(res).strip()