LandmarkDiff / landmarkdiff /metrics_agg.py
dreamlessx's picture
Update landmarkdiff/metrics_agg.py to v0.3.2
199f152 verified
"""Metrics aggregation across checkpoints, experiments, and procedures.
Collects evaluation results from multiple sources and computes aggregate
statistics, confidence intervals, and significance tests for paper reporting.
Usage:
from landmarkdiff.metrics_agg import MetricsAggregator
agg = MetricsAggregator()
agg.add("baseline", "rhinoplasty", {"ssim": 0.82, "lpips": 0.18})
agg.add("ours", "rhinoplasty", {"ssim": 0.91, "lpips": 0.09})
print(agg.summary_table())
print(agg.improvement_over("baseline"))
"""
from __future__ import annotations
import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass
class MetricRecord:
"""A single evaluation record."""
experiment: str
procedure: str
metrics: dict[str, float]
checkpoint_step: int | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class MetricsAggregator:
"""Aggregate and analyze evaluation metrics.
Supports multiple experiments, procedures, and per-sample results
for computing confidence intervals and significance.
"""
HIGHER_BETTER = {
"ssim": True, "psnr": True, "identity_sim": True,
"lpips": False, "fid": False, "nme": False,
}
def __init__(self) -> None:
self.records: list[MetricRecord] = []
def add(
self,
experiment: str,
procedure: str,
metrics: dict[str, float],
checkpoint_step: int | None = None,
**metadata: Any,
) -> None:
"""Add a single evaluation record."""
self.records.append(MetricRecord(
experiment=experiment,
procedure=procedure,
metrics=metrics,
checkpoint_step=checkpoint_step,
metadata=metadata,
))
def add_batch(
self,
experiment: str,
records: list[dict[str, Any]],
) -> None:
"""Add multiple records for an experiment.
Each record dict should have 'procedure' and metric keys.
"""
for rec in records:
proc = rec.get("procedure", "all")
metrics = {k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))}
self.add(experiment, proc, metrics)
@property
def experiments(self) -> list[str]:
"""Unique experiment names in insertion order."""
seen: dict[str, None] = {}
for r in self.records:
seen.setdefault(r.experiment, None)
return list(seen.keys())
@property
def procedures(self) -> list[str]:
"""Unique procedure names in insertion order."""
seen: dict[str, None] = {}
for r in self.records:
seen.setdefault(r.procedure, None)
return list(seen.keys())
@property
def metric_names(self) -> list[str]:
"""All unique metric names."""
names: set[str] = set()
for r in self.records:
names.update(r.metrics.keys())
return sorted(names)
def filter(
self,
experiment: str | None = None,
procedure: str | None = None,
) -> list[MetricRecord]:
"""Filter records by experiment and/or procedure."""
results = self.records
if experiment is not None:
results = [r for r in results if r.experiment == experiment]
if procedure is not None:
results = [r for r in results if r.procedure == procedure]
return results
def mean(
self,
experiment: str,
metric: str,
procedure: str | None = None,
) -> float:
"""Compute mean of a metric for an experiment."""
recs = self.filter(experiment=experiment, procedure=procedure)
vals = [r.metrics[metric] for r in recs if metric in r.metrics]
if not vals:
return float("nan")
return sum(vals) / len(vals)
def std(
self,
experiment: str,
metric: str,
procedure: str | None = None,
) -> float:
"""Compute standard deviation of a metric."""
recs = self.filter(experiment=experiment, procedure=procedure)
vals = [r.metrics[metric] for r in recs if metric in r.metrics]
if len(vals) < 2:
return 0.0
m = sum(vals) / len(vals)
var = sum((v - m) ** 2 for v in vals) / (len(vals) - 1)
return math.sqrt(var)
def ci_95(
self,
experiment: str,
metric: str,
procedure: str | None = None,
) -> tuple[float, float]:
"""Compute 95% confidence interval (mean +/- 1.96*SE)."""
recs = self.filter(experiment=experiment, procedure=procedure)
vals = [r.metrics[metric] for r in recs if metric in r.metrics]
if not vals:
return (float("nan"), float("nan"))
n = len(vals)
m = sum(vals) / n
if n < 2:
return (m, m)
var = sum((v - m) ** 2 for v in vals) / (n - 1)
se = math.sqrt(var / n)
return (m - 1.96 * se, m + 1.96 * se)
def improvement_over(
self,
baseline: str,
metric: str | None = None,
) -> dict[str, dict[str, float]]:
"""Compute relative improvement of all experiments over a baseline.
Returns:
{experiment: {metric: relative_improvement_pct}}
"""
metrics = [metric] if metric else self.metric_names
result: dict[str, dict[str, float]] = {}
for exp in self.experiments:
if exp == baseline:
continue
improvements: dict[str, float] = {}
for m in metrics:
base_val = self.mean(baseline, m)
exp_val = self.mean(exp, m)
if math.isnan(base_val) or math.isnan(exp_val) or base_val == 0:
continue
higher_better = self.HIGHER_BETTER.get(m, True)
if higher_better:
pct = (exp_val - base_val) / abs(base_val) * 100
else:
pct = (base_val - exp_val) / abs(base_val) * 100
improvements[m] = round(pct, 2)
result[exp] = improvements
return result
def best_experiment(
self,
metric: str,
procedure: str | None = None,
) -> str | None:
"""Find the experiment with the best mean for a metric."""
higher_better = self.HIGHER_BETTER.get(metric, True)
best_exp = None
best_val = float("-inf") if higher_better else float("inf")
for exp in self.experiments:
val = self.mean(exp, metric, procedure)
if math.isnan(val):
continue
if higher_better and val > best_val:
best_val = val
best_exp = exp
elif not higher_better and val < best_val:
best_val = val
best_exp = exp
return best_exp
def summary_table(
self,
metrics: list[str] | None = None,
procedure: str | None = None,
include_std: bool = False,
) -> str:
"""Generate a text summary table.
Args:
metrics: Metrics to include. None = all.
procedure: Filter by procedure. None = aggregate.
include_std: Show mean +/- std.
Returns:
Formatted text table.
"""
metrics = metrics or self.metric_names
exps = self.experiments
# Header
cols = ["Experiment"] + metrics
header = " | ".join(f"{c:>16s}" for c in cols)
lines = [header, "-" * len(header)]
for exp in exps:
parts = [f"{exp:>16s}"]
for m in metrics:
val = self.mean(exp, m, procedure)
if math.isnan(val):
parts.append(f"{'--':>16s}")
elif include_std:
s = self.std(exp, m, procedure)
parts.append(f"{val:>8.4f}±{s:<6.4f}")
else:
parts.append(f"{val:>16.4f}")
lines.append(" | ".join(parts))
return "\n".join(lines)
def to_json(self, path: str | Path | None = None) -> str:
"""Export all records as JSON.
Args:
path: Optional file path to write to.
Returns:
JSON string.
"""
data = {
"experiments": self.experiments,
"procedures": self.procedures,
"metrics": self.metric_names,
"records": [
{
"experiment": r.experiment,
"procedure": r.procedure,
"metrics": r.metrics,
"checkpoint_step": r.checkpoint_step,
"metadata": r.metadata,
}
for r in self.records
],
}
j = json.dumps(data, indent=2)
if path is not None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
Path(path).write_text(j)
return j
@staticmethod
def from_json(path: str | Path) -> MetricsAggregator:
"""Load aggregator from JSON."""
with open(path) as f:
data = json.load(f)
agg = MetricsAggregator()
for rec in data.get("records", []):
agg.add(
experiment=rec["experiment"],
procedure=rec["procedure"],
metrics=rec["metrics"],
checkpoint_step=rec.get("checkpoint_step"),
**rec.get("metadata", {}),
)
return agg