| """ |
| Training and evaluation script for HoHo wireframe prediction model. |
| This script loads the HoHo25k dataset, processes samples through a wireframe prediction pipeline |
| using PointNet models, and evaluates performance using HSS, F1, and IoU metrics. It supports |
| configurable thresholds, visualization of results, and saves detailed performance metrics to files. |
| Key features: |
| - Command-line argument support for model configuration |
| - PointNet-based vertex and edge prediction |
| - Real-time performance monitoring and visualization |
| - Comprehensive metric evaluation and result logging |
| - Support for CUDA acceleration when available |
| """ |
| from datasets import load_dataset |
| from hoho2025.vis import plot_all_modalities |
| from hoho2025.viz3d import * |
| import pycolmap |
| import tempfile,zipfile |
| import io |
| import open3d as o3d |
| import os |
| import argparse |
| import numpy as np |
|
|
| from visu import plot_reconstruction_local, plot_wireframe_local, plot_bpo_cameras_from_entry_local, _plotly_rgb_to_normalized_o3d_color |
| from utils import read_colmap_rec, empty_solution |
|
|
| |
| from hoho2025.metric_helper import hss |
| from predict import predict_wireframe, predict_wireframe_old |
| from tqdm import tqdm |
| from fast_pointnet_v2 import load_pointnet_model |
| from fast_pointnet_class import load_pointnet_model as load_pointnet_class_model |
| import torch |
| import time |
|
|
| |
| parser = argparse.ArgumentParser(description="Train and evaluate HoHo model with custom config.") |
| parser.add_argument('--vertex_threshold', type=float, default=0.59, help='Vertex threshold for prediction.') |
| parser.add_argument('--edge_threshold', type=float, default=0.65, help='Edge threshold for prediction.') |
| parser.add_argument('--only_predicted_connections', type=lambda x: (str(x).lower() == 'true'), default=True, help='Use only predicted connections (True/False).') |
| parser.add_argument('--max_samples', type=int, default=50000, help='Maximum number of samples to process.') |
| parser.add_argument('--results_dir', type=str, default="results", help='Directory to save result files.') |
|
|
|
|
| args = parser.parse_args() |
|
|
| |
| config = { |
| 'vertex_threshold': args.vertex_threshold, |
| 'edge_threshold': args.edge_threshold, |
| 'only_predicted_connections': args.only_predicted_connections |
| } |
| print(f"Running with configuration: {config}") |
|
|
| |
| os.makedirs(args.results_dir, exist_ok=True) |
|
|
|
|
| ds = load_dataset("usm3d/hoho25k", cache_dir="YOUR_CACHE_DIR_PATH/hoho25k/", trust_remote_code=True) |
| |
| |
| |
| scores_hss = [] |
| scores_f1 = [] |
| scores_iou = [] |
|
|
| show_visu = True |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| pnet_model = load_pointnet_model(model_path="pnet.pth", device=device, predict_score=True) |
| |
| |
|
|
| |
| |
| pnet_class_model = load_pointnet_class_model(model_path="pnet_class.pth", device=device) |
| |
|
|
| |
| voxel_model = None |
|
|
|
|
| idx = 0 |
| prediction_times = [] |
| for a in tqdm(ds['train'], desc="Processing dataset"): |
| |
| |
| |
| try: |
| start_time = time.time() |
| pred_vertices, pred_edges = predict_wireframe(a.copy(), pnet_model, voxel_model, pnet_class_model, config) |
| |
| end_time = time.time() |
| prediction_time = end_time - start_time |
| prediction_times.append(prediction_time) |
| if prediction_times: |
| mean_time = np.mean(prediction_times) |
| print(f"Prediction time: {prediction_time:.4f} seconds, Mean time: {mean_time:.4f} seconds") |
| else: |
| print(f"Prediction time: {prediction_time:.4f} seconds") |
| except Exception as e: |
| print(f"Error during prediction: {e}") |
| pred_vertices, pred_edges = empty_solution() |
|
|
| score = hss(pred_vertices, pred_edges, a['wf_vertices'], a['wf_edges'], vert_thresh=0.5, edge_thresh=0.5) |
| print(f"Score: {score}") |
| scores_hss.append(score.hss) |
| scores_f1.append(score.f1) |
| scores_iou.append(score.iou) |
|
|
| if show_visu: |
| colmap = read_colmap_rec(a['colmap_binary']) |
| pcd, geometries = plot_reconstruction_local(None, colmap, points=True, cameras=True, crop_outliers=True) |
| wireframe = plot_wireframe_local(None, a['wf_vertices'], a['wf_edges'], a['wf_classifications']) |
| |
| bpo_cams = plot_bpo_cameras_from_entry_local(None, a) |
|
|
| visu_all = [pcd] + geometries + wireframe + bpo_cams |
| o3d.visualization.draw_geometries(visu_all, window_name=f"3D Reconstruction - HSS: {score.hss:.4f}, F1: {score.f1:.4f}, IoU: {score.iou:.4f}") |
|
|
| idx += 1 |
| if idx >= args.max_samples: |
| print(f"Reached max_samples limit: {args.max_samples}") |
| break |
|
|
| for i in range(10): |
| print("END OF DATASET") |
|
|
| mean_hss_val = np.mean(scores_hss) if scores_hss else 0.0 |
| mean_f1_val = np.mean(scores_f1) if scores_f1 else 0.0 |
| mean_iou_val = np.mean(scores_iou) if scores_iou else 0.0 |
|
|
|
|
| print(f"Mean HSS: {mean_hss_val:.4f}") |
| print(f"Mean F1: {mean_f1_val:.4f}") |
| print(f"Mean IoU: {mean_iou_val:.4f}") |
| print(f"Final Config: {config}") |
| if prediction_times: |
| print(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds") |
|
|
|
|
| |
| vt_str = str(config['vertex_threshold']).replace('.', 'p') |
| et_str = str(config['edge_threshold']).replace('.', 'p') |
| opc_str = str(config['only_predicted_connections']) |
|
|
| results_filename = f"results_vt{vt_str}_et{et_str}_opc{opc_str}_samples{args.max_samples}.txt" |
| results_filepath = os.path.join(args.results_dir, results_filename) |
|
|
| with open(results_filepath, 'w') as f: |
| f.write(f"Configuration: {config}\n") |
| f.write(f"Max Samples Processed: {args.max_samples}\n") |
| f.write(f"Mean HSS: {mean_hss_val:.4f}\n") |
| f.write(f"Mean F1: {mean_f1_val:.4f}\n") |
| f.write(f"Mean IoU: {mean_iou_val:.4f}\n") |
| if prediction_times: |
| f.write(f"Overall Mean Prediction Time: {np.mean(prediction_times):.4f} seconds\n") |
| f.write("\nIndividual HSS Scores:\n") |
| for s_hss in scores_hss: |
| f.write(f"{s_hss:.4f}\n") |
| f.write("\nIndividual F1 Scores:\n") |
| for s_f1 in scores_f1: |
| f.write(f"{s_f1:.4f}\n") |
| f.write("\nIndividual IoU Scores:\n") |
| for s_iou in scores_iou: |
| f.write(f"{s_iou:.4f}\n") |
|
|
|
|
| print(f"Results saved to {results_filepath}") |
|
|
|
|