dreamlessx commited on
Commit
0434bde
·
verified ·
1 Parent(s): 199f152

Update landmarkdiff/metrics_viz.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/metrics_viz.py +35 -61
landmarkdiff/metrics_viz.py CHANGED
@@ -81,26 +81,22 @@ class MetricsVisualizer:
81
  def _get_plt(self) -> Any:
82
  """Import matplotlib with configuration."""
83
  import matplotlib
84
-
85
  matplotlib.use("Agg")
86
  import matplotlib.pyplot as plt
87
-
88
  try:
89
  plt.style.use(self.style)
90
  except OSError:
91
  plt.style.use("seaborn-v0_8")
92
  # Publication font sizes
93
- plt.rcParams.update(
94
- {
95
- "font.size": 10,
96
- "axes.titlesize": 12,
97
- "axes.labelsize": 11,
98
- "xtick.labelsize": 9,
99
- "ytick.labelsize": 9,
100
- "legend.fontsize": 9,
101
- "figure.titlesize": 13,
102
- }
103
- )
104
  return plt
105
 
106
  # ------------------------------------------------------------------
@@ -141,7 +137,7 @@ class MetricsVisualizer:
141
  if n_metrics == 1:
142
  axes = [axes]
143
 
144
- for ax, metric in zip(axes, metrics, strict=False):
145
  values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
146
  colors = [self.COLORS.get(p, "#999999") for p in procedures]
147
 
@@ -149,21 +145,16 @@ class MetricsVisualizer:
149
  ax.set_xticks(range(n_procs))
150
  ax.set_xticklabels(
151
  [p[:5].title() for p in procedures],
152
- rotation=30,
153
- ha="right",
154
  )
155
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
156
  ax.set_title(self.METRIC_LABELS.get(metric, metric))
157
 
158
  # Add value labels on bars
159
- for bar, val in zip(bars, values, strict=False):
160
  ax.text(
161
- bar.get_x() + bar.get_width() / 2,
162
- bar.get_height(),
163
- f"{val:.3f}",
164
- ha="center",
165
- va="bottom",
166
- fontsize=8,
167
  )
168
 
169
  fig.suptitle(title, fontweight="bold")
@@ -200,8 +191,9 @@ class MetricsVisualizer:
200
 
201
  if metrics is None:
202
  metrics = sorted(
203
- set.intersection(*(set(v.keys()) for v in experiments.values()))
204
- & set(self.METRIC_LABELS.keys())
 
205
  )
206
 
207
  n_metrics = len(metrics)
@@ -265,7 +257,9 @@ class MetricsVisualizer:
265
  plt = self._get_plt()
266
 
267
  fitz_types = sorted(metrics_by_type.keys())
268
- procedures = sorted(set.union(*(set(v.keys()) for v in metrics_by_type.values())))
 
 
269
 
270
  # Build matrix
271
  matrix = np.zeros((len(fitz_types), len(procedures)))
@@ -273,9 +267,7 @@ class MetricsVisualizer:
273
  for j, proc in enumerate(procedures):
274
  matrix[i, j] = metrics_by_type[ft].get(proc, 0)
275
 
276
- fig, ax = plt.subplots(
277
- figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8))
278
- )
279
 
280
  cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
281
  im = ax.imshow(matrix, cmap=cmap, aspect="auto")
@@ -289,15 +281,9 @@ class MetricsVisualizer:
289
  # Annotate cells
290
  for i in range(len(fitz_types)):
291
  for j in range(len(procedures)):
292
- ax.text(
293
- j,
294
- i,
295
- f"{matrix[i, j]:.3f}",
296
- ha="center",
297
- va="center",
298
- fontsize=9,
299
- color="white" if matrix[i, j] < np.median(matrix) else "black",
300
- )
301
 
302
  fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
303
 
@@ -341,21 +327,18 @@ class MetricsVisualizer:
341
  fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
342
 
343
  bp = ax.boxplot(
344
- data,
345
- patch_artist=True,
346
- widths=0.6,
347
  medianprops={"color": "black", "linewidth": 1.5},
348
  )
349
 
350
  colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
351
- for patch, color in zip(bp["boxes"], colors, strict=False):
352
  patch.set_facecolor(color)
353
  patch.set_alpha(0.7)
354
 
355
  ax.set_xticklabels(
356
  [g.title() for g in groups],
357
- rotation=30,
358
- ha="right",
359
  )
360
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
361
 
@@ -364,16 +347,9 @@ class MetricsVisualizer:
364
  ax.set_title(title, fontweight="bold")
365
 
366
  # Add sample count annotations
367
- for i, (_g, vals) in enumerate(zip(groups, data, strict=False)):
368
- ax.text(
369
- i + 1,
370
- ax.get_ylim()[0],
371
- f"n={len(vals)}",
372
- ha="center",
373
- va="bottom",
374
- fontsize=8,
375
- color="gray",
376
- )
377
 
