File size: 9,773 Bytes
b2c2640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#!/usr/bin/env python3
"""Patch MiniCPM-o 4.5 custom code in the Hugging Face modules cache.

``modeling_minicpmo.py`` (transformers >= 4.52):

  1. `WhisperEncoderLayer.forward` unpacks 3 values from `self.self_attn(...)`,
     but new `WhisperAttention.forward` returns 2 values.
  2. `prepare_inputs_for_generation` reads `past_key_values.seen_tokens`, which
     was removed from `DynamicCache`.
  3. `chat()` force-sets ``use_tts_template = True`` whenever audio is in the
     ``content`` list. That appends ``<|tts_bos|>`` to the assistant prefix
     and the model then generates **audio (TTS codec) ids**; decoded as text
     they look like ``<think>`` floods / gibberish. We want audio-in +
     **text-out** for benchmark eval, so respect the caller's kwarg instead.

``processing_minicpmo.py``:

  4. `_convert` used ``max(len(image_start_idx), len(image_end_idx))`` when
     building ``image_bounds``; after ``max_length`` truncation start/end counts
     can differ by one and ``torch.hstack`` raises (common with many video
     frames under the default ``chat(..., max_inp_length=8192)``). Use ``min``.

Idempotent. Also downloads model code on demand so files exist before patching.
"""
from __future__ import annotations

import os
import sys
from pathlib import Path

MODEL_ID = "openbmb/MiniCPM-o-4_5"


def _find_modeling_file() -> Path | None:
    """Locate the cached modeling_minicpmo.py (matches HF's module dir naming)."""
    home = Path(os.path.expanduser("~"))
    candidates = [
        home / ".cache" / "huggingface" / "modules" / "transformers_modules",
    ]
    hits: list[Path] = []
    for root in candidates:
        if not root.exists():
            continue
        for p in root.rglob("modeling_minicpmo.py"):
            hits.append(p)
    if not hits:
        return None
    # Prefer the deepest (snapshot-hashed) one.
    hits.sort(key=lambda p: len(p.parts), reverse=True)
    return hits[0]


def _find_processing_file() -> Path | None:
    """``processing_minicpmo.py`` lives next to the cached ``modeling_minicpmo.py``."""
    modeling = _find_modeling_file()
    if modeling is None:
        return None
    proc = modeling.parent / "processing_minicpmo.py"
    return proc if proc.is_file() else None


def _download_model_code() -> None:
    """Force HF to download MiniCPM-o's custom code so the file is cached.

    We only need the Python files + config (not weights) for patching. We use
    `hf_hub_download` for the individual code files to avoid fetching the
    multi-GB safetensors shards just to patch a .py file.
    """
    try:
        from huggingface_hub import hf_hub_download
    except ImportError:
        print("[patch] huggingface_hub not installed; skipping auto-download.")
        return

    for fn in [
        "config.json",
        "configuration_minicpm.py",
        "modeling_minicpmo.py",
        "modeling_navit_siglip.py",
        "processing_minicpmo.py",
        "resampler.py",
        "utils.py",
    ]:
        try:
            hf_hub_download(repo_id=MODEL_ID, filename=fn)
        except Exception as exc:
            # Some files may not exist in every revision; that's fine.
            print(f"[patch] (warn) could not fetch {fn}: {exc}")


def patch_whisper_unpack(text: str) -> tuple[str, bool]:
    """Fix #1: WhisperAttention now returns 2 values, not 3."""
    OLD = (
        "        hidden_states, attn_weights, past_key_values = self.self_attn(\n"
        "            hidden_states=hidden_states,\n"
        "            attention_mask=attention_mask,\n"
        "            layer_head_mask=layer_head_mask,\n"
        "            output_attentions=output_attentions,\n"
        "            past_key_value=past_key_values,\n"
        "        )"
    )
    NEW = (
        "        _attn_out = self.self_attn(\n"
        "            hidden_states=hidden_states,\n"
        "            attention_mask=attention_mask,\n"
        "            layer_head_mask=layer_head_mask,\n"
        "            output_attentions=output_attentions,\n"
        "            past_key_value=past_key_values,\n"
        "        )\n"
        "        if len(_attn_out) == 3:\n"
        "            hidden_states, attn_weights, past_key_values = _attn_out\n"
        "        else:\n"
        "            hidden_states, attn_weights = _attn_out"
    )
    if NEW.split("\n", 1)[0] in text:
        return text, False  # already patched
    if OLD not in text:
        return text, False  # not applicable (different revision?)
    return text.replace(OLD, NEW), True


