sh4shv4t commited on
Commit
0faca0b
Β·
1 Parent(s): 70be177

feat: added reward audit program

Browse files
Files changed (1) hide show
  1. scripts/debug_rewards_colabstyle.py +120 -0
scripts/debug_rewards_colabstyle.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Colab-style reward diagnostic (path: set REPO or run from repo root).
3
+ """
4
+ import inspect
5
+ import json
6
+ import os
7
+ import re
8
+ import sys
9
+
10
+ REPO = os.environ.get("PARLAY_REPO", os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+ sys.path.insert(0, REPO)
12
+
13
+ # ── 1. Test the reward functions directly ─────────────────────
14
+ from training.reward_fn import ( # noqa: E402
15
+ anti_capitulation_reward,
16
+ format_reward,
17
+ negotiation_efficiency_reward,
18
+ tom_accuracy_reward,
19
+ )
20
+
21
+ # Valid JSON, single line (user's Colab string had a broken newline inside the string)
22
+ completions = [
23
+ (
24
+ '{"utterance": "I\'m willing to negotiate, but I need a significant raise.", '
25
+ '"offer_amount": 150000, "tactical_move": null}'
26
+ )
27
+ ]
28
+
29
+ kwargs_hiring = {
30
+ "batna_seller": [195000.0],
31
+ "batna_buyer": [264500.0],
32
+ "zopa_width": [69500.0],
33
+ "scenario_id": ["hiring_package"],
34
+ "persona": ["shark"],
35
+ }
36
+ kwargs_saas = {
37
+ "batna_seller": [125000.0],
38
+ "batna_buyer": [165000.0],
39
+ "zopa_width": [40000.0],
40
+ "scenario_id": ["saas_enterprise"],
41
+ "persona": ["shark"],
42
+ }
43
+
44
+ print("=== REPO ===")
45
+ print(f" sys.path[0] = {sys.path[0]}")
46
+
47
+ print("\n=== REWARD FUNCTION OUTPUTS ===")
48
+ print(f"format_reward: {format_reward(completions)}")
49
+ print(f"anti_cap (hiring): {anti_capitulation_reward(completions, **kwargs_hiring)}")
50
+ print(f"tom_reward (hiring): {tom_accuracy_reward(completions, **kwargs_hiring)}")
51
+ print(f"efficiency (hiring): {negotiation_efficiency_reward(completions, **kwargs_hiring)}")
52
+ print(f"efficiency (saas): {negotiation_efficiency_reward(completions, **kwargs_saas)}")
53
+
54
+ # ── 2. Read reward_fn.py source and print the efficiency function ─
55
+ print("\n=== negotiation_efficiency_reward SOURCE ===")
56
+ src = inspect.getsource(negotiation_efficiency_reward)
57
+ print(src)
58
+
59
+ # ── 3. Step through the logic manually ───────────────────────
60
+ print("\n=== MANUAL TRACE (hiring_package, offer=150000) ===")
61
+ raw = completions[0]
62
+ try:
63
+ parsed = json.loads(raw)
64
+ offer = parsed.get("offer_amount")
65
+ print(f" parsed offer_amount: {offer!r} (type: {type(offer).__name__})")
66
+ except Exception as e:
67
+ print(f" JSON parse failed: {e}")
68
+ offer = None
69
+
70
+ batna_seller = 195000.0
71
+ batna_buyer = 264500.0
72
+ zopa_width = 69500.0
73
+ scenario_id = "hiring_package"
74
+
75
+ print(f" scenario_id: {scenario_id}")
76
+ print(f" batna_seller: {batna_seller} batna_buyer: {batna_buyer}")
77
+ print(f" zopa_width: {zopa_width}")
78
+ if offer is not None:
79
+ e_seller = (offer - batna_seller) / zopa_width
80
+ e_buyer = (batna_buyer - offer) / zopa_width
81
+ print(
82
+ f" efficiency if treated as SELLER: {e_seller:.4f} (offer - batna_seller) / width"
83
+ )
84
+ print(
85
+ f" efficiency if treated as BUYER: {e_buyer:.4f} (batna_buyer - offer) / width"
86
+ )
87
+ print(
88
+ f" offer ({offer}) vs batna_seller ({batna_seller}): "
89
+ f"{'ABOVE' if offer >= batna_seller else 'BELOW β€” anti-cap may fire'}"
90
+ )
91
+ print(
92
+ f" offer ({offer}) vs batna_buyer ({batna_buyer}): "
93
+ f"{'AT OR BELOW' if offer <= batna_buyer else 'ABOVE batna_buyer'}"
94
+ )
95
+
96
+ # ── 4. Check dataset paths (local) ─
97
+ print("\n=== GRPO DATASET / DATA PATHS CHECK ===")
98
+ for p in [
99
+ os.path.join(REPO, "data", "grpo_dataset"),
100
+ os.path.join(REPO, "data", "episodes.jsonl"),
101
+ os.path.join(REPO, "data", "episodes_v2.jsonl"),
102
+ REPO,
103
+ ]:
104
+ print(f" exists={os.path.exists(p)!s:5} {p}")
105
+
106
+ # Grep-relevant lines from grpo_train
107
+ print("\n=== grpo_train.py β€” lines mentioning build / batna / zopa / kwargs ===")
108
+ gp = os.path.join(REPO, "training", "grpo_train.py")
109
+ if os.path.isfile(gp):
110
+ with open(gp, encoding="utf-8") as f:
111
+ lines = f.readlines()
112
+ for i, line in enumerate(lines, start=1):
113
+ if re.search(
114
+ r"build_grpo|batna|zopa_width|def build|scenario_id|format_grpo",
115
+ line,
116
+ ):
117
+ print(f" L{i}: {line.rstrip()}")
118
+
119
+
120
+ print("\n=== DONE ===")