378
  fig.tight_layout()
379
  out_path = self.output_dir / filename
@@ -453,12 +429,10 @@ class MetricsVisualizer:
453
  parts.append(val_str)
454
  lines.append(" & ".join(parts) + " \\\\")
455
 
456
- lines.extend(
457
- [
458
- "\\bottomrule",
459
- "\\end{tabular}",
460
- "\\end{table}",
461
- ]
462
- )
463
 
464
  return "\n".join(lines)
 
81
  def _get_plt(self) -> Any:
82
  """Import matplotlib with configuration."""
83
  import matplotlib
 
84
  matplotlib.use("Agg")
85
  import matplotlib.pyplot as plt
 
86
  try:
87
  plt.style.use(self.style)
88
  except OSError:
89
  plt.style.use("seaborn-v0_8")
90
  # Publication font sizes
91
+ plt.rcParams.update({
92
+ "font.size": 10,
93
+ "axes.titlesize": 12,
94
+ "axes.labelsize": 11,
95
+ "xtick.labelsize": 9,
96
+ "ytick.labelsize": 9,
97
+ "legend.fontsize": 9,
98
+ "figure.titlesize": 13,
99
+ })
 
 
100
  return plt
101
 
102
  # ------------------------------------------------------------------
 
137
  if n_metrics == 1:
138
  axes = [axes]
139
 
140
+ for ax, metric in zip(axes, metrics):
141
  values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
142
  colors = [self.COLORS.get(p, "#999999") for p in procedures]
143
 
 
145
  ax.set_xticks(range(n_procs))
146
  ax.set_xticklabels(
147
  [p[:5].title() for p in procedures],
148
+ rotation=30, ha="right",
 
149
  )
150
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
151
  ax.set_title(self.METRIC_LABELS.get(metric, metric))
152
 
153
  # Add value labels on bars
154
+ for bar, val in zip(bars, values):
155
  ax.text(
156
+ bar.get_x() + bar.get_width() / 2, bar.get_height(),
157
+ f"{val:.3f}", ha="center", va="bottom", fontsize=8,
 
 
 
 
158
  )
159
 
160
  fig.suptitle(title, fontweight="bold")
 
191
 
192
  if metrics is None:
193
  metrics = sorted(
194
+ set.intersection(
195
+ *(set(v.keys()) for v in experiments.values())
196
+ ) & set(self.METRIC_LABELS.keys())
197
  )
198
 
199
  n_metrics = len(metrics)
 
257
  plt = self._get_plt()
258
 
259
  fitz_types = sorted(metrics_by_type.keys())
260
+ procedures = sorted(
261
+ set.union(*(set(v.keys()) for v in metrics_by_type.values()))
262
+ )
263
 
264
  # Build matrix
265
  matrix = np.zeros((len(fitz_types), len(procedures)))
 
267
  for j, proc in enumerate(procedures):
268
  matrix[i, j] = metrics_by_type[ft].get(proc, 0)
269
 
270
+ fig, ax = plt.subplots(figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8)))
 
 
271
 
272
  cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
273
  im = ax.imshow(matrix, cmap=cmap, aspect="auto")
 
281
  # Annotate cells
282
  for i in range(len(fitz_types)):
283
  for j in range(len(procedures)):
284
+ ax.text(j, i, f"{matrix[i, j]:.3f}",
285
+ ha="center", va="center", fontsize=9,
286
+ color="white" if matrix[i, j] < np.median(matrix) else "black")
 
 
 
 
 
 
287
 
288
  fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
289
 
 
327
  fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
328
 
329
  bp = ax.boxplot(
330
+ data, patch_artist=True, widths=0.6,
 
 
331
  medianprops={"color": "black", "linewidth": 1.5},
332
  )
333
 
334
  colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
335
+ for patch, color in zip(bp["boxes"], colors):
336
  patch.set_facecolor(color)
337
  patch.set_alpha(0.7)
338
 
339
  ax.set_xticklabels(
340
  [g.title() for g in groups],
341
+ rotation=30, ha="right",
 
342
  )
343
  ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
344
 
 
347
  ax.set_title(title, fontweight="bold")
348
 
349
  # Add sample count annotations
350
+ for i, (_g, vals) in enumerate(zip(groups, data)):
351
+ ax.text(i + 1, ax.get_ylim()[0], f"n={len(vals)}",
352
+ ha="center", va="bottom", fontsize=8, color="gray")
 
 
 
 
 
 
 
353
 
354
  fig.tight_layout()
355
  out_path = self.output_dir / filename
 
429
  parts.append(val_str)
430
  lines.append(" & ".join(parts) + " \\\\")
431
 
432
+ lines.extend([
433
+ "\\bottomrule",
434
+ "\\end{tabular}",
435
+ "\\end{table}",
436
+ ])
 
 
437
 
438
  return "\n".join(lines)