| """Plotting utilities for position-bias curves.""" |
| import logging |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def plot_curve( |
| x_values, |
| y_values, |
| title: str, |
| save_path: str, |
| xlabel: str = "Position (0=start, 1=end)", |
| ylabel: str = "Accuracy", |
| ylim: tuple = (-0.05, 1.05), |
| color: str = "#E63946", |
| ): |
| """Plot a standard position-bias accuracy curve.""" |
| plt.figure(figsize=(8, 5)) |
| plt.plot(x_values, y_values, marker="o", linewidth=2.5, markersize=10, color=color) |
| plt.xlabel(xlabel, fontsize=13) |
| plt.ylabel(ylabel, fontsize=13) |
| plt.title(title, fontsize=13) |
| plt.ylim(ylim) |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=200) |
| plt.close() |
| logger.info(f"Plot saved: {save_path}") |
|
|
|
|
| def plot_bar(categories, values, title: str, save_path: str, ylabel: str = "Accuracy", ylim=(0, 1.05), colors=None): |
| """Plot a bar chart (e.g., for multi-needle start/middle/end).""" |
| if colors is None: |
| colors = ["#2E86AB", "#E63946", "#2E86AB"] |
| plt.figure(figsize=(6, 5)) |
| plt.bar(categories, values, color=colors, edgecolor="black", linewidth=1.2) |
| plt.ylabel(ylabel, fontsize=13) |
| plt.title(title, fontsize=13) |
| plt.ylim(ylim) |
| plt.grid(True, alpha=0.3, axis="y") |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=200) |
| plt.close() |
| logger.info(f"Bar plot saved: {save_path}") |
|
|
|
|
| def plot_multi_curves(curves, labels, title, save_path, xlabel="Position", ylabel="Accuracy"): |
| """Overlay multiple curves for comparison.""" |
| plt.figure(figsize=(10, 6)) |
| cmap = plt.get_cmap("tab10") |
| for i, (x, y, label) in enumerate(zip(curves["x"], curves["y"], labels)): |
| plt.plot(x, y, marker="o", linewidth=2.0, markersize=8, label=label, color=cmap(i)) |
| plt.xlabel(xlabel, fontsize=13) |
| plt.ylabel(ylabel, fontsize=13) |
| plt.title(title, fontsize=13) |
| plt.ylim(-0.05, 1.05) |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| plt.tight_layout() |
| plt.savefig(save_path, dpi=200) |
| plt.close() |
| logger.info(f"Multi-curve plot saved: {save_path}") |
|
|