MiniCPM-Evaluation / scripts /eval_daily_omni.py
Rakancorle11's picture
Upload folder using huggingface_hub
b2c2640 verified
#!/usr/bin/env python3
"""Evaluate MiniCPM-o 4.5 on Daily-Omni.
Daily-Omni videos include embedded audio; we extract it and feed both frames
and waveform to MiniCPM-o.
"""
from __future__ import annotations
import _common
import argparse
import gc
import io
import contextlib
import json
from pathlib import Path
import torch
from tqdm import tqdm
ch = _common.ch("daily_omni")
load_daily_omni = ch.load_daily_omni
extract_answer = ch.extract_answer
compute_metrics = ch.compute_metrics
print_summary = ch.print_summary
DEFAULT_DATA_DIR = ch.DEFAULT_DATA_DIR
DEFAULT_OUTPUT_DIR = ch.DEFAULT_OUTPUT_DIR
from minicpmo_inference import load_model, run_inference
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Evaluate MiniCPM-o on Daily-Omni.")
p.add_argument("--model-id", type=str, default="openbmb/MiniCPM-o-4_5")
p.add_argument("--data-dir", type=Path, default=DEFAULT_DATA_DIR)
p.add_argument("--output-dir", type=Path,
default=Path("/home/ubuntu/eval_results/daily_omni_minicpmo"))
p.add_argument("--max-samples", type=int, default=-1)
p.add_argument("--max-new-tokens", type=int, default=32)
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--label", type=str, default="minicpmo_daily_omni")
p.add_argument("--max-frames", type=int, default=64)
p.add_argument("--fps", type=float, default=1.0)
p.add_argument("--attn", type=str, default="flash_attention_2",
choices=["sdpa", "flash_attention_2", "eager"])
p.add_argument("--no-audio", action="store_true",
help="Video-only mode (skip audio extraction).")
p.add_argument(
"--skip-audio-durations",
type=str,
default="",
help=(
"Comma-separated `video_duration` values from the dataset for which "
"audio is omitted (video-only for those clips). Useful when "
"MiniCPM-o forward fails on some lengths with audio+vision "
'(e.g. empty `raw_output` and log errors like "Expected size 122 '
'but got size 121"). Example: --skip-audio-durations 60s'
),
)
# vLLM flags: parity-only (MiniCPM-o 4.5 multimodal vLLM not yet supported).
p.add_argument("--vllm", action="store_true", default=False,
help="(no-op for MiniCPM-o 4.5; auto-falls back to transformers).")
p.add_argument("--tp", type=int, default=None)
p.add_argument("--gpu-memory-utilization", type=float, default=0.90)
p.add_argument("--max-model-len", type=int, default=65536)
p.add_argument("--batch-size", type=int, default=32)
# Data-parallel sharding
p.add_argument("--shard", type=int, default=0)
p.add_argument("--num-shards", type=int, default=1)
return p.parse_args()
def main() -> None:
args = parse_args()
out_dir = args.output_dir / args.label
out_dir.mkdir(parents=True, exist_ok=True)
shard_suffix = (f".shard{args.shard}of{args.num_shards}"
if args.num_shards > 1 else "")
results_jsonl = out_dir / f"eval_results{shard_suffix}.jsonl"
metrics_json = out_dir / "metrics.json"
summary_txt = out_dir / "summary.txt"
if args.vllm:
print("[warn] --vllm requested but MiniCPM-o 4.5 multimodal vLLM is not "
"supported upstream yet; falling back to transformers.")
print("[data] Loading Daily-Omni dataset...")
test_data = load_daily_omni(args.data_dir, args.max_samples)
if args.num_shards > 1:
test_data = [x for i, x in enumerate(test_data) if i % args.num_shards == args.shard]
print(f"[shard] shard {args.shard}/{args.num_shards}: {len(test_data)} questions")
else:
print(f"[data] {len(test_data)} questions ready")
processed: set = set()
if results_jsonl.exists():
with open(results_jsonl) as f:
for line in f:
obj = json.loads(line)
processed.add(obj["question_id"])
print(f"[resume] {len(processed)} already processed")
model, tokenizer = load_model(
args.model_id, attn_implementation=args.attn, init_audio=not args.no_audio,
)
skip_audio_durs = {
x.strip()
for x in args.skip_audio_durations.split(",")
if x.strip()
}
for item in tqdm(test_data, desc="Daily-Omni", unit="q"):
if item["question_id"] in processed:
continue
use_audio = not args.no_audio and (
item.get("video_duration", "") not in skip_audio_durs
)
try:
raw_output = run_inference(
model, tokenizer,
video_path=item["video_path"],
audio_path=None,
prompt=item["prompt"],
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
max_frames=args.max_frames,
fps=args.fps,
use_audio_from_video=use_audio,
)
except Exception as exc:
import traceback
print(f" [error] {item['question_id']}: {exc}")
traceback.print_exc()
raw_output = ""
pred = extract_answer(raw_output)
result = {
"question_id": item["question_id"],
"video_id": item["video_id"],
"question_type": item.get("question_type", ""),
"content_parent_category": item.get("content_parent_category", ""),
"content_fine_category": item.get("content_fine_category", ""),
"video_category": item.get("video_category", ""),
"video_duration": item.get("video_duration", ""),
"question": item["question"],
"choices": item["choices"],
"gt_answer": item["gt_answer"],
"pred_answer": pred,
"raw_output": raw_output,
}
with open(results_jsonl, "a", encoding="utf-8") as f:
f.write(json.dumps(result, ensure_ascii=False) + "\n")
processed.add(item["question_id"])
gc.collect()
torch.cuda.empty_cache()
if args.num_shards > 1:
print(f"\n[shard {args.shard}/{args.num_shards}] Done. Results: {results_jsonl}")
print(f"[shard] Run merge_shards.py --bench daily_omni --label-dir {out_dir}")
return
all_results = []
if results_jsonl.exists():
with open(results_jsonl) as f:
for line in f:
all_results.append(json.loads(line))
metrics = compute_metrics(all_results)
metrics["eval_config"] = {
"model_id": args.model_id,
"data_dir": str(args.data_dir),
"max_new_tokens": args.max_new_tokens,
"temperature": args.temperature,
"max_frames": args.max_frames,
"fps": args.fps,
"attn": args.attn,
"no_audio": args.no_audio,
"skip_audio_durations": sorted(skip_audio_durs),
}
with open(metrics_json, "w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2, ensure_ascii=False)
print_summary(metrics, args.label)
with open(summary_txt, "w", encoding="utf-8") as f:
buf = io.StringIO()
with contextlib.redirect_stdout(buf):
print_summary(metrics, args.label)
f.write(buf.getvalue())
print(f"\n[output] Results: {results_jsonl}")
print(f"[output] Metrics: {metrics_json}")
print(f"[output] Summary: {summary_txt}")
if __name__ == "__main__":
main()