Spaces:
Running
Running
Update landmarkdiff/metrics_viz.py to v0.3.2
Browse files- landmarkdiff/metrics_viz.py +35 -61
landmarkdiff/metrics_viz.py
CHANGED
|
@@ -81,26 +81,22 @@ class MetricsVisualizer:
|
|
| 81 |
def _get_plt(self) -> Any:
|
| 82 |
"""Import matplotlib with configuration."""
|
| 83 |
import matplotlib
|
| 84 |
-
|
| 85 |
matplotlib.use("Agg")
|
| 86 |
import matplotlib.pyplot as plt
|
| 87 |
-
|
| 88 |
try:
|
| 89 |
plt.style.use(self.style)
|
| 90 |
except OSError:
|
| 91 |
plt.style.use("seaborn-v0_8")
|
| 92 |
# Publication font sizes
|
| 93 |
-
plt.rcParams.update(
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
}
|
| 103 |
-
)
|
| 104 |
return plt
|
| 105 |
|
| 106 |
# ------------------------------------------------------------------
|
|
@@ -141,7 +137,7 @@ class MetricsVisualizer:
|
|
| 141 |
if n_metrics == 1:
|
| 142 |
axes = [axes]
|
| 143 |
|
| 144 |
-
for ax, metric in zip(axes, metrics
|
| 145 |
values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
|
| 146 |
colors = [self.COLORS.get(p, "#999999") for p in procedures]
|
| 147 |
|
|
@@ -149,21 +145,16 @@ class MetricsVisualizer:
|
|
| 149 |
ax.set_xticks(range(n_procs))
|
| 150 |
ax.set_xticklabels(
|
| 151 |
[p[:5].title() for p in procedures],
|
| 152 |
-
rotation=30,
|
| 153 |
-
ha="right",
|
| 154 |
)
|
| 155 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 156 |
ax.set_title(self.METRIC_LABELS.get(metric, metric))
|
| 157 |
|
| 158 |
# Add value labels on bars
|
| 159 |
-
for bar, val in zip(bars, values
|
| 160 |
ax.text(
|
| 161 |
-
bar.get_x() + bar.get_width() / 2,
|
| 162 |
-
|
| 163 |
-
f"{val:.3f}",
|
| 164 |
-
ha="center",
|
| 165 |
-
va="bottom",
|
| 166 |
-
fontsize=8,
|
| 167 |
)
|
| 168 |
|
| 169 |
fig.suptitle(title, fontweight="bold")
|
|
@@ -200,8 +191,9 @@ class MetricsVisualizer:
|
|
| 200 |
|
| 201 |
if metrics is None:
|
| 202 |
metrics = sorted(
|
| 203 |
-
set.intersection(
|
| 204 |
-
|
|
|
|
| 205 |
)
|
| 206 |
|
| 207 |
n_metrics = len(metrics)
|
|
@@ -265,7 +257,9 @@ class MetricsVisualizer:
|
|
| 265 |
plt = self._get_plt()
|
| 266 |
|
| 267 |
fitz_types = sorted(metrics_by_type.keys())
|
| 268 |
-
procedures = sorted(
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# Build matrix
|
| 271 |
matrix = np.zeros((len(fitz_types), len(procedures)))
|
|
@@ -273,9 +267,7 @@ class MetricsVisualizer:
|
|
| 273 |
for j, proc in enumerate(procedures):
|
| 274 |
matrix[i, j] = metrics_by_type[ft].get(proc, 0)
|
| 275 |
|
| 276 |
-
fig, ax = plt.subplots(
|
| 277 |
-
figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8))
|
| 278 |
-
)
|
| 279 |
|
| 280 |
cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
|
| 281 |
im = ax.imshow(matrix, cmap=cmap, aspect="auto")
|
|
@@ -289,15 +281,9 @@ class MetricsVisualizer:
|
|
| 289 |
# Annotate cells
|
| 290 |
for i in range(len(fitz_types)):
|
| 291 |
for j in range(len(procedures)):
|
| 292 |
-
ax.text(
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
f"{matrix[i, j]:.3f}",
|
| 296 |
-
ha="center",
|
| 297 |
-
va="center",
|
| 298 |
-
fontsize=9,
|
| 299 |
-
color="white" if matrix[i, j] < np.median(matrix) else "black",
|
| 300 |
-
)
|
| 301 |
|
| 302 |
fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
|
| 303 |
|
|
@@ -341,21 +327,18 @@ class MetricsVisualizer:
|
|
| 341 |
fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
|
| 342 |
|
| 343 |
bp = ax.boxplot(
|
| 344 |
-
data,
|
| 345 |
-
patch_artist=True,
|
| 346 |
-
widths=0.6,
|
| 347 |
medianprops={"color": "black", "linewidth": 1.5},
|
| 348 |
)
|
| 349 |
|
| 350 |
colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
|
| 351 |
-
for patch, color in zip(bp["boxes"], colors
|
| 352 |
patch.set_facecolor(color)
|
| 353 |
patch.set_alpha(0.7)
|
| 354 |
|
| 355 |
ax.set_xticklabels(
|
| 356 |
[g.title() for g in groups],
|
| 357 |
-
rotation=30,
|
| 358 |
-
ha="right",
|
| 359 |
)
|
| 360 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 361 |
|
|
@@ -364,16 +347,9 @@ class MetricsVisualizer:
|
|
| 364 |
ax.set_title(title, fontweight="bold")
|
| 365 |
|
| 366 |
# Add sample count annotations
|
| 367 |
-
for i, (_g, vals) in enumerate(zip(groups, data
|
| 368 |
-
ax.text(
|
| 369 |
-
|
| 370 |
-
ax.get_ylim()[0],
|
| 371 |
-
f"n={len(vals)}",
|
| 372 |
-
ha="center",
|
| 373 |
-
va="bottom",
|
| 374 |
-
fontsize=8,
|
| 375 |
-
color="gray",
|
| 376 |
-
)
|
| 377 |
|
| 378 |
fig.tight_layout()
|
| 379 |
out_path = self.output_dir / filename
|
|
@@ -453,12 +429,10 @@ class MetricsVisualizer:
|
|
| 453 |
parts.append(val_str)
|
| 454 |
lines.append(" & ".join(parts) + " \\\\")
|
| 455 |
|
| 456 |
-
lines.extend(
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
]
|
| 462 |
-
)
|
| 463 |
|
| 464 |
return "\n".join(lines)
|
|
|
|
| 81 |
def _get_plt(self) -> Any:
|
| 82 |
"""Import matplotlib with configuration."""
|
| 83 |
import matplotlib
|
|
|
|
| 84 |
matplotlib.use("Agg")
|
| 85 |
import matplotlib.pyplot as plt
|
|
|
|
| 86 |
try:
|
| 87 |
plt.style.use(self.style)
|
| 88 |
except OSError:
|
| 89 |
plt.style.use("seaborn-v0_8")
|
| 90 |
# Publication font sizes
|
| 91 |
+
plt.rcParams.update({
|
| 92 |
+
"font.size": 10,
|
| 93 |
+
"axes.titlesize": 12,
|
| 94 |
+
"axes.labelsize": 11,
|
| 95 |
+
"xtick.labelsize": 9,
|
| 96 |
+
"ytick.labelsize": 9,
|
| 97 |
+
"legend.fontsize": 9,
|
| 98 |
+
"figure.titlesize": 13,
|
| 99 |
+
})
|
|
|
|
|
|
|
| 100 |
return plt
|
| 101 |
|
| 102 |
# ------------------------------------------------------------------
|
|
|
|
| 137 |
if n_metrics == 1:
|
| 138 |
axes = [axes]
|
| 139 |
|
| 140 |
+
for ax, metric in zip(axes, metrics):
|
| 141 |
values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
|
| 142 |
colors = [self.COLORS.get(p, "#999999") for p in procedures]
|
| 143 |
|
|
|
|
| 145 |
ax.set_xticks(range(n_procs))
|
| 146 |
ax.set_xticklabels(
|
| 147 |
[p[:5].title() for p in procedures],
|
| 148 |
+
rotation=30, ha="right",
|
|
|
|
| 149 |
)
|
| 150 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 151 |
ax.set_title(self.METRIC_LABELS.get(metric, metric))
|
| 152 |
|
| 153 |
# Add value labels on bars
|
| 154 |
+
for bar, val in zip(bars, values):
|
| 155 |
ax.text(
|
| 156 |
+
bar.get_x() + bar.get_width() / 2, bar.get_height(),
|
| 157 |
+
f"{val:.3f}", ha="center", va="bottom", fontsize=8,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
fig.suptitle(title, fontweight="bold")
|
|
|
|
| 191 |
|
| 192 |
if metrics is None:
|
| 193 |
metrics = sorted(
|
| 194 |
+
set.intersection(
|
| 195 |
+
*(set(v.keys()) for v in experiments.values())
|
| 196 |
+
) & set(self.METRIC_LABELS.keys())
|
| 197 |
)
|
| 198 |
|
| 199 |
n_metrics = len(metrics)
|
|
|
|
| 257 |
plt = self._get_plt()
|
| 258 |
|
| 259 |
fitz_types = sorted(metrics_by_type.keys())
|
| 260 |
+
procedures = sorted(
|
| 261 |
+
set.union(*(set(v.keys()) for v in metrics_by_type.values()))
|
| 262 |
+
)
|
| 263 |
|
| 264 |
# Build matrix
|
| 265 |
matrix = np.zeros((len(fitz_types), len(procedures)))
|
|
|
|
| 267 |
for j, proc in enumerate(procedures):
|
| 268 |
matrix[i, j] = metrics_by_type[ft].get(proc, 0)
|
| 269 |
|
| 270 |
+
fig, ax = plt.subplots(figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8)))
|
|
|
|
|
|
|
| 271 |
|
| 272 |
cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
|
| 273 |
im = ax.imshow(matrix, cmap=cmap, aspect="auto")
|
|
|
|
| 281 |
# Annotate cells
|
| 282 |
for i in range(len(fitz_types)):
|
| 283 |
for j in range(len(procedures)):
|
| 284 |
+
ax.text(j, i, f"{matrix[i, j]:.3f}",
|
| 285 |
+
ha="center", va="center", fontsize=9,
|
| 286 |
+
color="white" if matrix[i, j] < np.median(matrix) else "black")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
|
| 289 |
|
|
|
|
| 327 |
fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
|
| 328 |
|
| 329 |
bp = ax.boxplot(
|
| 330 |
+
data, patch_artist=True, widths=0.6,
|
|
|
|
|
|
|
| 331 |
medianprops={"color": "black", "linewidth": 1.5},
|
| 332 |
)
|
| 333 |
|
| 334 |
colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
|
| 335 |
+
for patch, color in zip(bp["boxes"], colors):
|
| 336 |
patch.set_facecolor(color)
|
| 337 |
patch.set_alpha(0.7)
|
| 338 |
|
| 339 |
ax.set_xticklabels(
|
| 340 |
[g.title() for g in groups],
|
| 341 |
+
rotation=30, ha="right",
|
|
|
|
| 342 |
)
|
| 343 |
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
| 344 |
|
|
|
|
| 347 |
ax.set_title(title, fontweight="bold")
|
| 348 |
|
| 349 |
# Add sample count annotations
|
| 350 |
+
for i, (_g, vals) in enumerate(zip(groups, data)):
|
| 351 |
+
ax.text(i + 1, ax.get_ylim()[0], f"n={len(vals)}",
|
| 352 |
+
ha="center", va="bottom", fontsize=8, color="gray")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
fig.tight_layout()
|
| 355 |
out_path = self.output_dir / filename
|
|
|
|
| 429 |
parts.append(val_str)
|
| 430 |
lines.append(" & ".join(parts) + " \\\\")
|
| 431 |
|
| 432 |
+
lines.extend([
|
| 433 |
+
"\\bottomrule",
|
| 434 |
+
"\\end{tabular}",
|
| 435 |
+
"\\end{table}",
|
| 436 |
+
])
|
|
|
|
|
|
|
| 437 |
|
| 438 |
return "\n".join(lines)
|