def patch_seen_tokens(text: str) -> tuple[str, bool]:
    """Fix #2: DynamicCache.seen_tokens was removed in newer transformers."""
    OLD = (
        "            cache_length = past_key_values.get_seq_length()\n"
        "            past_length = past_key_values.seen_tokens"
    )
    NEW = (
        "            cache_length = past_key_values.get_seq_length()\n"
        "            past_length = getattr(past_key_values, \"seen_tokens\", cache_length)"
    )
    if 'getattr(past_key_values, "seen_tokens"' in text:
        return text, False  # already patched
    if OLD not in text:
        return text, False
    return text.replace(OLD, NEW), True


def patch_chat_force_tts_template(text: str) -> tuple[str, bool]:
    """Fix #3: don't force ``use_tts_template=True`` on audio-containing content.

    MiniCPM-o's ``chat()`` assumes "audio in implies TTS audio out". For MCQ /
    freetext eval we want a text answer; the caller's ``use_tts_template`` kwarg
    (default ``False``) must win so the assistant prefix doesn't get
    ``<|tts_bos|>`` appended (which causes the LM to emit audio-codec ids that
    look like ``<think>`` repetitions when text-decoded).
    """
    OLD = (
        '                    elif isinstance(c, np.ndarray):  # audio\n'
        '                        audios.append(c)\n'
        '                        audio_parts.append(i)\n'
        '                        cur_msgs.append("<audio>./</audio>")\n'
        '                        use_tts_template = True\n'
    )
    NEW = (
        '                    elif isinstance(c, np.ndarray):  # audio\n'
        '                        audios.append(c)\n'
        '                        audio_parts.append(i)\n'
        '                        cur_msgs.append("<audio>./</audio>")\n'
        '                        # PATCHED: honour caller-provided use_tts_template.\n'
        '                        # Upstream force-sets True on any audio, which makes the model\n'
        '                        # generate TTS codec ids (look like <think> noise as text).\n'
    )
    if "PATCHED: honour caller-provided use_tts_template" in text:
        return text, False
    if OLD not in text:
        return text, False
    return text.replace(OLD, NEW), True


def patch_processor_image_bounds(text: str) -> tuple[str, bool]:
    """Fix ``image_bounds`` when start/end marker counts disagree (truncation)."""
    OLD = "        valid_image_nums = max(len(image_start_idx), len(image_end_idx))"
    NEW = (
        "        # Pair only complete spans; max() breaks torch.hstack if counts differ.\n"
        "        valid_image_nums = min(len(image_start_idx), len(image_end_idx))"
    )
    if "valid_image_nums = min(len(image_start_idx), len(image_end_idx))" in text:
        return text, False
    if OLD not in text:
        return text, False
    return text.replace(OLD, NEW), True


def patch_file(path: Path) -> bool:
    original = path.read_text()
    text = original
    any_change = False

    text, c1 = patch_whisper_unpack(text)
    any_change |= c1
    text, c2 = patch_seen_tokens(text)
    any_change |= c2
    text, c3 = patch_chat_force_tts_template(text)
    any_change |= c3

    if any_change:
        backup = path.with_suffix(path.suffix + ".bak")
        if not backup.exists():
            backup.write_text(original)
            print(f"[patch] Backup -> {backup}")
        path.write_text(text)
        print(f"[patch] Patched {path.name}: "
              f"whisper_unpack={c1}, seen_tokens={c2}, chat_tts_template={c3}")
    else:
        print(f"[patch] No changes needed (already patched or unknown revision)")
    return any_change


def patch_processing_file(path: Path) -> bool:
    """Patch ``processing_minicpmo.py`` (image_bounds hstack)."""
    original = path.read_text()
    text = original
    text, c = patch_processor_image_bounds(text)
    if not c:
        print(f"[patch] {path.name}: image_bounds already patched or pattern missing")
        return False
    backup = path.with_suffix(path.suffix + ".bak")
    if not backup.exists():
        backup.write_text(original)
        print(f"[patch] Backup -> {backup}")
    path.write_text(text)
    print(f"[patch] Patched {path.name}: image_bounds min() fix")
    return True


def main() -> int:
    path = _find_modeling_file()
    if path is None:
        print("[patch] modeling_minicpmo.py not cached yet; fetching from HF...")
        _download_model_code()
        path = _find_modeling_file()
    if path is None:
        print("[patch] ERROR: could not locate modeling_minicpmo.py", file=sys.stderr)
        return 1
    print(f"[patch] Target: {path}")
    patch_file(path)

    proc = _find_processing_file()
    if proc is not None:
        print(f"[patch] Target: {proc}")
        patch_processing_file(proc)
    else:
        print("[patch] (warn) processing_minicpmo.py not found next to modeling; "
              "run once with HF cache populated")

    # Invalidate __pycache__ so the edited file is re-imported.
    for pc in path.parent.rglob("__pycache__"):
        import shutil
        shutil.rmtree(pc, ignore_errors=True)
    return 0


if __name__ == "__main__":
    sys.exit(main())