userisuser's picture
Deploy MiniCPM-V 4.6 Gradio Server demo
ecb8ee5
import os
import torch
import re
import av
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from gradio import Server
from gradio.data_classes import FileData
from fastapi.responses import HTMLResponse
import spaces
# Load model and processor
model_id = "openbmb/MiniCPM-V-4.6"
print(f"Loading model: {model_id}...")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="cuda"
)
def load_video(video_path, max_frames=64):
"""Utility to load video frames using PyAV."""
try:
container = av.open(video_path)
frames = []
stream = container.streams.video[0]
total_frames = stream.frames
if total_frames <= 0:
print("Frame count unknown, decoding all and sampling...")
temp_frames = []
for frame in container.decode(video=0):
temp_frames.append(frame.to_image())
if len(temp_frames) > max_frames:
indices = [int(i * len(temp_frames) / max_frames) for i in range(max_frames)]
frames = [temp_frames[i] for i in indices]
else:
frames = temp_frames
else:
indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
current_idx = 0
for i, frame in enumerate(container.decode(video=0)):
if current_idx < len(indices) and i == indices[current_idx]:
frames.append(frame.to_image())
current_idx += 1
if current_idx >= len(indices):
break
container.close()
return frames
except Exception as e:
print(f"Error loading video: {e}")
return None
# Utility for response normalization
_PATTERN = re.compile(
r'(```[\s\S]*?```|`[^`]+`|\$\$[\s\S]*?\$\$|\$[^$]+\$|\\\([\s\S]*?\\\)|\\\[[\s\S]*?\\\])'
r'|(?<!\\)(?:\\r\\n|\\[nr])'
)
def normalize_response_text(text: str) -> str:
if not isinstance(text, str) or "\\" not in text:
return text
return _PATTERN.sub(lambda m: m.group(1) or '\n', text)
app = Server()
@app.api()
@spaces.GPU(duration=120)
def predict(message: str, file: FileData = None, downsample_mode: str = "16x") -> str:
"""
General inference endpoint for both image and video.
"""
if file is None:
# Text-only inference
messages = [{"role": "user", "content": [{"type": "text", "text": message}]}]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
else:
file_path = file["path"]
# Robust detection: Try opening with AV first to see if it's a video
is_video = False
try:
container = av.open(file_path)
if len(container.streams.video) > 0:
is_video = True
container.close()
except:
is_video = False
if is_video:
print(f"Processing as video: {file_path}")
frames = load_video(file_path, max_frames=64)
if frames is None or len(frames) == 0:
return "Error: Could not decode video file."
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": frames},
{"type": "text", "text": message},
],
}
]
inputs = processor.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True,
return_dict=True, return_tensors="pt",
processor_kwargs={
"downsample_mode": downsample_mode,
"max_num_frames": 64,
"stack_frames": 1,
"max_slice_nums": 1,
"use_image_id": False,
"do_sample_frames": False, # Fix: Avoid requiring metadata since we already sampled
}
).to(model.device)
else:
print(f"Processing as image: {file_path}")
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": file_path},
{"type": "text", "text": message},
],
}
]
inputs = processor.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True,
return_dict=True, return_tensors="pt",
processor_kwargs={
"downsample_mode": downsample_mode,
"max_slice_nums": 9,
}
).to(model.device)
with torch.no_grad():
generate_kwargs = {
**inputs,
"max_new_tokens": 1024,
"do_sample": True,
"temperature": 0.7
}
if file is not None:
generate_kwargs["downsample_mode"] = downsample_mode
generated_ids = model.generate(**generate_kwargs)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return normalize_response_text(output_text[0])
@app.get("/", response_class=HTMLResponse)
async def homepage():
html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html")
with open(html_path, "r", encoding="utf-8") as f:
return f.read()
if __name__ == "__main__":
app.launch(show_error=True)