Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """HF Space entry point for 3DReflecNet dataset preview. | |
| Loads only data/preview/preview.parquet so the Space exposes the configured | |
| preview instance subset instead of the full dataset metadata. | |
| """ | |
| from __future__ import annotations | |
| import atexit | |
| import io | |
| import os | |
| import shutil | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from utils import ( | |
| BOOL_FILTER_CHOICES, | |
| FILTER_ALL, | |
| filter_dataframe_advanced, | |
| get_distinct_text_choices, | |
| logger, | |
| require_bool_columns, | |
| require_columns, | |
| require_text_columns, | |
| setup_logging, | |
| ) | |
| DATASET_REPO = os.environ.get("DATASET_REPO", "3DReflecNet/3DReflecNet") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| MAX_RESULTS = 300 | |
| BOOL_COLUMNS = ["hasGlass", "isGenerated", "transparent", "near_light"] | |
| _GLB_CACHE_DIR = Path(tempfile.mkdtemp(prefix="glb_cache_")) | |
| atexit.register(shutil.rmtree, str(_GLB_CACHE_DIR), True) | |
| # --------------------------------------------------------------------------- | |
| # Data loading | |
| # --------------------------------------------------------------------------- | |
| def load_preview_dataframe() -> pd.DataFrame: | |
| """Load the small preview Parquet into memory.""" | |
| PREVIEW_COLS = [ | |
| "instance_id", "split", "frame_id", "rgb", "mask", | |
| "depth_preview", "normal_preview", | |
| "main_category", "sub_category", "model_name", | |
| "material_name", "env_name", "glb_path", | |
| "hasGlass", "isGenerated", "transparent", "near_light", | |
| ] | |
| ds = load_dataset( | |
| DATASET_REPO, | |
| data_files="data/preview/preview.parquet", | |
| split="train", | |
| streaming=False, | |
| token=HF_TOKEN, | |
| ).select_columns(PREVIEW_COLS) | |
| df = pd.DataFrame(list(ds)) | |
| require_columns(df, PREVIEW_COLS, "preview parquet") | |
| require_text_columns( | |
| df, | |
| [ | |
| "instance_id", "split", "main_category", "sub_category", | |
| "model_name", "material_name", "env_name", "glb_path", | |
| ], | |
| "preview parquet", | |
| ) | |
| require_bool_columns(df, BOOL_COLUMNS, "preview parquet") | |
| if df["frame_id"].isna().any() or not pd.api.types.is_integer_dtype(df["frame_id"]): | |
| raise TypeError(f"Expected integer dtype for column 'frame_id' in preview parquet, got {df['frame_id'].dtype}.") | |
| for col in ["rgb", "mask", "depth_preview", "normal_preview"]: | |
| invalid = df[col].map(lambda value: not isinstance(value, (bytes, bytearray)) or len(value) == 0) | |
| if invalid.any(): | |
| raise TypeError(f"Expected non-empty binary values for column {col!r} in preview parquet.") | |
| return df | |
| def decode_image_bytes(img_bytes: bytes | bytearray, context: str) -> Image.Image: | |
| if not isinstance(img_bytes, (bytes, bytearray)) or not img_bytes: | |
| raise TypeError(f"Expected non-empty image bytes for {context}.") | |
| with Image.open(io.BytesIO(img_bytes)) as img: | |
| return img.copy() | |
| def build_preview_instance_dataframe(preview_df: pd.DataFrame) -> pd.DataFrame: | |
| """Derive one row per preview instance from preview frame rows.""" | |
| instance_cols = [ | |
| "instance_id", "main_category", "sub_category", "model_name", | |
| "material_name", "env_name", "hasGlass", "isGenerated", | |
| "transparent", "near_light", "glb_path", | |
| ] | |
| require_columns(preview_df, instance_cols, "preview parquet") | |
| rows: list[dict[str, Any]] = [] | |
| for instance_id, group in preview_df.groupby("instance_id", sort=True): | |
| row: dict[str, Any] = {} | |
| for col in instance_cols: | |
| values = group[col].drop_duplicates().tolist() | |
| if len(values) != 1: | |
| raise ValueError(f"Inconsistent {col!r} values for preview instance {instance_id!r}.") | |
| row[col] = values[0] | |
| rows.append(row) | |
| df = pd.DataFrame(rows, columns=instance_cols) | |
| require_text_columns( | |
| df, | |
| [ | |
| "instance_id", "main_category", "sub_category", | |
| "model_name", "material_name", "env_name", "glb_path", | |
| ], | |
| "preview instance dataframe", | |
| ) | |
| require_bool_columns(df, BOOL_COLUMNS, "preview instance dataframe") | |
| if df["glb_path"].map(lambda value: not value.strip()).any(): | |
| raise ValueError("Preview instance dataframe contains empty GLB paths.") | |
| return df | |
| def train_frame_rows(preview_df: pd.DataFrame, instance_id: str, max_frames: int | None = None) -> list[dict[str, Any]]: | |
| selected = preview_df[ | |
| (preview_df["instance_id"].astype(str) == str(instance_id)) | |
| & (preview_df["split"].astype(str) == "train") | |
| ].copy() | |
| if selected.empty: | |
| raise ValueError(f"Preview instance {instance_id!r} has no train split rows.") | |
| selected = selected.sort_values("frame_id") | |
| if max_frames is not None: | |
| selected = selected.head(max_frames) | |
| return selected.to_dict(orient="records") | |
| def get_instance_thumbnail(preview_df: pd.DataFrame, instance_id: str) -> Image.Image: | |
| row = train_frame_rows(preview_df, instance_id, max_frames=1)[0] | |
| return decode_image_bytes(row["rgb"], f"{instance_id} thumbnail RGB") | |
| def instance_caption(row: dict[str, Any]) -> str: | |
| return f"{row['model_name']} | {row['material_name']} | {row['env_name']}" | |
| def build_instance_gallery_items( | |
| rows: list[dict[str, Any]], | |
| preview_df: pd.DataFrame, | |
| ) -> list[tuple[Image.Image, str]]: | |
| return [ | |
| (get_instance_thumbnail(preview_df, row["instance_id"]), instance_caption(row)) | |
| for row in rows | |
| ] | |
| def load_instance_frames( | |
| preview_df: pd.DataFrame, instance_id: str, max_frames: int = 50, | |
| ) -> list[dict[str, Any]]: | |
| """Load train preview image payloads for one instance from preview Parquet.""" | |
| rows = train_frame_rows(preview_df, instance_id, max_frames=max_frames) | |
| frames: list[dict[str, Any]] = [] | |
| for example in rows: | |
| frame_id = int(example["frame_id"]) | |
| frame_item: dict[str, Any] = {"frame_id": frame_id} | |
| for key in ("rgb", "mask", "depth_preview", "normal_preview"): | |
| frame_item[key] = decode_image_bytes(example[key], f"{key} frame {frame_id}") | |
| frames.append(frame_item) | |
| return frames | |
| def render_frame_images(frame_items: list[dict[str, Any]], frame_index: float) -> list[Any | None]: | |
| """Render RGB/Mask/Depth/Normal images for one selected frame index (1-based).""" | |
| if not frame_items: | |
| return [ | |
| gr.update(value=None, label="RGB"), | |
| gr.update(value=None, label="Mask"), | |
| gr.update(value=None, label="Depth"), | |
| gr.update(value=None, label="Normal"), | |
| ] | |
| idx = int(round(frame_index)) - 1 | |
| idx = max(0, min(idx, len(frame_items) - 1)) | |
| selected = frame_items[idx] | |
| frame_id = int(selected["frame_id"]) | |
| return [ | |
| gr.update(value=selected["rgb"], label=f"RGB frame_{frame_id:05d}"), | |
| gr.update(value=selected["mask"], label=f"Mask frame_{frame_id:05d}"), | |
| gr.update(value=selected["depth_preview"], label=f"Depth frame_{frame_id:05d}"), | |
| gr.update(value=selected["normal_preview"], label=f"Normal frame_{frame_id:05d}"), | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def download_glb(glb_path: str) -> str: | |
| """Download pre-converted GLB file from HF dataset repo.""" | |
| if not glb_path: | |
| raise ValueError("GLB path is required.") | |
| local = _GLB_CACHE_DIR / Path(glb_path).name | |
| if local.exists(): | |
| return str(local) | |
| downloaded = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=glb_path, | |
| repo_type="dataset", | |
| token=HF_TOKEN, | |
| ) | |
| shutil.copy2(downloaded, str(local)) | |
| logger.info("GLB ready: %s", local) | |
| return str(local) | |
| # --------------------------------------------------------------------------- | |
| # App builder | |
| # --------------------------------------------------------------------------- | |
| def build_app(instance_df: pd.DataFrame, preview_df: pd.DataFrame) -> gr.Blocks: | |
| model_name_choices = get_distinct_text_choices(instance_df, "model_name") | |
| material_name_choices = get_distinct_text_choices(instance_df, "material_name") | |
| env_name_choices = get_distinct_text_choices(instance_df, "env_name") | |
| def filtered_instance_rows( | |
| model_name: str, | |
| material_name: str, | |
| env_name: str, | |
| has_glass: str, | |
| is_generated: str, | |
| transparent: str, | |
| near_light: str, | |
| ) -> tuple[pd.DataFrame, list[dict[str, Any]]]: | |
| filtered = filter_dataframe_advanced( | |
| instance_df, | |
| model_name=model_name, | |
| material_name=material_name, | |
| env_name=env_name, | |
| has_glass=has_glass, | |
| is_generated=is_generated, | |
| transparent=transparent, | |
| near_light=near_light, | |
| ) | |
| shown = filtered.head(MAX_RESULTS).copy() | |
| rows = shown.to_dict(orient="records") | |
| return filtered, rows | |
| def filter_gallery( | |
| model_name: str, | |
| material_name: str, | |
| env_name: str, | |
| has_glass: str, | |
| is_generated: str, | |
| transparent: str, | |
| near_light: str, | |
| ): | |
| filtered, rows = filtered_instance_rows( | |
| model_name=model_name, | |
| material_name=material_name, | |
| env_name=env_name, | |
| has_glass=has_glass, | |
| is_generated=is_generated, | |
| transparent=transparent, | |
| near_light=near_light, | |
| ) | |
| summary = f"Matched **{len(filtered)}** preview instances, showing **{len(rows)}**." | |
| gallery_items = build_instance_gallery_items(rows, preview_df) | |
| slider_empty = gr.update(minimum=1, maximum=1, step=1, value=1, interactive=False) | |
| return summary, gallery_items, rows, {}, None, None, None, None, None, slider_empty, [] | |
| def on_instance_select(rows: list[dict[str, Any]], evt: gr.SelectData): | |
| if not rows: | |
| slider_empty = gr.update(minimum=1, maximum=1, step=1, value=1, interactive=False) | |
| return {}, None, None, None, None, None, slider_empty, [] | |
| idx = evt.index[0] if isinstance(evt.index, tuple) else evt.index | |
| if not isinstance(idx, int) or idx < 0 or idx >= len(rows): | |
| raise IndexError(f"Selected gallery index is out of range: {evt.index!r}") | |
| row = rows[idx] | |
| instance_id = row["instance_id"] | |
| if not isinstance(instance_id, str) or not instance_id.strip(): | |
| raise ValueError(f"Selected instance row has invalid instance_id: {rows[idx]!r}") | |
| logger.info("Loading images for instance: %s", instance_id) | |
| frame_items = load_instance_frames(preview_df, instance_id, max_frames=50) | |
| slider_ready = gr.update( | |
| minimum=1, | |
| maximum=len(frame_items), | |
| step=1, | |
| value=1, | |
| interactive=True, | |
| ) | |
| return row, download_glb(row["glb_path"]), *render_frame_images(frame_items, 1), slider_ready, frame_items | |
| def on_frame_change(frame_idx: float, frame_items: list[dict[str, Any]]): | |
| return render_frame_images(frame_items, frame_idx) | |
| initial_rows = instance_df.head(MAX_RESULTS).to_dict(orient="records") | |
| initial_gallery = build_instance_gallery_items(initial_rows, preview_df) | |
| initial_summary = f"Matched **{len(instance_df)}** preview instances, showing **{len(initial_rows)}**." | |
| with gr.Blocks(title="3DReflecNet Dataset Explorer") as demo: | |
| gr.Markdown("# 3DReflecNet Dataset Explorer") | |
| gr.Markdown( | |
| "Browse the configured preview subset. Select an RGB thumbnail to inspect the instance." | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown(label="model_name", choices=model_name_choices, value=FILTER_ALL) | |
| material_name = gr.Dropdown(label="material_name", choices=material_name_choices, value=FILTER_ALL) | |
| env_name = gr.Dropdown(label="env_name", choices=env_name_choices, value=FILTER_ALL) | |
| with gr.Row(): | |
| has_glass = gr.Dropdown(label="hasGlass", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) | |
| is_generated = gr.Dropdown(label="isGenerated", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) | |
| transparent = gr.Dropdown(label="transparent", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) | |
| near_light = gr.Dropdown(label="near_light", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) | |
| summary = gr.Markdown(initial_summary) | |
| instance_gallery = gr.Gallery( | |
| label="Preview Instances", | |
| value=initial_gallery, | |
| columns=5, | |
| object_fit="contain", | |
| height="auto", | |
| ) | |
| with gr.Row(): | |
| instance_meta = gr.JSON(label="Instance Metadata") | |
| model_viewer = gr.Model3D( | |
| label="3D Preview (GLB)", | |
| clear_color=(0.35, 0.35, 0.38, 1.0), | |
| camera_position=(35, 70, 3.5), | |
| ) | |
| with gr.Row(): | |
| rgb_image = gr.Image(label="RGB", height=360, interactive=False, scale=1, min_width=160) | |
| mask_image = gr.Image(label="Mask", height=360, interactive=False, scale=1, min_width=160) | |
| depth_image = gr.Image(label="Depth", height=360, interactive=False, scale=1, min_width=160) | |
| normal_image = gr.Image(label="Normal", height=360, interactive=False, scale=1, min_width=160) | |
| frame_slider = gr.Slider( | |
| label="Frame", | |
| minimum=1, | |
| maximum=1, | |
| step=1, | |
| value=1, | |
| interactive=False, | |
| ) | |
| instance_state = gr.State(initial_rows) | |
| frame_state = gr.State([]) | |
| filter_inputs = [ | |
| model_name, | |
| material_name, | |
| env_name, | |
| has_glass, | |
| is_generated, | |
| transparent, | |
| near_light, | |
| ] | |
| filter_outputs = [ | |
| summary, | |
| instance_gallery, | |
| instance_state, | |
| instance_meta, | |
| model_viewer, | |
| rgb_image, | |
| mask_image, | |
| depth_image, | |
| normal_image, | |
| frame_slider, | |
| frame_state, | |
| ] | |
| for filter_component in filter_inputs: | |
| filter_component.change( | |
| fn=filter_gallery, | |
| inputs=filter_inputs, | |
| outputs=filter_outputs, | |
| ) | |
| instance_gallery.select( | |
| fn=on_instance_select, | |
| inputs=[instance_state], | |
| outputs=[ | |
| instance_meta, | |
| model_viewer, | |
| rgb_image, | |
| mask_image, | |
| depth_image, | |
| normal_image, | |
| frame_slider, | |
| frame_state, | |
| ], | |
| ) | |
| frame_slider.change( | |
| fn=on_frame_change, | |
| inputs=[frame_slider, frame_state], | |
| outputs=[rgb_image, mask_image, depth_image, normal_image], | |
| ) | |
| return demo | |
| def main() -> None: | |
| setup_logging() | |
| logger.info("DATASET_REPO = %r", DATASET_REPO) | |
| logger.info("HF_TOKEN set = %s, length = %d", HF_TOKEN is not None, len(HF_TOKEN) if HF_TOKEN else 0) | |
| logger.info("Loading preview subset from Hugging Face Hub...") | |
| preview_df = load_preview_dataframe() | |
| instance_df = build_preview_instance_dataframe(preview_df) | |
| logger.info("Loaded %d preview rows for %d preview instance(s).", len(preview_df), len(instance_df)) | |
| app = build_app(instance_df, preview_df) | |
| app.launch() | |
| if __name__ == "__main__": | |
| main() | |