| """Utilities for persisting and aggregating GIFT-Eval results.""" |
|
|
| import argparse |
| import csv |
| import glob |
| import logging |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
| from src.gift_eval.constants import ( |
| ALL_DATASETS, |
| DATASET_PROPERTIES, |
| MED_LONG_DATASETS, |
| PRETTY_NAMES, |
| STANDARD_METRIC_NAMES, |
| ) |
| from src.gift_eval.core import DatasetMetadata, EvaluationItem |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _ensure_results_csv(csv_file_path: Path) -> None: |
| if not csv_file_path.exists(): |
| csv_file_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(csv_file_path, "w", newline="") as csvfile: |
| writer = csv.writer(csvfile) |
| header = ( |
| ["dataset", "model"] |
| + [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES] |
| + ["domain", "num_variates"] |
| ) |
| writer.writerow(header) |
|
|
|
|
| def write_results_to_disk( |
| items: list[EvaluationItem], |
| dataset_name: str, |
| output_dir: Path, |
| model_name: str, |
| create_plots: bool, |
| ) -> None: |
| output_dir = output_dir / dataset_name |
| output_dir.mkdir(parents=True, exist_ok=True) |
| output_csv_path = output_dir / "results.csv" |
| _ensure_results_csv(output_csv_path) |
|
|
| try: |
| import matplotlib.pyplot as plt |
| except ImportError: |
| plt = None |
|
|
| with open(output_csv_path, "a", newline="") as csvfile: |
| writer = csv.writer(csvfile) |
| for item in items: |
| md: DatasetMetadata = item.dataset_metadata |
| metric_values: list[float | None] = [] |
| for metric_name in STANDARD_METRIC_NAMES: |
| value = item.metrics.get(metric_name, None) |
| if value is None: |
| metric_values.append(None) |
| else: |
| if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1: |
| value = value[0] |
| elif hasattr(value, "item"): |
| value = value.item() |
| metric_values.append(value) |
|
|
| ds_key = md.key.lower() |
| props = DATASET_PROPERTIES.get(ds_key, {}) |
| domain = props.get("domain", "unknown") |
| num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim) |
|
|
| row = [md.full_name, model_name] + metric_values + [domain, num_variates] |
| writer.writerow(row) |
|
|
| if create_plots and item.figures and plt is not None: |
| plots_dir = output_dir / "plots" / md.key / md.term |
| plots_dir.mkdir(parents=True, exist_ok=True) |
| for fig, filename in item.figures: |
| filepath = plots_dir / filename |
| fig.savefig(filepath, dpi=300, bbox_inches="tight") |
| plt.close(fig) |
|
|
| logger.info( |
| "Evaluation complete for dataset '%s'. Results saved to %s", |
| dataset_name, |
| output_csv_path, |
| ) |
| if create_plots: |
| logger.info("Plots saved under %s", output_dir / "plots") |
|
|
|
|
| def get_all_datasets_full_name() -> list[str]: |
| """Get all possible dataset full names for validation.""" |
|
|
| terms = ["short", "medium", "long"] |
| datasets_full_names: list[str] = [] |
|
|
| for name in ALL_DATASETS: |
| for term in terms: |
| if term in ["medium", "long"] and name not in MED_LONG_DATASETS: |
| continue |
|
|
| if "/" in name: |
| ds_key, ds_freq = name.split("/") |
| ds_key = ds_key.lower() |
| ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
| else: |
| ds_key = name.lower() |
| ds_key = PRETTY_NAMES.get(ds_key, ds_key) |
| ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency") |
|
|
| datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}") |
|
|
| return datasets_full_names |
|
|
|
|
| def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None: |
| """Aggregate results from multiple CSV files into a single dataframe.""" |
|
|
| result_root = Path(result_root_dir) |
|
|
| logger.info("Aggregating results in: %s", result_root) |
|
|
| result_files = glob.glob(f"{result_root}/**/results.csv", recursive=True) |
|
|
| if not result_files: |
| logger.error("No result files found!") |
| return None |
|
|
| dataframes: list[pd.DataFrame] = [] |
| for file in result_files: |
| try: |
| df = pd.read_csv(file) |
| if len(df) > 0: |
| dataframes.append(df) |
| else: |
| logger.warning("Empty file: %s", file) |
| except pd.errors.EmptyDataError: |
| logger.warning("Skipping empty file: %s", file) |
| except Exception as exc: |
| logger.error("Error reading %s: %s", file, exc) |
|
|
| if not dataframes: |
| logger.warning("No valid CSV files found to combine") |
| return None |
|
|
| combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset") |
|
|
| if len(combined_df) != len(set(combined_df.dataset)): |
| duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist() |
| logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets) |
| combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first") |
| logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df)) |
|
|
| logger.info("Combined results: %s datasets", len(combined_df)) |
|
|
| all_datasets_full_name = get_all_datasets_full_name() |
| completed_experiments = combined_df.dataset.tolist() |
|
|
| completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name] |
| missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean] |
|
|
| logger.info("=== EXPERIMENT SUMMARY ===") |
| logger.info("Total expected datasets: %s", len(all_datasets_full_name)) |
| logger.info("Completed experiments: %s", len(completed_experiments_clean)) |
| logger.info("Missing/failed experiments: %s", len(missing_or_failed_experiments)) |
|
|
| logger.info("Completed experiments:") |
| for idx, exp in enumerate(completed_experiments_clean, start=1): |
| logger.info(" %3d: %s", idx, exp) |
|
|
| if missing_or_failed_experiments: |
| logger.info("Missing or failed experiments:") |
| for idx, exp in enumerate(missing_or_failed_experiments, start=1): |
| logger.info(" %3d: %s", idx, exp) |
|
|
| completion_rate = ( |
| len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0 |
| ) |
| logger.info("Completion rate: %.1f%%", completion_rate) |
|
|
| output_file = result_root / "all_results.csv" |
| combined_df.to_csv(output_file, index=False) |
| logger.info("Combined results saved to: %s", output_file) |
|
|
| return combined_df |
|
|
|
|
| __all__ = [ |
| "aggregate_results", |
| "get_all_datasets_full_name", |
| "write_results_to_disk", |
| ] |
|
|
|
|
| def main() -> None: |
| """CLI entry point for aggregating results from disk.""" |
|
|
| parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files") |
| parser.add_argument( |
| "--result_root_dir", |
| type=str, |
| required=True, |
| help="Root directory containing result subdirectories", |
| ) |
|
|
| args = parser.parse_args() |
| result_root_dir = Path(args.result_root_dir) |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
| logger.info("Searching in directory: %s", result_root_dir) |
|
|
| aggregate_results(result_root_dir=result_root_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|