abhshkp's picture
Upload folder using huggingface_hub
e1e1ce9 verified
"""Plotting utilities for position-bias curves."""
import logging
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
logger = logging.getLogger(__name__)
def plot_curve(
x_values,
y_values,
title: str,
save_path: str,
xlabel: str = "Position (0=start, 1=end)",
ylabel: str = "Accuracy",
ylim: tuple = (-0.05, 1.05),
color: str = "#E63946",
):
"""Plot a standard position-bias accuracy curve."""
plt.figure(figsize=(8, 5))
plt.plot(x_values, y_values, marker="o", linewidth=2.5, markersize=10, color=color)
plt.xlabel(xlabel, fontsize=13)
plt.ylabel(ylabel, fontsize=13)
plt.title(title, fontsize=13)
plt.ylim(ylim)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=200)
plt.close()
logger.info(f"Plot saved: {save_path}")
def plot_bar(categories, values, title: str, save_path: str, ylabel: str = "Accuracy", ylim=(0, 1.05), colors=None):
"""Plot a bar chart (e.g., for multi-needle start/middle/end)."""
if colors is None:
colors = ["#2E86AB", "#E63946", "#2E86AB"]
plt.figure(figsize=(6, 5))
plt.bar(categories, values, color=colors, edgecolor="black", linewidth=1.2)
plt.ylabel(ylabel, fontsize=13)
plt.title(title, fontsize=13)
plt.ylim(ylim)
plt.grid(True, alpha=0.3, axis="y")
plt.tight_layout()
plt.savefig(save_path, dpi=200)
plt.close()
logger.info(f"Bar plot saved: {save_path}")
def plot_multi_curves(curves, labels, title, save_path, xlabel="Position", ylabel="Accuracy"):
"""Overlay multiple curves for comparison."""
plt.figure(figsize=(10, 6))
cmap = plt.get_cmap("tab10")
for i, (x, y, label) in enumerate(zip(curves["x"], curves["y"], labels)):
plt.plot(x, y, marker="o", linewidth=2.0, markersize=8, label=label, color=cmap(i))
plt.xlabel(xlabel, fontsize=13)
plt.ylabel(ylabel, fontsize=13)
plt.title(title, fontsize=13)
plt.ylim(-0.05, 1.05)
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(save_path, dpi=200)
plt.close()
logger.info(f"Multi-curve plot saved: {save_path}")