import os import torch import re import av from PIL import Image from transformers import AutoModelForImageTextToText, AutoProcessor from gradio import Server from gradio.data_classes import FileData from fastapi.responses import HTMLResponse import spaces # Load model and processor model_id = "openbmb/MiniCPM-V-4.6" print(f"Loading model: {model_id}...") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="cuda" ) def load_video(video_path, max_frames=64): """Utility to load video frames using PyAV.""" try: container = av.open(video_path) frames = [] stream = container.streams.video[0] total_frames = stream.frames if total_frames <= 0: print("Frame count unknown, decoding all and sampling...") temp_frames = [] for frame in container.decode(video=0): temp_frames.append(frame.to_image()) if len(temp_frames) > max_frames: indices = [int(i * len(temp_frames) / max_frames) for i in range(max_frames)] frames = [temp_frames[i] for i in indices] else: frames = temp_frames else: indices = [int(i * total_frames / max_frames) for i in range(max_frames)] current_idx = 0 for i, frame in enumerate(container.decode(video=0)): if current_idx < len(indices) and i == indices[current_idx]: frames.append(frame.to_image()) current_idx += 1 if current_idx >= len(indices): break container.close() return frames except Exception as e: print(f"Error loading video: {e}") return None # Utility for response normalization _PATTERN = re.compile( r'(```[\s\S]*?```|`[^`]+`|\$\$[\s\S]*?\$\$|\$[^$]+\$|\\\([\s\S]*?\\\)|\\\[[\s\S]*?\\\])' r'|(? str: if not isinstance(text, str) or "\\" not in text: return text return _PATTERN.sub(lambda m: m.group(1) or '\n', text) app = Server() @app.api() @spaces.GPU(duration=120) def predict(message: str, file: FileData = None, downsample_mode: str = "16x") -> str: """ General inference endpoint for both image and video. """ if file is None: # Text-only inference messages = [{"role": "user", "content": [{"type": "text", "text": message}]}] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" ).to(model.device) else: file_path = file["path"] # Robust detection: Try opening with AV first to see if it's a video is_video = False try: container = av.open(file_path) if len(container.streams.video) > 0: is_video = True container.close() except: is_video = False if is_video: print(f"Processing as video: {file_path}") frames = load_video(file_path, max_frames=64) if frames is None or len(frames) == 0: return "Error: Could not decode video file." messages = [ { "role": "user", "content": [ {"type": "video", "video": frames}, {"type": "text", "text": message}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", processor_kwargs={ "downsample_mode": downsample_mode, "max_num_frames": 64, "stack_frames": 1, "max_slice_nums": 1, "use_image_id": False, "do_sample_frames": False, # Fix: Avoid requiring metadata since we already sampled } ).to(model.device) else: print(f"Processing as image: {file_path}") messages = [ { "role": "user", "content": [ {"type": "image", "url": file_path}, {"type": "text", "text": message}, ], } ] inputs = processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", processor_kwargs={ "downsample_mode": downsample_mode, "max_slice_nums": 9, } ).to(model.device) with torch.no_grad(): generate_kwargs = { **inputs, "max_new_tokens": 1024, "do_sample": True, "temperature": 0.7 } if file is not None: generate_kwargs["downsample_mode"] = downsample_mode generated_ids = model.generate(**generate_kwargs) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return normalize_response_text(output_text[0]) @app.get("/", response_class=HTMLResponse) async def homepage(): html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") with open(html_path, "r", encoding="utf-8") as f: return f.read() if __name__ == "__main__": app.launch(show_error=True)