MiniCPM-Evaluation / scripts /patch_minicpmo.py
Rakancorle11's picture
Upload folder using huggingface_hub
b2c2640 verified
#!/usr/bin/env python3
"""Patch MiniCPM-o 4.5 custom code in the Hugging Face modules cache.
``modeling_minicpmo.py`` (transformers >= 4.52):
1. `WhisperEncoderLayer.forward` unpacks 3 values from `self.self_attn(...)`,
but new `WhisperAttention.forward` returns 2 values.
2. `prepare_inputs_for_generation` reads `past_key_values.seen_tokens`, which
was removed from `DynamicCache`.
3. `chat()` force-sets ``use_tts_template = True`` whenever audio is in the
``content`` list. That appends ``<|tts_bos|>`` to the assistant prefix
and the model then generates **audio (TTS codec) ids**; decoded as text
they look like ``<think>`` floods / gibberish. We want audio-in +
**text-out** for benchmark eval, so respect the caller's kwarg instead.
``processing_minicpmo.py``:
4. `_convert` used ``max(len(image_start_idx), len(image_end_idx))`` when
building ``image_bounds``; after ``max_length`` truncation start/end counts
can differ by one and ``torch.hstack`` raises (common with many video
frames under the default ``chat(..., max_inp_length=8192)``). Use ``min``.
Idempotent. Also downloads model code on demand so files exist before patching.
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
MODEL_ID = "openbmb/MiniCPM-o-4_5"
def _find_modeling_file() -> Path | None:
"""Locate the cached modeling_minicpmo.py (matches HF's module dir naming)."""
home = Path(os.path.expanduser("~"))
candidates = [
home / ".cache" / "huggingface" / "modules" / "transformers_modules",
]
hits: list[Path] = []
for root in candidates:
if not root.exists():
continue
for p in root.rglob("modeling_minicpmo.py"):
hits.append(p)
if not hits:
return None
# Prefer the deepest (snapshot-hashed) one.
hits.sort(key=lambda p: len(p.parts), reverse=True)
return hits[0]
def _find_processing_file() -> Path | None:
"""``processing_minicpmo.py`` lives next to the cached ``modeling_minicpmo.py``."""
modeling = _find_modeling_file()
if modeling is None:
return None
proc = modeling.parent / "processing_minicpmo.py"
return proc if proc.is_file() else None
def _download_model_code() -> None:
"""Force HF to download MiniCPM-o's custom code so the file is cached.
We only need the Python files + config (not weights) for patching. We use
`hf_hub_download` for the individual code files to avoid fetching the
multi-GB safetensors shards just to patch a .py file.
"""
try:
from huggingface_hub import hf_hub_download
except ImportError:
print("[patch] huggingface_hub not installed; skipping auto-download.")
return
for fn in [
"config.json",
"configuration_minicpm.py",
"modeling_minicpmo.py",
"modeling_navit_siglip.py",
"processing_minicpmo.py",
"resampler.py",
"utils.py",
]:
try:
hf_hub_download(repo_id=MODEL_ID, filename=fn)
except Exception as exc:
# Some files may not exist in every revision; that's fine.
print(f"[patch] (warn) could not fetch {fn}: {exc}")
def patch_whisper_unpack(text: str) -> tuple[str, bool]:
"""Fix #1: WhisperAttention now returns 2 values, not 3."""
OLD = (
" hidden_states, attn_weights, past_key_values = self.self_attn(\n"
" hidden_states=hidden_states,\n"
" attention_mask=attention_mask,\n"
" layer_head_mask=layer_head_mask,\n"
" output_attentions=output_attentions,\n"
" past_key_value=past_key_values,\n"
" )"
)
NEW = (
" _attn_out = self.self_attn(\n"
" hidden_states=hidden_states,\n"
" attention_mask=attention_mask,\n"
" layer_head_mask=layer_head_mask,\n"
" output_attentions=output_attentions,\n"
" past_key_value=past_key_values,\n"
" )\n"
" if len(_attn_out) == 3:\n"
" hidden_states, attn_weights, past_key_values = _attn_out\n"
" else:\n"
" hidden_states, attn_weights = _attn_out"
)
if NEW.split("\n", 1)[0] in text:
return text, False # already patched
if OLD not in text:
return text, False # not applicable (different revision?)
return text.replace(OLD, NEW), True
def patch_seen_tokens(text: str) -> tuple[str, bool]:
"""Fix #2: DynamicCache.seen_tokens was removed in newer transformers."""
OLD = (
" cache_length = past_key_values.get_seq_length()\n"
" past_length = past_key_values.seen_tokens"
)
NEW = (
" cache_length = past_key_values.get_seq_length()\n"
" past_length = getattr(past_key_values, \"seen_tokens\", cache_length)"
)
if 'getattr(past_key_values, "seen_tokens"' in text:
return text, False # already patched
if OLD not in text:
return text, False
return text.replace(OLD, NEW), True
def patch_chat_force_tts_template(text: str) -> tuple[str, bool]:
"""Fix #3: don't force ``use_tts_template=True`` on audio-containing content.
MiniCPM-o's ``chat()`` assumes "audio in implies TTS audio out". For MCQ /
freetext eval we want a text answer; the caller's ``use_tts_template`` kwarg
(default ``False``) must win so the assistant prefix doesn't get
``<|tts_bos|>`` appended (which causes the LM to emit audio-codec ids that
look like ``<think>`` repetitions when text-decoded).
"""
OLD = (
' elif isinstance(c, np.ndarray): # audio\n'
' audios.append(c)\n'
' audio_parts.append(i)\n'
' cur_msgs.append("<audio>./</audio>")\n'
' use_tts_template = True\n'
)
NEW = (
' elif isinstance(c, np.ndarray): # audio\n'
' audios.append(c)\n'
' audio_parts.append(i)\n'
' cur_msgs.append("<audio>./</audio>")\n'
' # PATCHED: honour caller-provided use_tts_template.\n'
' # Upstream force-sets True on any audio, which makes the model\n'
' # generate TTS codec ids (look like <think> noise as text).\n'
)
if "PATCHED: honour caller-provided use_tts_template" in text:
return text, False
if OLD not in text:
return text, False
return text.replace(OLD, NEW), True
def patch_processor_image_bounds(text: str) -> tuple[str, bool]:
"""Fix ``image_bounds`` when start/end marker counts disagree (truncation)."""
OLD = " valid_image_nums = max(len(image_start_idx), len(image_end_idx))"
NEW = (
" # Pair only complete spans; max() breaks torch.hstack if counts differ.\n"
" valid_image_nums = min(len(image_start_idx), len(image_end_idx))"
)
if "valid_image_nums = min(len(image_start_idx), len(image_end_idx))" in text:
return text, False
if OLD not in text:
return text, False
return text.replace(OLD, NEW), True
def patch_file(path: Path) -> bool:
original = path.read_text()
text = original
any_change = False
text, c1 = patch_whisper_unpack(text)
any_change |= c1
text, c2 = patch_seen_tokens(text)
any_change |= c2
text, c3 = patch_chat_force_tts_template(text)
any_change |= c3
if any_change:
backup = path.with_suffix(path.suffix + ".bak")
if not backup.exists():
backup.write_text(original)
print(f"[patch] Backup -> {backup}")
path.write_text(text)
print(f"[patch] Patched {path.name}: "
f"whisper_unpack={c1}, seen_tokens={c2}, chat_tts_template={c3}")
else:
print(f"[patch] No changes needed (already patched or unknown revision)")
return any_change
def patch_processing_file(path: Path) -> bool:
"""Patch ``processing_minicpmo.py`` (image_bounds hstack)."""
original = path.read_text()
text = original
text, c = patch_processor_image_bounds(text)
if not c:
print(f"[patch] {path.name}: image_bounds already patched or pattern missing")
return False
backup = path.with_suffix(path.suffix + ".bak")
if not backup.exists():
backup.write_text(original)
print(f"[patch] Backup -> {backup}")
path.write_text(text)
print(f"[patch] Patched {path.name}: image_bounds min() fix")
return True
def main() -> int:
path = _find_modeling_file()
if path is None:
print("[patch] modeling_minicpmo.py not cached yet; fetching from HF...")
_download_model_code()
path = _find_modeling_file()
if path is None:
print("[patch] ERROR: could not locate modeling_minicpmo.py", file=sys.stderr)
return 1
print(f"[patch] Target: {path}")
patch_file(path)
proc = _find_processing_file()
if proc is not None:
print(f"[patch] Target: {proc}")
patch_processing_file(proc)
else:
print("[patch] (warn) processing_minicpmo.py not found next to modeling; "
"run once with HF cache populated")
# Invalidate __pycache__ so the edited file is re-imported.
for pc in path.parent.rglob("__pycache__"):
import shutil
shutil.rmtree(pc, ignore_errors=True)
return 0
if __name__ == "__main__":
sys.exit(main())