#!/usr/bin/env python3 """HF Space entry point for 3DReflecNet dataset preview. Loads only data/preview/preview.parquet so the Space exposes the configured preview instance subset instead of the full dataset metadata. """ from __future__ import annotations import atexit import io import os import shutil import tempfile from pathlib import Path from typing import Any import gradio as gr import pandas as pd from datasets import load_dataset from huggingface_hub import hf_hub_download from PIL import Image from utils import ( BOOL_FILTER_CHOICES, FILTER_ALL, aggregate_by_model, filter_dataframe_advanced, format_instance_choice, format_model_choice, get_distinct_text_choices, logger, parse_choice_index, require_bool_columns, require_columns, require_text_columns, setup_logging, ) DATASET_REPO = os.environ.get("DATASET_REPO", "3DReflecNet/3DReflecNet") HF_TOKEN = os.environ.get("HF_TOKEN", None) MAX_RESULTS = 300 BOOL_COLUMNS = ["hasGlass", "isGenerated", "transparent", "near_light"] _GLB_CACHE_DIR = Path(tempfile.mkdtemp(prefix="glb_cache_")) atexit.register(shutil.rmtree, str(_GLB_CACHE_DIR), True) # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def load_preview_dataframe() -> pd.DataFrame: """Load the small preview Parquet into memory.""" PREVIEW_COLS = [ "instance_id", "split", "frame_id", "rgb", "mask", "depth_preview", "normal_preview", "main_category", "sub_category", "model_name", "material_name", "env_name", "glb_path", "hasGlass", "isGenerated", "transparent", "near_light", ] ds = load_dataset( DATASET_REPO, data_files="data/preview/preview.parquet", split="train", streaming=False, token=HF_TOKEN, ).select_columns(PREVIEW_COLS) df = pd.DataFrame(list(ds)) require_columns(df, PREVIEW_COLS, "preview parquet") require_text_columns( df, [ "instance_id", "split", "main_category", "sub_category", "model_name", "material_name", "env_name", "glb_path", ], "preview parquet", ) require_bool_columns(df, BOOL_COLUMNS, "preview parquet") if df["frame_id"].isna().any() or not pd.api.types.is_integer_dtype(df["frame_id"]): raise TypeError(f"Expected integer dtype for column 'frame_id' in preview parquet, got {df['frame_id'].dtype}.") for col in ["rgb", "mask", "depth_preview", "normal_preview"]: invalid = df[col].map(lambda value: not isinstance(value, (bytes, bytearray)) or len(value) == 0) if invalid.any(): raise TypeError(f"Expected non-empty binary values for column {col!r} in preview parquet.") return df def build_preview_instance_dataframe(preview_df: pd.DataFrame) -> pd.DataFrame: """Derive one row per preview instance from preview frame rows.""" instance_cols = [ "instance_id", "main_category", "sub_category", "model_name", "material_name", "env_name", "hasGlass", "isGenerated", "transparent", "near_light", "glb_path", ] require_columns(preview_df, instance_cols, "preview parquet") rows: list[dict[str, Any]] = [] for instance_id, group in preview_df.groupby("instance_id", sort=True): row: dict[str, Any] = {} for col in instance_cols: values = group[col].drop_duplicates().tolist() if len(values) != 1: raise ValueError(f"Inconsistent {col!r} values for preview instance {instance_id!r}.") row[col] = values[0] rows.append(row) df = pd.DataFrame(rows, columns=instance_cols) require_text_columns( df, [ "instance_id", "main_category", "sub_category", "model_name", "material_name", "env_name", "glb_path", ], "preview instance dataframe", ) require_bool_columns(df, BOOL_COLUMNS, "preview instance dataframe") if df["glb_path"].map(lambda value: not value.strip()).any(): raise ValueError("Preview instance dataframe contains empty GLB paths.") return df def load_instance_frames( preview_df: pd.DataFrame, instance_id: str, split: str = "train", max_frames: int = 50, ) -> list[dict[str, Any]]: """Load preview image payloads for one instance from preview Parquet.""" selected = preview_df[ (preview_df["instance_id"].astype(str) == str(instance_id)) & (preview_df["split"].astype(str) == split) ].copy() selected = selected.sort_values("frame_id").head(max_frames) rows = selected.to_dict(orient="records") if not rows and split == "train": selected = preview_df[ preview_df["instance_id"].astype(str) == str(instance_id) ].copy() selected = selected.sort_values(["split", "frame_id"]).head(max_frames) rows = selected.to_dict(orient="records") frames: list[dict[str, Any]] = [] for example in rows: frame_id = int(example["frame_id"]) frame_item: dict[str, Any] = {"frame_id": frame_id} for key in ("rgb", "mask", "depth_preview", "normal_preview"): img_bytes = example[key] if not isinstance(img_bytes, (bytes, bytearray)) or not img_bytes: raise TypeError(f"Expected non-empty image bytes for {key} frame {frame_id}.") with Image.open(io.BytesIO(img_bytes)) as img: frame_item[key] = img.copy() frames.append(frame_item) return frames def render_frame_gallery(frame_items: list[dict[str, Any]], frame_index: float) -> list[tuple[Any, str]]: """Render RGB/Mask/Depth/Normal gallery for one selected frame index (1-based).""" if not frame_items: return [] idx = int(round(frame_index)) - 1 idx = max(0, min(idx, len(frame_items) - 1)) selected = frame_items[idx] frame_id = int(selected["frame_id"]) gallery: list[tuple[Any, str]] = [] for key, label in [ ("rgb", "RGB"), ("mask", "Mask"), ("depth_preview", "Depth"), ("normal_preview", "Normal"), ]: gallery.append((selected[key], f"{label} frame_{frame_id:05d}")) return gallery # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def download_glb(glb_path: str) -> str: """Download pre-converted GLB file from HF dataset repo.""" if not glb_path: raise ValueError("GLB path is required.") local = _GLB_CACHE_DIR / Path(glb_path).name if local.exists(): return str(local) downloaded = hf_hub_download( repo_id=DATASET_REPO, filename=glb_path, repo_type="dataset", token=HF_TOKEN, ) shutil.copy2(downloaded, str(local)) logger.info("GLB ready: %s", local) return str(local) def build_stats_markdown(df: pd.DataFrame, preview_df: pd.DataFrame) -> str: """Generate preview subset overview statistics.""" total_instances = len(df) total_models = df["model_name"].nunique() total_frames = len(preview_df) main_cats = df["main_category"].dropna().astype(str) main_dist = main_cats.value_counts().head(10) dist_lines = " | ".join(f"**{cat}**: {cnt}" for cat, cnt in main_dist.items()) return ( f"**Preview Subset** — " f"**{total_instances}** instances, " f"**{total_frames}** frames, " f"**{total_models}** models, " f"**{main_cats.nunique()}** main categories\n\n" f"Distribution: {dist_lines}" ) # --------------------------------------------------------------------------- # App builder # --------------------------------------------------------------------------- def build_app(instance_df: pd.DataFrame, preview_df: pd.DataFrame) -> gr.Blocks: model_name_choices = get_distinct_text_choices(instance_df, "model_name") material_name_choices = get_distinct_text_choices(instance_df, "material_name") env_name_choices = get_distinct_text_choices(instance_df, "env_name") model_display_cols = [ "model_name", "material_name", "env_name", "hasGlass", "isGenerated", "transparent", "near_light", "instance_count", "instance_ids", ] instance_display_cols = [ "instance_id", "model_name", "material_name", "env_name", "hasGlass", "isGenerated", "transparent", "near_light", ] model_extra_cols = [ "material_name", "env_name", "hasGlass", "isGenerated", "transparent", "near_light", "glb_path", ] stats_md = build_stats_markdown(instance_df, preview_df) # ---- 3D Viewer callbacks ---- def search_models( model_name: str, material_name: str, env_name: str, has_glass: str, is_generated: str, transparent: str, near_light: str, ): filtered = filter_dataframe_advanced( instance_df, model_name=model_name, material_name=material_name, env_name=env_name, has_glass=has_glass, is_generated=is_generated, transparent=transparent, near_light=near_light, ) aggregated = aggregate_by_model(filtered, extra_columns=model_extra_cols) shown = aggregated.head(MAX_RESULTS).copy() rows = shown.to_dict(orient="records") choices = [format_model_choice(i, r) for i, r in enumerate(rows)] selected = choices[0] if choices else None summary = ( f"Matched **{len(aggregated)}** models, showing **{len(rows)}**. " f"Preview instances: **{len(filtered)}**." ) table = shown[model_display_cols] if not shown.empty else pd.DataFrame(columns=model_display_cols) meta = rows[0] if rows else {} return summary, table, gr.update(choices=choices, value=selected), rows, meta def on_model_select(choice: str, rows: list[dict[str, Any]]): if not choice or not rows: return {} idx = parse_choice_index(choice, len(rows)) if idx is None: return {} return rows[idx] def on_load_3d(rows: list[dict[str, Any]], choice: str): if not choice or not rows: return None idx = parse_choice_index(choice, len(rows)) if idx is None: return None glb = rows[idx]["glb_path"] logger.info("on_load_3d: glb_path=%r", glb) if not isinstance(glb, str) or not glb.strip(): raise ValueError(f"Selected model row does not contain a GLB path: {rows[idx]!r}") return download_glb(glb) # ---- Image Viewer callbacks ---- def search_instances( model_name: str, material_name: str, env_name: str, has_glass: str, is_generated: str, transparent: str, near_light: str, ): filtered = filter_dataframe_advanced( instance_df, model_name=model_name, material_name=material_name, env_name=env_name, has_glass=has_glass, is_generated=is_generated, transparent=transparent, near_light=near_light, ) shown = filtered.head(MAX_RESULTS).copy() rows = shown[instance_display_cols].to_dict(orient="records") choices = [format_instance_choice(i, r) for i, r in enumerate(rows)] selected = choices[0] if choices else None summary = f"Matched **{len(filtered)}** preview instances, showing **{len(rows)}**." table = shown[instance_display_cols] if not shown.empty else pd.DataFrame(columns=instance_display_cols) return summary, table, gr.update(choices=choices, value=selected), rows def on_load_images(rows: list[dict[str, Any]], choice: str): slider_empty = gr.update(minimum=1, maximum=1, step=1, value=1, interactive=False) if not choice or not rows: return [], slider_empty, [] idx = parse_choice_index(choice, len(rows)) if idx is None: return [], slider_empty, [] instance_id = rows[idx]["instance_id"] if not isinstance(instance_id, str) or not instance_id.strip(): raise ValueError(f"Selected instance row has invalid instance_id: {rows[idx]!r}") logger.info("Loading images for instance: %s", instance_id) frame_items = load_instance_frames(preview_df, instance_id, split="train", max_frames=50) if not frame_items: return [], slider_empty, [] slider_ready = gr.update( minimum=1, maximum=len(frame_items), step=1, value=1, interactive=True, ) return render_frame_gallery(frame_items, 1), slider_ready, frame_items def on_frame_change(frame_idx: float, frame_items: list[dict[str, Any]]): return render_frame_gallery(frame_items, frame_idx) # ---- UI ---- with gr.Blocks(title="3DReflecNet Dataset Explorer") as demo: gr.Markdown("# 3DReflecNet Dataset Explorer") gr.Markdown( "Filter the configured preview subset by model/material/environment dropdowns and boolean scene tags." ) gr.Markdown(stats_md) with gr.Tabs(): # === Tab 1: 3D Viewer === with gr.TabItem("3D Viewer"): with gr.Row(): m3d_model_name = gr.Dropdown(label="model_name", choices=model_name_choices, value=FILTER_ALL) m3d_material_name = gr.Dropdown(label="material_name", choices=material_name_choices, value=FILTER_ALL) m3d_env_name = gr.Dropdown(label="env_name", choices=env_name_choices, value=FILTER_ALL) with gr.Row(): m3d_has_glass = gr.Dropdown(label="hasGlass", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) m3d_is_generated = gr.Dropdown(label="isGenerated", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) m3d_transparent = gr.Dropdown(label="transparent", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) m3d_near_light = gr.Dropdown(label="near_light", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) m3d_btn = gr.Button("Search", variant="primary") m3d_summary = gr.Markdown("Click **Search** to browse models.") m3d_table = gr.Dataframe(headers=model_display_cols, interactive=False, wrap=True) m3d_select = gr.Dropdown(label="Select model", choices=[], value=None) m3d_meta = gr.JSON(label="Model Metadata") m3d_load_btn = gr.Button("Load 3D Preview", variant="primary") m3d_viewer = gr.Model3D( label="3D Preview (GLB)", clear_color=(0.35, 0.35, 0.38, 1.0), camera_position=(35, 70, 3.5), ) m3d_state = gr.State([]) m3d_btn.click( fn=search_models, inputs=[ m3d_model_name, m3d_material_name, m3d_env_name, m3d_has_glass, m3d_is_generated, m3d_transparent, m3d_near_light, ], outputs=[m3d_summary, m3d_table, m3d_select, m3d_state, m3d_meta], ) m3d_select.change( fn=on_model_select, inputs=[m3d_select, m3d_state], outputs=[m3d_meta], ) m3d_load_btn.click( fn=on_load_3d, inputs=[m3d_state, m3d_select], outputs=[m3d_viewer], ) # === Tab 2: Image Viewer === with gr.TabItem("Image Viewer"): with gr.Row(): img_model_name = gr.Dropdown(label="model_name", choices=model_name_choices, value=FILTER_ALL) img_material_name = gr.Dropdown(label="material_name", choices=material_name_choices, value=FILTER_ALL) img_env_name = gr.Dropdown(label="env_name", choices=env_name_choices, value=FILTER_ALL) with gr.Row(): img_has_glass = gr.Dropdown(label="hasGlass", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) img_is_generated = gr.Dropdown(label="isGenerated", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) img_transparent = gr.Dropdown(label="transparent", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) img_near_light = gr.Dropdown(label="near_light", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) img_btn = gr.Button("Search", variant="primary") img_summary = gr.Markdown("Click **Search** to browse instances.") img_table = gr.Dataframe(headers=instance_display_cols, interactive=False, wrap=True) img_select = gr.Dropdown(label="Select instance", choices=[], value=None) img_load_btn = gr.Button("Load Instance Frames", variant="primary") img_gallery = gr.Gallery(label="Frame Images", columns=4, rows=1, object_fit="contain", height="auto") img_frame_slider = gr.Slider( label="Frame", minimum=1, maximum=1, step=1, value=1, interactive=False, ) img_state = gr.State([]) img_frame_state = gr.State([]) img_btn.click( fn=search_instances, inputs=[ img_model_name, img_material_name, img_env_name, img_has_glass, img_is_generated, img_transparent, img_near_light, ], outputs=[img_summary, img_table, img_select, img_state], ) img_load_btn.click( fn=on_load_images, inputs=[img_state, img_select], outputs=[img_gallery, img_frame_slider, img_frame_state], ) img_frame_slider.change( fn=on_frame_change, inputs=[img_frame_slider, img_frame_state], outputs=[img_gallery], ) return demo def main() -> None: setup_logging() logger.info("DATASET_REPO = %r", DATASET_REPO) logger.info("HF_TOKEN set = %s, length = %d", HF_TOKEN is not None, len(HF_TOKEN) if HF_TOKEN else 0) logger.info("Loading preview subset from Hugging Face Hub...") preview_df = load_preview_dataframe() instance_df = build_preview_instance_dataframe(preview_df) logger.info("Loaded %d preview rows for %d preview instance(s).", len(preview_df), len(instance_df)) app = build_app(instance_df, preview_df) app.launch() if __name__ == "__main__": main()