""" Common inference wrapper for MiniCPM-o 4.5. MiniCPM-o's API is `model.chat(msgs=[...], tokenizer=...)` where `msgs` is a list of `{"role": ..., "content": [image, audio, ..., text]}`. This module hides that detail behind `run_inference(model, tokenizer, video, audio, prompt)` so the 6 benchmark eval scripts can share one inference code path. Also runs the compatibility patcher on import so users who haven't run `setup_env.sh` still get a working model. """ from __future__ import annotations import os import subprocess import tempfile from pathlib import Path from typing import Any, List, Optional, Tuple import numpy as np # --------------------------------------------------------------------------- # Apply transformers>=4.52 compatibility patches lazily on import. # Safe to call multiple times; idempotent. # --------------------------------------------------------------------------- def _maybe_patch_once() -> None: try: from patch_minicpmo import ( _find_modeling_file, _find_processing_file, patch_file, patch_processing_file, ) except ImportError: return path = _find_modeling_file() if path is not None: try: patch_file(path) except Exception as exc: # pragma: no cover print(f"[minicpmo] (warn) patch failed: {exc}") proc = _find_processing_file() if proc is not None: try: patch_processing_file(proc) except Exception as exc: # pragma: no cover print(f"[minicpmo] (warn) processing patch failed: {exc}") _maybe_patch_once() def _max_inp_length_for_chat(model: Any, max_new_tokens: int) -> int: """Upper bound for ``model.chat(..., max_inp_length=...)`` (defaults to 8192). Many frames × per-frame image placeholders can exceed 8k text tokens; the processor then truncates ``input_ids`` and image start/end counts diverge, causing ``RuntimeError`` in ``processing_minicpmo._convert``. """ reserve = int(max_new_tokens) + 1024 best = 32768 for cfg in ( getattr(model, "config", None), getattr(getattr(model, "llm", None), "config", None), ): if cfg is None: continue npos = getattr(cfg, "max_position_embeddings", None) if isinstance(npos, int) and npos > 8192: best = min(best, max(npos - reserve, 16384)) return best # --------------------------------------------------------------------------- # Frame / audio loaders # --------------------------------------------------------------------------- def load_video_frames(video_path: str, max_frames: int = 32, fps: float = 1.0) -> List: """Sample PIL RGB frames uniformly from a video. MiniCPM-o expects a list of PIL Images (not a tensor). `fps=1.0, max_frames=32` covers ~32s; longer videos get sparser sampling. """ from PIL import Image import decord vr = decord.VideoReader(video_path, num_threads=1) total_frames = len(vr) video_fps = vr.get_avg_fps() duration = total_frames / max(video_fps, 1e-6) target = max(int(round(fps * duration)), 2) target = min(target, max_frames) target = min(target, total_frames) idx = np.linspace(0, total_frames - 1, target).round().astype(int).tolist() frames = vr.get_batch(idx).asnumpy() return [Image.fromarray(f).convert("RGB") for f in frames] def load_audio_waveform(audio_path: str, target_sr: int = 16000) -> np.ndarray: """Load audio as float32 numpy in [-1, 1] at `target_sr`.""" import librosa y, _ = librosa.load(audio_path, sr=target_sr, mono=True) return y.astype(np.float32) def extract_audio_from_video(video_path: str, target_sr: int = 16000, tmp_dir: Optional[str] = None) -> Optional[str]: """Extract the audio track from a video file to a temp .wav via ffmpeg. Returns the path to the .wav file, or None if the video has no audio track or extraction fails. Caller is responsible for cleanup. """ tmp_dir = tmp_dir or tempfile.mkdtemp(prefix="mo_audio_") out = os.path.join(tmp_dir, "audio.wav") try: subprocess.run( ["ffmpeg", "-y", "-loglevel", "error", "-i", video_path, "-vn", "-ac", "1", "-ar", str(target_sr), out], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=120, ) except Exception: return None if not os.path.isfile(out) or os.path.getsize(out) < 64: return None return out # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_model(model_id: str = "openbmb/MiniCPM-o-4_5", device: str = "cuda", dtype: str = "bfloat16", init_audio: bool = True, attn_implementation: str = "flash_attention_2"): """Load MiniCPM-o model + tokenizer. Returns (model, tokenizer). Tries `attn_implementation` first; if flash_attention_2 isn't installed or the backbone doesn't support it, falls back to sdpa automatically. """ import torch from transformers import AutoModel, AutoTokenizer torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[dtype] def _try_load(attn: str): print(f"[minicpmo] Loading {model_id} (dtype={dtype}, device={device}, " f"init_audio={init_audio}, attn={attn})...") return AutoModel.from_pretrained( model_id, trust_remote_code=True, attn_implementation=attn, torch_dtype=torch_dtype, init_vision=True, init_audio=init_audio, init_tts=False, ) try: model = _try_load(attn_implementation) except Exception as exc: if attn_implementation != "sdpa": print(f"[minicpmo] (warn) {attn_implementation} failed ({exc}); falling back to sdpa.") model = _try_load("sdpa") else: raise model = model.eval().to(device) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) print(f"[minicpmo] Model ready.") return model, tokenizer # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def run_inference( model, tokenizer, video_path: Optional[str], audio_path: Optional[str], prompt: str, max_new_tokens: int = 256, temperature: float = 0.0, max_frames: int = 32, fps: float = 1.0, use_audio_from_video: bool = False, ) -> str: """Run MiniCPM-o chat inference. Args: video_path: optional path to an mp4/etc. file. audio_path: optional path to a wav file. If `use_audio_from_video` is True and `audio_path` is None, we extract audio from the video. prompt: user instruction text. temperature: 0 means greedy. use_audio_from_video: if True, extract audio from the video automatically (useful for WorldSense / Daily-Omni where video has embedded audio but no separate wav is provided). """ content: List[Any] = [] tmp_audio_dir: Optional[str] = None if video_path is not None: frames = load_video_frames(video_path, max_frames=max_frames, fps=fps) content.extend(frames) if audio_path is None and use_audio_from_video and video_path is not None: tmp_audio_dir = tempfile.mkdtemp(prefix="mo_audio_") audio_path = extract_audio_from_video(video_path, tmp_dir=tmp_audio_dir) if audio_path is not None: try: audio = load_audio_waveform(audio_path, target_sr=16000) if audio.size > 0: content.append(audio) except Exception as exc: print(f" [minicpmo] (warn) audio load failed: {exc}") content.append(prompt) msgs = [{"role": "user", "content": content}] # Critical defaults for video understanding (see MiniCPM-o 4.5 HF README # "Chat with Video"): without ``use_image_id=False, max_slice_nums=1`` the # processor treats each frame as an independent HD image, slicing it into # multiple sub-images with per-image ID tokens. That token distribution is # OOD for the video-trained model and produces degenerate output (repeated # training-data fragments, e.g. "the image description of the first image # you see as a brief description ..."). gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=temperature > 0, temperature=temperature if temperature > 0 else 1.0, top_p=0.9 if temperature > 0 else 1.0, max_inp_length=_max_inp_length_for_chat(model, max_new_tokens), use_tts_template=False, enable_thinking=False, ) if video_path is not None: gen_kwargs["use_image_id"] = False gen_kwargs["max_slice_nums"] = 1 if use_audio_from_video and video_path is not None: gen_kwargs.setdefault("omni_mode", True) try: res = model.chat(msgs=msgs, tokenizer=tokenizer, **gen_kwargs) except TypeError: res = model.chat(msgs=msgs, tokenizer=tokenizer) if tmp_audio_dir is not None: import shutil shutil.rmtree(tmp_audio_dir, ignore_errors=True) if isinstance(res, tuple): res = res[0] return str(res).strip()