Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import re | |
| import av | |
| from PIL import Image | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| from fastapi.responses import HTMLResponse | |
| import spaces | |
| # Load model and processor | |
| model_id = "openbmb/MiniCPM-V-4.6" | |
| print(f"Loading model: {model_id}...") | |
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="cuda" | |
| ) | |
| def load_video(video_path, max_frames=64): | |
| """Utility to load video frames using PyAV.""" | |
| try: | |
| container = av.open(video_path) | |
| frames = [] | |
| stream = container.streams.video[0] | |
| total_frames = stream.frames | |
| if total_frames <= 0: | |
| print("Frame count unknown, decoding all and sampling...") | |
| temp_frames = [] | |
| for frame in container.decode(video=0): | |
| temp_frames.append(frame.to_image()) | |
| if len(temp_frames) > max_frames: | |
| indices = [int(i * len(temp_frames) / max_frames) for i in range(max_frames)] | |
| frames = [temp_frames[i] for i in indices] | |
| else: | |
| frames = temp_frames | |
| else: | |
| indices = [int(i * total_frames / max_frames) for i in range(max_frames)] | |
| current_idx = 0 | |
| for i, frame in enumerate(container.decode(video=0)): | |
| if current_idx < len(indices) and i == indices[current_idx]: | |
| frames.append(frame.to_image()) | |
| current_idx += 1 | |
| if current_idx >= len(indices): | |
| break | |
| container.close() | |
| return frames | |
| except Exception as e: | |
| print(f"Error loading video: {e}") | |
| return None | |
| # Utility for response normalization | |
| _PATTERN = re.compile( | |
| r'(```[\s\S]*?```|`[^`]+`|\$\$[\s\S]*?\$\$|\$[^$]+\$|\\\([\s\S]*?\\\)|\\\[[\s\S]*?\\\])' | |
| r'|(?<!\\)(?:\\r\\n|\\[nr])' | |
| ) | |
| def normalize_response_text(text: str) -> str: | |
| if not isinstance(text, str) or "\\" not in text: | |
| return text | |
| return _PATTERN.sub(lambda m: m.group(1) or '\n', text) | |
| app = Server() | |
| def predict(message: str, file: FileData = None, downsample_mode: str = "16x") -> str: | |
| """ | |
| General inference endpoint for both image and video. | |
| """ | |
| if file is None: | |
| # Text-only inference | |
| messages = [{"role": "user", "content": [{"type": "text", "text": message}]}] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt" | |
| ).to(model.device) | |
| else: | |
| file_path = file["path"] | |
| # Robust detection: Try opening with AV first to see if it's a video | |
| is_video = False | |
| try: | |
| container = av.open(file_path) | |
| if len(container.streams.video) > 0: | |
| is_video = True | |
| container.close() | |
| except: | |
| is_video = False | |
| if is_video: | |
| print(f"Processing as video: {file_path}") | |
| frames = load_video(file_path, max_frames=64) | |
| if frames is None or len(frames) == 0: | |
| return "Error: Could not decode video file." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "video", "video": frames}, | |
| {"type": "text", "text": message}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, | |
| return_dict=True, return_tensors="pt", | |
| processor_kwargs={ | |
| "downsample_mode": downsample_mode, | |
| "max_num_frames": 64, | |
| "stack_frames": 1, | |
| "max_slice_nums": 1, | |
| "use_image_id": False, | |
| "do_sample_frames": False, # Fix: Avoid requiring metadata since we already sampled | |
| } | |
| ).to(model.device) | |
| else: | |
| print(f"Processing as image: {file_path}") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "url": file_path}, | |
| {"type": "text", "text": message}, | |
| ], | |
| } | |
| ] | |
| inputs = processor.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True, | |
| return_dict=True, return_tensors="pt", | |
| processor_kwargs={ | |
| "downsample_mode": downsample_mode, | |
| "max_slice_nums": 9, | |
| } | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| generate_kwargs = { | |
| **inputs, | |
| "max_new_tokens": 1024, | |
| "do_sample": True, | |
| "temperature": 0.7 | |
| } | |
| if file is not None: | |
| generate_kwargs["downsample_mode"] = downsample_mode | |
| generated_ids = model.generate(**generate_kwargs) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return normalize_response_text(output_text[0]) | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if __name__ == "__main__": | |
| app.launch(show_error=True) | |