Spaces:
Sleeping
Sleeping
Modigy Train method
Browse files- env/rewards.py +7 -2
- train/train_grpo.py +51 -3
env/rewards.py
CHANGED
|
@@ -102,9 +102,14 @@ class RewardSystem:
|
|
| 102 |
# Smoothly squish reasoning quality using tanh to bound its impact
|
| 103 |
q_smooth = math.tanh(q)
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
|
|
|
|
|
|
|
|
|
|
| 108 |
components = {
|
| 109 |
"total_reward": total_r,
|
| 110 |
"C_correctness": c,
|
|
|
|
| 102 |
# Smoothly squish reasoning quality using tanh to bound its impact
|
| 103 |
q_smooth = math.tanh(q)
|
| 104 |
|
| 105 |
+
# Normalize variables mapping entirely into the [0, 1] domain
|
| 106 |
+
p_norm = (process_supervision + 1.0) / 2.0 # Scales [-1, 1] to [0, 1]
|
| 107 |
+
r_norm = (reflection_score + 0.5) / 1.5 # Scales [-0.5, 1.0] to [0, 1]
|
| 108 |
+
q_norm = min(1.0, max(0.0, q_smooth))
|
| 109 |
|
| 110 |
+
# New Simplified Composite Reward Equation (Strictly bounded [0, 1])
|
| 111 |
+
# Base coefficients sum exactly to 1.0. Noise is removed to satisfy bounds.
|
| 112 |
+
total_r = (0.4 * c) + (0.3 * q_norm) + (0.2 * p_norm) + (0.1 * r_norm)
|
| 113 |
components = {
|
| 114 |
"total_reward": total_r,
|
| 115 |
"C_correctness": c,
|
train/train_grpo.py
CHANGED
|
@@ -208,11 +208,59 @@ def main():
|
|
| 208 |
print("Starting LADDER Training (Curriculum: Recursive Variant Trees)...")
|
| 209 |
trainer.train()
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
# Showcase TTRL
|
| 212 |
run_ttrl(model, tokenizer, "If 4(x+2) - 10 = 14, what is x?", env)
|
| 213 |
|
| 214 |
if __name__ == "__main__":
|
| 215 |
main()
|
| 216 |
-
|
| 217 |
-
if __name__ == "__main__":
|
| 218 |
-
main()
|
|
|
|
| 208 |
print("Starting LADDER Training (Curriculum: Recursive Variant Trees)...")
|
| 209 |
trainer.train()
|
| 210 |
|
| 211 |
+
# Generate Training Charts
|
| 212 |
+
try:
|
| 213 |
+
import matplotlib.pyplot as plt
|
| 214 |
+
import os
|
| 215 |
+
|
| 216 |
+
os.makedirs("outputs_math/plots", exist_ok=True)
|
| 217 |
+
history = trainer.state.log_history
|
| 218 |
+
|
| 219 |
+
# Plot Loss
|
| 220 |
+
losses = [x["loss"] for x in history if "loss" in x]
|
| 221 |
+
steps = [x["step"] for x in history if "loss" in x]
|
| 222 |
+
if losses:
|
| 223 |
+
plt.figure(figsize=(10, 6))
|
| 224 |
+
plt.plot(steps, losses, marker="o", color="blue", linewidth=2)
|
| 225 |
+
plt.title("GRPO Training Loss Over Steps")
|
| 226 |
+
plt.xlabel("Steps")
|
| 227 |
+
plt.ylabel("Loss")
|
| 228 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 229 |
+
plt.savefig("outputs_math/plots/training_loss.png")
|
| 230 |
+
plt.close()
|
| 231 |
+
|
| 232 |
+
# Plot Rewards
|
| 233 |
+
rewards = [x["reward"] for x in history if "reward" in x]
|
| 234 |
+
r_steps = [x["step"] for x in history if "reward" in x]
|
| 235 |
+
if rewards:
|
| 236 |
+
plt.figure(figsize=(10, 6))
|
| 237 |
+
plt.plot(r_steps, rewards, marker="x", color="green", linewidth=2)
|
| 238 |
+
plt.title("Average Completion Reward Over Steps")
|
| 239 |
+
plt.xlabel("Steps")
|
| 240 |
+
plt.ylabel("Rewards")
|
| 241 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 242 |
+
plt.savefig("outputs_math/plots/reward.png")
|
| 243 |
+
plt.close()
|
| 244 |
+
|
| 245 |
+
# Plot KL Divergence
|
| 246 |
+
kl = [x["kl"] for x in history if "kl" in x]
|
| 247 |
+
kl_steps = [x["step"] for x in history if "kl" in x]
|
| 248 |
+
if kl:
|
| 249 |
+
plt.figure(figsize=(10, 6))
|
| 250 |
+
plt.plot(kl_steps, kl, marker="^", color="red", linewidth=2)
|
| 251 |
+
plt.title("KL Divergence (Policy vs Reference)")
|
| 252 |
+
plt.xlabel("Steps")
|
| 253 |
+
plt.ylabel("KL Divergence")
|
| 254 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 255 |
+
plt.savefig("outputs_math/plots/kl_divergence.png")
|
| 256 |
+
plt.close()
|
| 257 |
+
|
| 258 |
+
print(f"✅ Generated training metric plots in 'outputs_math/plots' directory.")
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Could not generate plots: {e}")
|
| 261 |
+
|
| 262 |
# Showcase TTRL
|
| 263 |
run_ttrl(model, tokenizer, "If 4(x+2) - 10 = 14, what is x?", env)
|
| 264 |
|
| 265 |
if __name__ == "__main__":
|
| 266 |
main()
|
|
|
|
|
|
|
|
|