File size: 2,166 Bytes
e1e1ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Plotting utilities for position-bias curves."""
import logging
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)


def plot_curve(
    x_values,
    y_values,
    title: str,
    save_path: str,
    xlabel: str = "Position (0=start, 1=end)",
    ylabel: str = "Accuracy",
    ylim: tuple = (-0.05, 1.05),
    color: str = "#E63946",
):
    """Plot a standard position-bias accuracy curve."""
    plt.figure(figsize=(8, 5))
    plt.plot(x_values, y_values, marker="o", linewidth=2.5, markersize=10, color=color)
    plt.xlabel(xlabel, fontsize=13)
    plt.ylabel(ylabel, fontsize=13)
    plt.title(title, fontsize=13)
    plt.ylim(ylim)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()
    logger.info(f"Plot saved: {save_path}")


def plot_bar(categories, values, title: str, save_path: str, ylabel: str = "Accuracy", ylim=(0, 1.05), colors=None):
    """Plot a bar chart (e.g., for multi-needle start/middle/end)."""
    if colors is None:
        colors = ["#2E86AB", "#E63946", "#2E86AB"]
    plt.figure(figsize=(6, 5))
    plt.bar(categories, values, color=colors, edgecolor="black", linewidth=1.2)
    plt.ylabel(ylabel, fontsize=13)
    plt.title(title, fontsize=13)
    plt.ylim(ylim)
    plt.grid(True, alpha=0.3, axis="y")
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()
    logger.info(f"Bar plot saved: {save_path}")


def plot_multi_curves(curves, labels, title, save_path, xlabel="Position", ylabel="Accuracy"):
    """Overlay multiple curves for comparison."""
    plt.figure(figsize=(10, 6))
    cmap = plt.get_cmap("tab10")
    for i, (x, y, label) in enumerate(zip(curves["x"], curves["y"], labels)):
        plt.plot(x, y, marker="o", linewidth=2.0, markersize=8, label=label, color=cmap(i))
    plt.xlabel(xlabel, fontsize=13)
    plt.ylabel(ylabel, fontsize=13)
    plt.title(title, fontsize=13)
    plt.ylim(-0.05, 1.05)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=200)
    plt.close()
    logger.info(f"Multi-curve plot saved: {save_path}")