File size: 3,172 Bytes
ae60795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# training/plot_rewards.py
# Run after grpo_train.py to generate reward graphs.
# Usage: python training/plot_rewards.py --input salespath_training_outputs/reward_history.txt

import argparse
import os
import sys

import matplotlib
matplotlib.use("Agg")  # headless safe
import matplotlib.pyplot as plt
import numpy as np


def load_rewards(path: str) -> list[float]:
    rewards = []
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            rewards.append(float(parts[-1]))
    return rewards


def rolling_mean(data: list[float], window: int = 20) -> list[float]:
    result = []
    for i in range(len(data)):
        start = max(0, i - window + 1)
        result.append(float(np.mean(data[start : i + 1])))
    return result


def plot(rewards: list[float], output_path: str):
    steps = list(range(len(rewards)))
    smooth = rolling_mean(rewards, window=20)

    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    fig.suptitle("SalesPath — Training Reward Curve", fontsize=14, fontweight="bold")

    # Top: raw + smoothed reward
    ax = axes[0]
    ax.plot(steps, rewards, alpha=0.3, color="#5b9bd5", linewidth=0.8, label="Episode reward")
    ax.plot(steps, smooth, color="#e07b3c", linewidth=2.0, label="Rolling mean (20)")
    ax.axhline(0, color="gray", linestyle="--", linewidth=0.8)
    ax.set_ylabel("Total Reward")
    ax.set_xlabel("Episode")
    ax.legend(loc="upper left")
    ax.set_ylim(-1.1, 1.1)
    ax.grid(True, alpha=0.3)

    # Bottom: histogram of rewards
    ax2 = axes[1]
    ax2.hist(rewards, bins=30, color="#5b9bd5", edgecolor="white", alpha=0.8)
    ax2.axvline(np.mean(rewards), color="#e07b3c", linewidth=2, label=f"Mean: {np.mean(rewards):.3f}")
    ax2.set_xlabel("Reward")
    ax2.set_ylabel("Count")
    ax2.set_title("Reward Distribution")
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"Saved reward graph → {output_path}")
    print(f"  Episodes:    {len(rewards)}")
    print(f"  Mean reward: {np.mean(rewards):.4f}")
    print(f"  Max reward:  {np.max(rewards):.4f}")
    print(f"  Min reward:  {np.min(rewards):.4f}")
    print(f"  Std reward:  {np.std(rewards):.4f}")


def main():
    parser = argparse.ArgumentParser(description="Plot SalesPath reward history.")
    parser.add_argument(
        "--input",
        default="salespath_training_outputs/reward_history.txt",
        help="Path to reward_history.txt",
    )
    parser.add_argument(
        "--output",
        default="salespath_training_outputs/reward_graph.png",
        help="Output PNG path",
    )
    args = parser.parse_args()

    if not os.path.exists(args.input):
        print(f"ERROR: {args.input} not found. Run grpo_train.py first.")
        sys.exit(1)

    rewards = load_rewards(args.input)
    if not rewards:
        print("ERROR: No rewards found in file.")
        sys.exit(1)

    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    plot(rewards, args.output)


if __name__ == "__main__":
    main()