Spaces:
Runtime error
Runtime error
File size: 3,172 Bytes
ae60795 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | # training/plot_rewards.py
# Run after grpo_train.py to generate reward graphs.
# Usage: python training/plot_rewards.py --input salespath_training_outputs/reward_history.txt
import argparse
import os
import sys
import matplotlib
matplotlib.use("Agg") # headless safe
import matplotlib.pyplot as plt
import numpy as np
def load_rewards(path: str) -> list[float]:
rewards = []
with open(path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split("\t")
rewards.append(float(parts[-1]))
return rewards
def rolling_mean(data: list[float], window: int = 20) -> list[float]:
result = []
for i in range(len(data)):
start = max(0, i - window + 1)
result.append(float(np.mean(data[start : i + 1])))
return result
def plot(rewards: list[float], output_path: str):
steps = list(range(len(rewards)))
smooth = rolling_mean(rewards, window=20)
fig, axes = plt.subplots(2, 1, figsize=(12, 8))
fig.suptitle("SalesPath — Training Reward Curve", fontsize=14, fontweight="bold")
# Top: raw + smoothed reward
ax = axes[0]
ax.plot(steps, rewards, alpha=0.3, color="#5b9bd5", linewidth=0.8, label="Episode reward")
ax.plot(steps, smooth, color="#e07b3c", linewidth=2.0, label="Rolling mean (20)")
ax.axhline(0, color="gray", linestyle="--", linewidth=0.8)
ax.set_ylabel("Total Reward")
ax.set_xlabel("Episode")
ax.legend(loc="upper left")
ax.set_ylim(-1.1, 1.1)
ax.grid(True, alpha=0.3)
# Bottom: histogram of rewards
ax2 = axes[1]
ax2.hist(rewards, bins=30, color="#5b9bd5", edgecolor="white", alpha=0.8)
ax2.axvline(np.mean(rewards), color="#e07b3c", linewidth=2, label=f"Mean: {np.mean(rewards):.3f}")
ax2.set_xlabel("Reward")
ax2.set_ylabel("Count")
ax2.set_title("Reward Distribution")
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"Saved reward graph → {output_path}")
print(f" Episodes: {len(rewards)}")
print(f" Mean reward: {np.mean(rewards):.4f}")
print(f" Max reward: {np.max(rewards):.4f}")
print(f" Min reward: {np.min(rewards):.4f}")
print(f" Std reward: {np.std(rewards):.4f}")
def main():
parser = argparse.ArgumentParser(description="Plot SalesPath reward history.")
parser.add_argument(
"--input",
default="salespath_training_outputs/reward_history.txt",
help="Path to reward_history.txt",
)
parser.add_argument(
"--output",
default="salespath_training_outputs/reward_graph.png",
help="Output PNG path",
)
args = parser.parse_args()
if not os.path.exists(args.input):
print(f"ERROR: {args.input} not found. Run grpo_train.py first.")
sys.exit(1)
rewards = load_rewards(args.input)
if not rewards:
print("ERROR: No rewards found in file.")
sys.exit(1)
os.makedirs(os.path.dirname(args.output), exist_ok=True)
plot(rewards, args.output)
if __name__ == "__main__":
main()
|