Spaces:
Running on Zero
Running on Zero
| import os, sys, shutil | |
| import json | |
| import glob | |
| import math | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| import cv2 | |
| import ffmpeg | |
| import numpy as np | |
| import torch | |
| import tempfile | |
| import spaces | |
| from fastapi.responses import HTMLResponse | |
| # Temp file bug of gradio | |
| BASE_TMP_DIR = os.path.abspath("./gradio_tmp") | |
| os.makedirs(BASE_TMP_DIR, exist_ok=True) | |
| os.environ["TMPDIR"] = BASE_TMP_DIR | |
| os.environ["TEMP"] = BASE_TMP_DIR | |
| os.environ["TMP"] = BASE_TMP_DIR | |
| os.environ["GRADIO_TEMP_DIR"] = BASE_TMP_DIR | |
| tempfile.tempdir = BASE_TMP_DIR | |
| from gradio import Server | |
| from gradio.data_classes import FileData | |
| # Import your existing project code | |
| root_path = os.path.abspath(".") | |
| sys.path.append(root_path) | |
| from architecture.backbone import build_backbone | |
| from architecture.transformer import build_transformer | |
| from architecture.model import OmniShotCut | |
| from datasets.transforms import Video_Augmentation_Transform | |
| from util.visualization import visualize_concated_frames | |
| from config.label_correspondence import unique_intra_label_mapping, unique_inter_label_mapping | |
| from test_code.inference import single_video_inference, load_model | |
| # ------------------------- | |
| # Global cache / constants | |
| # ------------------------- | |
| video_transform = Video_Augmentation_Transform(set_type="val") | |
| INTRA_ID2NAME = {v: k for k, v in unique_intra_label_mapping.items()} | |
| INTER_ID2NAME = {v: k for k, v in unique_inter_label_mapping.items()} | |
| # Fixed demo config | |
| DEFAULT_CHECKPOINT_PATH = "checkpoints/OmniShotCut_ckpt.pth" | |
| DEFAULT_NUM_CONTEXT_FRAMES = 0 | |
| DEFAULT_MAX_FRAMES_PER_IMG = 132 | |
| VIS_DIR = "demo_video_results" | |
| # Public URL safe setting | |
| MAX_GALLERY_PAGES = 20 | |
| # Prepare the checkpoint | |
| if not os.path.exists(DEFAULT_CHECKPOINT_PATH): | |
| os.makedirs("checkpoints", exist_ok=True) | |
| os.system("wget -P checkpoints https://huggingface.co/uva-cv-lab/OmniShotCut/resolve/main/OmniShotCut_ckpt.pth") | |
| # Load the model | |
| checkpoint_path = DEFAULT_CHECKPOINT_PATH | |
| model, model_args = load_model(checkpoint_path) | |
| ######################## Utilities ######################## | |
| def escape_html(x): | |
| x = "" if x is None else str(x) | |
| return ( | |
| x.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace('"', """) | |
| .replace("'", "'") | |
| ) | |
| def prepare_result_table( | |
| pred_ranges: List[List[int]], | |
| pred_intra_labels: List[int], | |
| pred_inter_labels: List[int], | |
| fps: float, | |
| ) -> str: | |
| headers = [ | |
| "Index", | |
| "Start Frame", | |
| "End Frame", | |
| "Start Time (s)", | |
| "End Time (s)", | |
| "Intra Label", | |
| "Inter Label", | |
| ] | |
| html = """ | |
| <div class="result-table-wrap"> | |
| <table class="result-table"> | |
| <thead> | |
| <tr> | |
| """ | |
| for h in headers: | |
| html += f"<th>{escape_html(h)}</th>" | |
| html += """ | |
| </tr> | |
| </thead> | |
| <tbody> | |
| """ | |
| for idx, pred_range in enumerate(pred_ranges): | |
| start_frame = int(pred_range[0]) | |
| end_frame = int(pred_range[1]) | |
| intra_id = int(pred_intra_labels[idx]) if idx < len(pred_intra_labels) else -1 | |
| inter_id = int(pred_inter_labels[idx]) if idx < len(pred_inter_labels) else -1 | |
| row = [ | |
| idx, | |
| start_frame, | |
| end_frame, | |
| round(start_frame / fps, 3) if fps and fps > 0 else "", | |
| round(end_frame / fps, 3) if fps and fps > 0 else "", | |
| INTRA_ID2NAME.get(intra_id, str(intra_id)), | |
| INTER_ID2NAME.get(inter_id, str(inter_id)), | |
| ] | |
| html += "<tr>" | |
| for item in row: | |
| html += f"<td>{escape_html(item)}</td>" | |
| html += "</tr>" | |
| html += """ | |
| </tbody> | |
| </table> | |
| </div> | |
| """ | |
| return html | |
| def list_sample_videos(asset_dir: str = "__assets__", max_samples: int = 8) -> List[dict]: | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| asset_dir = os.path.join(script_dir, asset_dir) | |
| if not os.path.isdir(asset_dir): | |
| return [] | |
| samples = [] | |
| for name in sorted(os.listdir(asset_dir)): | |
| path = os.path.join(asset_dir, name) | |
| if os.path.isfile(path) and name.lower().endswith(".mp4"): | |
| samples.append({"path": path, "name": name}) | |
| return samples[:max_samples] | |
| from fastapi.staticfiles import StaticFiles | |
| # ------------------------- | |
| # Server and API | |
| # ------------------------- | |
| app = Server() | |
| os.makedirs(VIS_DIR, exist_ok=True) | |
| app.mount("/outputs", StaticFiles(directory=VIS_DIR), name="outputs") | |
| def get_examples() -> List[dict]: | |
| samples = list_sample_videos("__assets__/", max_samples=16) | |
| space_id = os.getenv("SPACE_ID") | |
| if space_id: | |
| hub_base = f"https://huggingface.co/spaces/{space_id}/resolve/main/__assets__/" | |
| return [{"url": f"{hub_base}{s['name']}", "orig_name": s["name"]} for s in samples] | |
| return [FileData(path=s["path"], orig_name=s["name"]) for s in samples] | |
| def run_inference(video_file: dict) -> dict: | |
| video_path = video_file["path"] | |
| # ffmpeg/opencv often need a file extension to correctly parse the container | |
| if not os.path.splitext(video_path)[1]: | |
| orig_name = video_file.get("orig_name") or "input.mp4" | |
| ext = os.path.splitext(orig_name)[1] or ".mp4" | |
| new_path = video_path + ext | |
| if not os.path.exists(new_path): | |
| shutil.copy(video_path, new_path) | |
| video_path = new_path | |
| if not os.path.exists(video_path): | |
| return {"error": "Video file not found"} | |
| # Check if it's a Git LFS pointer instead of a real video | |
| if os.path.getsize(video_path) < 1000: | |
| with open(video_path, "rb") as f: | |
| header = f.read(100) | |
| if b"version https://git-lfs" in header: | |
| return {"error": "LFS pointer detected. Please ensure video files are fully downloaded on the Space."} | |
| print(f"Start processing: {video_path}") | |
| pred_ranges, pred_intra_labels, pred_inter_labels, video_np_full, fps = single_video_inference( | |
| video_path=video_path, | |
| model=model, | |
| model_args=model_args, | |
| num_context_frames=DEFAULT_NUM_CONTEXT_FRAMES, | |
| ) | |
| print("Inference finished") | |
| # Prepare visualization directory | |
| cur_vis_dir = os.path.join(VIS_DIR, f"vis_{int(time.time())}") | |
| os.makedirs(cur_vis_dir, exist_ok=True) | |
| # Generate visualization frames | |
| page_paths = visualize_concated_frames( | |
| frames=video_np_full, | |
| out_dir=cur_vis_dir, | |
| highlight_ranges_closed=pred_ranges, | |
| max_frames_per_img=DEFAULT_MAX_FRAMES_PER_IMG, | |
| end_range_exclusive=True, | |
| fps=fps, | |
| start_index=0, | |
| ) | |
| gallery_data = [] | |
| for p in page_paths[:MAX_GALLERY_PAGES]: | |
| rel_path = os.path.relpath(p, VIS_DIR) | |
| gallery_data.append({"url": f"/outputs/{rel_path}"}) | |
| result_table_html = prepare_result_table( | |
| pred_ranges=pred_ranges, | |
| pred_intra_labels=pred_intra_labels, | |
| pred_inter_labels=pred_inter_labels, | |
| fps=fps, | |
| ) | |
| return { | |
| "gallery": gallery_data, | |
| "table": result_table_html, | |
| "shot_count": len(pred_ranges) | |
| } | |
| async def homepage(): | |
| html_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "index.html") | |
| with open(html_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| if __name__ == "__main__": | |
| app.launch(show_error=True) |