| import os |
| import sys |
| import time |
| import threading |
| import argparse |
| import tempfile |
| import shutil |
| from typing import Generator, Optional, Tuple |
| import logging |
|
|
| import gradio as gr |
| import spaces |
| from huggingface_hub import hf_hub_download |
| import torch |
| from PIL import Image |
|
|
| |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
| from gradio_models import GradioGaussianSplatting2D, StreamingResults |
| from utils.misc_utils import load_cfg |
| from main import get_log_dir |
|
|
|
|
| class TrainingState: |
| """Manages the state of training sessions""" |
|
|
| def __init__(self): |
| self.is_training = False |
| self.training_thread = None |
| self.model = None |
| self.temp_dir = None |
| self.results = StreamingResults() |
|
|
| def reset(self): |
| self.is_training = False |
| if self.temp_dir and os.path.exists(self.temp_dir): |
| shutil.rmtree(self.temp_dir) |
| self.temp_dir = None |
| self.results = StreamingResults() |
|
|
|
|
| |
| training_state = TrainingState() |
|
|
|
|
| def ensure_models_available(): |
| """Download models from HuggingFace if they're not available locally""" |
| required_files = [ |
| "models/emlnet/res_decoder.pth", |
| "models/emlnet/res_imagenet.pth", |
| "models/emlnet/res_places.pth", |
| "models/torch/checkpoints/alexnet-owt-7be5be79.pth", |
| ] |
|
|
| |
| all_files_exist = all(os.path.exists(file_path) for file_path in required_files) |
|
|
| if not all_files_exist: |
| print("📥 Downloading model files from HuggingFace...") |
| try: |
| |
| os.makedirs("models", exist_ok=True) |
|
|
| |
| model_files_remote = [ |
| "emlnet/res_decoder.pth", |
| "emlnet/res_imagenet.pth", |
| "emlnet/res_places.pth", |
| "torch/checkpoints/alexnet-owt-7be5be79.pth", |
| ] |
|
|
| model_files_local = [ |
| "models/emlnet/res_decoder.pth", |
| "models/emlnet/res_imagenet.pth", |
| "models/emlnet/res_places.pth", |
| "models/torch/checkpoints/alexnet-owt-7be5be79.pth", |
| ] |
|
|
| for remote_file, local_file in zip(model_files_remote, model_files_local): |
| if not os.path.exists(local_file): |
| |
| os.makedirs(os.path.dirname(local_file), exist_ok=True) |
|
|
| |
| print(f"📥 Downloading {remote_file} -> {local_file}...") |
| downloaded_path = hf_hub_download( |
| repo_id="blanchon/image-gs-models-utils", |
| filename=remote_file, |
| repo_type="model", |
| ) |
|
|
| |
| shutil.copy2(downloaded_path, local_file) |
|
|
| print("✅ Model files downloaded successfully!") |
| except Exception as e: |
| print(f"❌ Failed to download model files: {e}") |
| print("⚠️ The app may not work properly without these model files.") |
| else: |
| print("✅ Model files are already available locally.") |
|
|
|
|
| |
| ensure_models_available() |
| torch.hub.set_dir("models/torch") |
|
|
|
|
| def create_args_from_config( |
| image_path: str, |
| exp_name: str, |
| num_gaussians: int, |
| quantize: bool, |
| pos_bits: int, |
| scale_bits: int, |
| rot_bits: int, |
| feat_bits: int, |
| init_mode: str, |
| init_random_ratio: float, |
| max_steps: int, |
| vis_gaussians: bool, |
| save_image_steps: int, |
| l1_loss_ratio: float, |
| l2_loss_ratio: float, |
| ssim_loss_ratio: float, |
| pos_lr: float, |
| scale_lr: float, |
| rot_lr: float, |
| feat_lr: float, |
| disable_lr_schedule: bool, |
| disable_prog_optim: bool, |
| ) -> argparse.Namespace: |
| """Create arguments object from Gradio inputs""" |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser) |
| args = parser.parse_args([]) |
|
|
| |
| args.input_path = image_path |
| args.exp_name = exp_name |
| args.num_gaussians = num_gaussians |
| args.quantize = quantize |
| args.pos_bits = pos_bits |
| args.scale_bits = scale_bits |
| args.rot_bits = rot_bits |
| args.feat_bits = feat_bits |
| args.init_mode = init_mode |
| args.init_random_ratio = init_random_ratio |
| args.max_steps = max_steps |
| args.vis_gaussians = vis_gaussians |
| args.save_image_steps = save_image_steps |
| args.l1_loss_ratio = l1_loss_ratio |
| args.l2_loss_ratio = l2_loss_ratio |
| args.ssim_loss_ratio = ssim_loss_ratio |
| args.pos_lr = pos_lr |
| args.scale_lr = scale_lr |
| args.rot_lr = rot_lr |
| args.feat_lr = feat_lr |
| args.disable_lr_schedule = disable_lr_schedule |
| args.disable_prog_optim = disable_prog_optim |
| args.eval = False |
|
|
| |
| args.log_dir = get_log_dir(args) |
|
|
| return args |
|
|
|
|
| @spaces.GPU(duration=300) |
| def train_model(args: argparse.Namespace) -> None: |
| """Training function that runs with ZeroGPU allocation""" |
| try: |
| |
| training_state.model = GradioGaussianSplatting2D(args, training_state.results) |
|
|
| |
| training_state.model.optimize() |
|
|
| except Exception as e: |
| import traceback |
|
|
| training_state.results.training_logs.append(f"ERROR: {str(e)}") |
| training_state.results.training_logs.append( |
| f"TRACEBACK: {traceback.format_exc()}" |
| ) |
| logging.error(f"Training failed: {str(e)}") |
| logging.error(f"TRACEBACK: {traceback.format_exc()}") |
| finally: |
| training_state.is_training = False |
|
|
|
|
| def start_training_and_stream( |
| image_file, |
| exp_name: str, |
| num_gaussians: int, |
| quantize: bool, |
| pos_bits: int, |
| scale_bits: int, |
| rot_bits: int, |
| feat_bits: int, |
| init_mode: str, |
| init_random_ratio: float, |
| max_steps: int, |
| vis_gaussians: bool, |
| save_image_steps: int, |
| l1_loss_ratio: float, |
| l2_loss_ratio: float, |
| ssim_loss_ratio: float, |
| pos_lr: float, |
| scale_lr: float, |
| rot_lr: float, |
| feat_lr: float, |
| disable_lr_schedule: bool, |
| disable_prog_optim: bool, |
| ) -> Generator[ |
| Tuple[ |
| str, |
| str, |
| Optional[Image.Image], |
| Optional[Image.Image], |
| Optional[Image.Image], |
| bool, |
| bool, |
| ], |
| None, |
| None, |
| ]: |
| """Start training and stream progress with images""" |
|
|
| if training_state.is_training: |
| yield ( |
| "Training is already in progress!", |
| "", |
| None, |
| None, |
| None, |
| False, |
| True, |
| ) |
| return |
|
|
| if image_file is None: |
| yield ( |
| "Please upload an image first!", |
| "", |
| None, |
| None, |
| None, |
| True, |
| False, |
| ) |
| return |
|
|
| try: |
| |
| training_state.reset() |
|
|
| |
| training_state.temp_dir = tempfile.mkdtemp() |
|
|
| |
| image_path = os.path.join(training_state.temp_dir, "input_image.png") |
| image_file.save(image_path) |
|
|
| |
| args = create_args_from_config( |
| image_path=image_path, |
| exp_name=exp_name, |
| num_gaussians=num_gaussians, |
| quantize=quantize, |
| pos_bits=pos_bits, |
| scale_bits=scale_bits, |
| rot_bits=rot_bits, |
| feat_bits=feat_bits, |
| init_mode=init_mode, |
| init_random_ratio=init_random_ratio, |
| max_steps=max_steps, |
| vis_gaussians=vis_gaussians, |
| save_image_steps=save_image_steps, |
| l1_loss_ratio=l1_loss_ratio, |
| l2_loss_ratio=l2_loss_ratio, |
| ssim_loss_ratio=ssim_loss_ratio, |
| pos_lr=pos_lr, |
| scale_lr=scale_lr, |
| rot_lr=rot_lr, |
| feat_lr=feat_lr, |
| disable_lr_schedule=disable_lr_schedule, |
| disable_prog_optim=disable_prog_optim, |
| ) |
|
|
| |
| args.data_root = training_state.temp_dir |
| args.input_path = "input_image.png" |
|
|
| |
| training_state.is_training = True |
| training_state.training_thread = threading.Thread( |
| target=train_model, args=(args,) |
| ) |
| training_state.training_thread.start() |
|
|
| |
| yield ( |
| "Training started! Check the progress below.", |
| "Initializing training...", |
| None, |
| None, |
| None, |
| False, |
| True, |
| ) |
|
|
| |
| while training_state.is_training or not training_state.results.is_complete: |
| |
| if ( |
| not training_state.is_training |
| and training_state.training_thread |
| and training_state.training_thread.is_alive() |
| ): |
| |
| training_state.results.training_logs.append( |
| "🛑 Training stopped by user request" |
| ) |
| break |
|
|
| |
| if training_state.results.training_logs: |
| logs_text = "\n".join(training_state.results.training_logs) |
|
|
| |
| if training_state.results.step > 0: |
| |
| if training_state.results.step > training_state.results.total_steps: |
| break |
|
|
| metrics = training_state.results.metrics |
| status_line = ( |
| f"\nCurrent: Step {training_state.results.step}/{training_state.results.total_steps} | " |
| f"PSNR: {metrics['psnr']:.2f} | SSIM: {metrics['ssim']:.4f} | " |
| f"Loss: {metrics['loss']:.4f}" |
| ) |
| logs_text += status_line |
| else: |
| logs_text = "Waiting for training to start..." |
|
|
| |
| initialization_map = training_state.results.initialization_map |
| current_render = training_state.results.current_render |
| current_gaussian_id = training_state.results.current_gaussian_id |
|
|
| |
| current_step = training_state.results.step |
| if training_state.results.is_complete: |
| status = "✅ Training completed successfully!" |
| start_btn_interactive = True |
| stop_btn_interactive = False |
| elif not training_state.is_training: |
| status = "⏹️ Training stopped." |
| start_btn_interactive = True |
| stop_btn_interactive = False |
| else: |
| status = f"🔄 Training in progress... Step {current_step}/{training_state.results.total_steps}" |
| start_btn_interactive = False |
| stop_btn_interactive = True |
|
|
| |
| yield ( |
| status, |
| logs_text, |
| initialization_map, |
| current_render, |
| current_gaussian_id, |
| start_btn_interactive, |
| stop_btn_interactive, |
| ) |
|
|
| |
| if training_state.results.is_complete or not training_state.is_training: |
| break |
| if current_step > training_state.results.total_steps: |
| break |
|
|
| time.sleep(0.5) |
|
|
| except Exception as e: |
| training_state.reset() |
| yield ( |
| f"Failed to start training: {str(e)}", |
| "", |
| None, |
| None, |
| None, |
| True, |
| False, |
| ) |
|
|
|
|
| def stop_training() -> str: |
| """Stop the current training""" |
| if not training_state.is_training: |
| return "No training in progress." |
|
|
| training_state.is_training = False |
| training_state.results.training_logs.append( |
| "🛑 STOP: Training stop requested by user..." |
| ) |
|
|
| |
| if training_state.model: |
| training_state.model.stop_requested = True |
|
|
| return "Training stop requested. Will complete current step and stop." |
|
|
|
|
| def get_final_results() -> Tuple[Optional[Image.Image], Optional[str]]: |
| """Get final training results""" |
| final_render = training_state.results.final_render |
| checkpoint_path = training_state.results.final_checkpoint_path |
| return final_render, checkpoint_path |
|
|
|
|
| def browse_step_results( |
| step: int, |
| ) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: |
| """Browse results from a specific training step""" |
| if not training_state.results.is_complete: |
| return None, None |
|
|
| |
| available_steps = list(training_state.results.step_renders.keys()) |
| if not available_steps: |
| return None, None |
|
|
| closest_step = min(available_steps, key=lambda x: abs(x - step)) |
|
|
| render_img = training_state.results.step_renders.get(closest_step) |
| gaussian_id_img = training_state.results.step_gaussian_ids.get(closest_step) |
|
|
| return render_img, gaussian_id_img |
|
|
|
|
| def update_step_slider_after_training() -> gr.Slider: |
| """Update step slider range and enable it after training completes""" |
| if not training_state.results.is_complete: |
| return gr.Slider( |
| minimum=0, |
| maximum=10000, |
| value=0, |
| step=100, |
| label="Browse Training Steps", |
| info="Training not complete yet", |
| interactive=False, |
| ) |
|
|
| available_steps = list(training_state.results.step_renders.keys()) |
| if not available_steps: |
| return gr.Slider( |
| minimum=0, |
| maximum=10000, |
| value=0, |
| step=100, |
| label="Browse Training Steps", |
| info="No training steps available", |
| interactive=False, |
| ) |
|
|
| max_step = max(available_steps) |
| min_step = min(available_steps) |
| |
| if len(available_steps) > 1: |
| step_size = available_steps[1] - available_steps[0] |
| else: |
| step_size = 100 |
|
|
| return gr.Slider( |
| minimum=min_step, |
| maximum=max_step, |
| value=max_step, |
| step=step_size, |
| label="Browse Training Steps", |
| info=f"Browse results from steps {min_step}-{max_step} (interactive)", |
| interactive=True, |
| ) |
|
|
|
|
| |
| with gr.Blocks(title="Image-GS: 2D Gaussian Splatting", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # Image-GS: Content-Adaptive Image Representation via 2D Gaussians |
| |
| Upload an image and configure parameters to train a 2D Gaussian Splatting representation. |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("## Configuration") |
|
|
| |
| image_input = gr.Image( |
| label="Input Image", |
| type="pil", |
| height=300, |
| sources=["upload"], |
| show_label=True, |
| ) |
|
|
| |
| with gr.Group(): |
| gr.Markdown("### Basic Parameters") |
| exp_name = gr.Textbox( |
| label="Experiment Name", |
| value="gradio_experiment", |
| info="Name for this training run", |
| ) |
| num_gaussians = gr.Slider( |
| minimum=100, |
| maximum=50000, |
| value=10000, |
| step=1000, |
| label="Number of Gaussians", |
| info="Number of Gaussians (for compression rate control). More = higher quality but slower training", |
| ) |
| max_steps = gr.Slider( |
| minimum=100, |
| maximum=20000, |
| value=10000, |
| step=100, |
| label="Maximum Training Steps", |
| info="Maximum number of optimization steps. Default: 10000", |
| ) |
|
|
| |
| with gr.Group(): |
| gr.Markdown("### Quantization") |
| quantize = gr.Checkbox( |
| label="Enable Quantization", |
| value=False, |
| info="Enable bit precision control of Gaussian parameters. Reduces memory usage.", |
| ) |
| with gr.Row(): |
| pos_bits = gr.Slider( |
| 4, |
| 32, |
| 16, |
| step=1, |
| label="Position Bits", |
| info="Bit precision of individual coordinate dimension", |
| ) |
| scale_bits = gr.Slider( |
| 4, |
| 32, |
| 16, |
| step=1, |
| label="Scale Bits", |
| info="Bit precision of individual scale dimension", |
| ) |
| with gr.Row(): |
| rot_bits = gr.Slider( |
| 4, |
| 32, |
| 16, |
| step=1, |
| label="Rotation Bits", |
| info="Bit precision of Gaussian orientation angle", |
| ) |
| feat_bits = gr.Slider( |
| 4, |
| 32, |
| 16, |
| step=1, |
| label="Feature Bits", |
| info="Bit precision of individual feature dimension", |
| ) |
|
|
| |
| with gr.Group(): |
| gr.Markdown("### Initialization") |
| init_mode = gr.Radio( |
| choices=["gradient", "saliency", "random"], |
| value="saliency", |
| label="Initialization Mode", |
| info="Gaussian position initialization mode. Gradient uses image gradients, saliency uses attention maps.", |
| ) |
| init_random_ratio = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.3, |
| step=0.1, |
| label="Random Ratio", |
| info="Ratio of Gaussians with randomly initialized position (default: 0.3)", |
| ) |
|
|
| |
| with gr.Accordion("Advanced Parameters", open=False): |
| |
| gr.Markdown("#### Loss Weights") |
| with gr.Row(): |
| l1_loss_ratio = gr.Slider(0.0, 2.0, 1.0, step=0.1, label="L1 Loss") |
| l2_loss_ratio = gr.Slider(0.0, 2.0, 0.0, step=0.1, label="L2 Loss") |
| ssim_loss_ratio = gr.Slider( |
| 0.0, 1.0, 0.1, step=0.01, label="SSIM Loss" |
| ) |
|
|
| |
| gr.Markdown("#### Learning Rates") |
| with gr.Row(): |
| pos_lr = gr.Number(value=5e-4, label="Position LR", precision=6) |
| scale_lr = gr.Number(value=2e-3, label="Scale LR", precision=6) |
| with gr.Row(): |
| rot_lr = gr.Number(value=2e-3, label="Rotation LR", precision=6) |
| feat_lr = gr.Number(value=5e-3, label="Feature LR", precision=6) |
|
|
| |
| gr.Markdown("#### Optimization") |
| disable_lr_schedule = gr.Checkbox( |
| label="Disable LR Schedule", |
| value=False, |
| info="Keep learning rate constant", |
| ) |
| disable_prog_optim = gr.Checkbox( |
| label="Disable Progressive Optimization", |
| value=False, |
| info="Don't add Gaussians during training", |
| ) |
|
|
| |
| with gr.Group(): |
| gr.Markdown("### Visualization") |
| vis_gaussians = gr.Checkbox( |
| label="Visualize Gaussians", |
| value=True, |
| info="Visualize Gaussians during optimization (default: True)", |
| ) |
| save_image_steps = gr.Slider( |
| minimum=200, |
| maximum=10000, |
| value=200, |
| step=100, |
| label="Save Image Every N Steps", |
| info="Frequency of rendering intermediate results during optimization (default: 100)", |
| ) |
|
|
| |
| with gr.Row(): |
| start_btn = gr.Button("Start Training", variant="primary", size="lg") |
| stop_btn = gr.Button("Stop Training", variant="stop", size="lg") |
|
|
| status_text = gr.Textbox(label="Status", interactive=False, lines=2) |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("## Training Progress") |
|
|
| |
| progress_logs = gr.Textbox( |
| label="Training Logs", |
| lines=10, |
| max_lines=15, |
| interactive=False, |
| autoscroll=True, |
| ) |
|
|
| |
| gr.Markdown("### Initialization Map") |
| initialization_map = gr.Image( |
| label="Initialization Map", |
| type="pil", |
| height=200, |
| ) |
|
|
| |
| gr.Markdown("### Current Training Results") |
| with gr.Row(): |
| current_render = gr.Image( |
| label="Current Render", |
| type="pil", |
| height=300, |
| show_label=True, |
| show_download_button=True, |
| ) |
| current_gaussian_id = gr.Image( |
| label="Gaussian ID", |
| type="pil", |
| height=300, |
| show_label=True, |
| show_download_button=True, |
| ) |
|
|
| |
| step_slider = gr.Slider( |
| minimum=0, |
| maximum=10000, |
| value=0, |
| step=100, |
| label="Browse Training Steps", |
| info="Slide to view results from different training steps (disabled during training)", |
| interactive=False, |
| ) |
|
|
| gr.Markdown("## Final Results") |
| with gr.Row(): |
| final_render = gr.Image(label="Final Render", type="pil", height=300) |
| final_checkpoint = gr.File(label="Download Final Checkpoint (.pt)") |
|
|
| |
| with gr.Row(): |
| results_btn = gr.Button("Load Final Results", size="lg") |
| enable_slider_btn = gr.Button( |
| "Enable Step Browsing", size="lg", variant="secondary" |
| ) |
|
|
| |
| start_btn.click( |
| fn=start_training_and_stream, |
| inputs=[ |
| image_input, |
| exp_name, |
| num_gaussians, |
| quantize, |
| pos_bits, |
| scale_bits, |
| rot_bits, |
| feat_bits, |
| init_mode, |
| init_random_ratio, |
| max_steps, |
| vis_gaussians, |
| save_image_steps, |
| l1_loss_ratio, |
| l2_loss_ratio, |
| ssim_loss_ratio, |
| pos_lr, |
| scale_lr, |
| rot_lr, |
| feat_lr, |
| disable_lr_schedule, |
| disable_prog_optim, |
| ], |
| outputs=[ |
| status_text, |
| progress_logs, |
| initialization_map, |
| current_render, |
| current_gaussian_id, |
| start_btn, |
| stop_btn, |
| ], |
| ) |
|
|
| stop_btn.click(fn=stop_training, outputs=status_text) |
|
|
| results_btn.click(fn=get_final_results, outputs=[final_render, final_checkpoint]) |
|
|
| enable_slider_btn.click(fn=update_step_slider_after_training, outputs=[step_slider]) |
|
|
| step_slider.change( |
| fn=browse_step_results, |
| inputs=[step_slider], |
| outputs=[current_render, current_gaussian_id], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|