File size: 10,238 Bytes
b4ac377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""

test_env_connection.py — Validates that train_minimal.py is correctly

wired to call the live environment via HTTP.



This script verifies:

  1. reward_fn signature accepts **kwargs (not positional args like ground_truths)

  2. make_row() produces task_id and seed columns

  3. run_episode_via_http() makes actual HTTP POST calls

  4. _start_env_server_if_needed() raises when server is unreachable

  5. The word "no server" / "no HTTP" does NOT appear in the docstring



Usage:

  python train/test_env_connection.py

"""

import json
import os
import re
import sys
import inspect
from pathlib import Path
from unittest.mock import patch, MagicMock

sys.path.insert(0, ".")

# Mock out heavy training-only dependencies that may not be installed locally
# We only need to test the HTTP wiring logic, not actual GPU training
for mod_name in ["wandb", "trl", "trl.GRPOConfig", "trl.GRPOTrainer",
                 "datasets", "datasets.Dataset", "unsloth"]:
    if mod_name not in sys.modules:
        sys.modules[mod_name] = MagicMock()

# torch may or may not be installed — mock if missing
try:
    import torch
except ImportError:
    sys.modules["torch"] = MagicMock()
    sys.modules["torch.cuda"] = MagicMock()

# ── Test 1: Verify the module docstring says server IS required ─────────────
print("Test 1: Checking module docstring...")
with open("train/train_minimal.py", encoding="utf-8") as f:
    source = f.read()

# MUST NOT contain anti-patterns
forbidden = ["no server required", "no HTTP server", "no server needed", "direct-reward", "no-server"]
for phrase in forbidden:
    if phrase.lower() in source.lower():
        print(f"  ❌ FAIL: Found forbidden phrase '{phrase}' in train_minimal.py")
        sys.exit(1)

# MUST contain these indicators that it's env-connected
required = [
    "POST /step",
    "POST /reset",
    "/reset",
    "/step",
    "env-connected",
    "http-reward",
    "MR-2",
]
for phrase in required:
    if phrase not in source:
        print(f"  ❌ FAIL: Missing required phrase '{phrase}' in train_minimal.py")
        sys.exit(1)

print("  ✅ Module docstring correctly declares env-connected training")


# ── Test 2: make_row() includes task_id and seed ────────────────────────────
print("Test 2: Checking make_row() output columns...")
from server.claim_generator import generate_claim

class MockTokenizer:
    def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
        return json.dumps(messages)

# Import make_row
from train.train_minimal import make_row

ep = generate_claim(seed=42, fraud_type="medical_inflation", coverage_type="health", difficulty="medium")
tok = MockTokenizer()
row = make_row(ep, tok)

assert "task_id" in row, f"❌ FAIL: make_row() missing 'task_id'. Got keys: {list(row.keys())}"
assert "seed" in row, f"❌ FAIL: make_row() missing 'seed'. Got keys: {list(row.keys())}"
assert row["task_id"] == "contradictory_claim", f"❌ FAIL: task_id should be 'contradictory_claim', got '{row['task_id']}'"
assert row["seed"] == "42", f"❌ FAIL: seed should be '42' (str), got '{row['seed']}'"
print(f"  ✅ make_row() includes task_id='{row['task_id']}' and seed='{row['seed']}'")


# ── Test 3: reward_fn uses **kwargs (not positional) ────────────────────────
print("Test 3: Checking reward_fn signature...")
from train.train_minimal import reward_fn

sig = inspect.signature(reward_fn)
params = list(sig.parameters.keys())

# Must accept **kwargs
assert any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()), \
    f"❌ FAIL: reward_fn does not accept **kwargs. Params: {params}"

# Must NOT have 'expected_signals_list' as a positional param (old signature)
assert "expected_signals_list" not in params, \
    f"❌ FAIL: reward_fn still has 'expected_signals_list' positional param (old signature)"

# Must NOT have 'ground_truths' as a positional param (should come via **kwargs)
assert "ground_truths" not in params, \
    f"❌ FAIL: reward_fn still has 'ground_truths' as positional param. Should come via **kwargs"

print(f"  ✅ reward_fn signature: ({', '.join(params)}) — uses **kwargs correctly")


# ── Test 4: run_episode_via_http makes HTTP calls ───────────────────────────
print("Test 4: Verifying run_episode_via_http() makes HTTP POST calls...")
from train.train_minimal import run_episode_via_http

# Mock requests to verify it makes the right calls
with patch("train.train_minimal.http_client") as mock_http:
    # Setup mock responses
    mock_reset_resp = MagicMock()
    mock_reset_resp.json.return_value = {"session_id": "test-session-123"}
    mock_reset_resp.raise_for_status = MagicMock()

    mock_step_resp = MagicMock()
    mock_step_resp.json.return_value = {"reward": 0.85, "done": True}
    mock_step_resp.raise_for_status = MagicMock()

    mock_http.post.side_effect = [mock_reset_resp, mock_step_resp]

    reward = run_episode_via_http(
        task_id="clean_claim",
        seed=42,
        decision="approve_claim",
        confidence="HIGH",
        reason="All documents verified.",
        base_url="http://fake:7860",
    )

    # Verify POST /reset was called
    calls = mock_http.post.call_args_list
    assert len(calls) == 2, f"❌ FAIL: Expected 2 POST calls, got {len(calls)}"

    reset_call = calls[0]
    assert "/reset" in reset_call[0][0], f"❌ FAIL: First POST not to /reset"
    reset_body = reset_call[1]["json"]
    assert reset_body["task_id"] == "clean_claim", f"❌ FAIL: /reset body missing task_id"
    assert reset_body["seed"] == 42, f"❌ FAIL: /reset body missing seed"

    step_call = calls[1]
    assert "/step" in step_call[0][0], f"❌ FAIL: Second POST not to /step"
    step_body = step_call[1]["json"]
    assert step_body["session_id"] == "test-session-123", f"❌ FAIL: /step missing session_id from /reset"
    assert step_body["action"]["action_type"] == "approve_claim", f"❌ FAIL: action_type wrong"
    assert step_body["action"]["confidence"] == "HIGH", f"❌ FAIL: confidence wrong"

    assert reward == 0.85, f"❌ FAIL: reward should be 0.85 from /step, got {reward}"

print("  ✅ run_episode_via_http() makes POST /reset then POST /step correctly")
print(f"     → /reset body: {{task_id, seed}}")
print(f"     → /step body: {{action: {{action_type, confidence, reasoning}}, session_id}}")
print(f"     → reward returned from /step response: 0.85")


# ── Test 5: reward_fn calls run_episode_via_http (not training_reward) ──────
print("Test 5: Verifying reward_fn calls HTTP, not training_reward()...")

with patch("train.train_minimal.run_episode_via_http") as mock_episode:
    mock_episode.return_value = 0.75

    completions = [
        [{"content": "DECISION: approve_claim\nCONFIDENCE: HIGH\nREASON: docs verified"}],
        [{"content": "DECISION: deny_claim\nCONFIDENCE: MED\nREASON: suspicious docs"}],
    ]
    prompts = ["prompt1", "prompt2"]

    rewards = reward_fn(
        completions,
        prompts,
        task_id=["clean_claim", "contradictory_claim"],
        seed=["42", "43"],
        ground_truth=["approve_claim", "deny_claim"],
    )

    assert mock_episode.call_count == 2, f"❌ FAIL: Expected 2 HTTP calls, got {mock_episode.call_count}"
    assert rewards == [0.75, 0.75], f"❌ FAIL: rewards should be [0.75, 0.75], got {rewards}"

print("  ✅ reward_fn calls run_episode_via_http() for each completion")


# ── Test 6: _start_env_server_if_needed fails without server ────────────────
print("Test 6: Verifying training fails without server...")
from train.train_minimal import _wait_for_env

try:
    # Use very short retries to a port that's definitely not running
    _wait_for_env("http://localhost:19999", retries=1)
    print("  ❌ FAIL: Should have raised RuntimeError when server is unreachable")
    sys.exit(1)
except RuntimeError as e:
    assert "not reachable" in str(e).lower(), f"❌ FAIL: Error message unclear: {e}"
    print(f"  ✅ _wait_for_env raises RuntimeError when server is down")


# ── Test 7: WandB config says env-connected ─────────────────────────────────
print("Test 7: Checking WandB tags and config...")
assert '"env-connected"' in source, "❌ FAIL: WandB tags don't include 'env-connected'"
assert '"http-reward"' in source, "❌ FAIL: WandB tags don't include 'http-reward'"
assert '"env_http_reward"' in source, "❌ FAIL: reward_type not set to 'env_http_reward'"
assert '"no-server"' not in source, "❌ FAIL: WandB tags still contain 'no-server'"
assert '"direct-reward"' not in source, "❌ FAIL: WandB tags still contain 'direct-reward'"
print("  ✅ WandB config correctly reflects env-connected training")


# ── Final Summary ───────────────────────────────────────────────────────────
print()
print("=" * 70)
print("  ALL 7 TESTS PASSED ✅")
print()
print("  MR-2 Compliance verified:")
print("    • reward_fn calls POST /reset + POST /step (not training_reward)")
print("    • make_row() includes task_id + seed for /reset")
print("    • Training WILL FAIL if environment server is not running")
print("    • No 'no-server' or 'direct-reward' remnants in code")
print("=" * 70)
"""

This script validates:

  1. The module docstring declares env-connected training

  2. make_row() includes task_id and seed columns

  3. reward_fn uses **kwargs (not positional args)

  4. run_episode_via_http() makes correct POST /reset then POST /step

  5. reward_fn dispatches to run_episode_via_http (not training_reward)

  6. _wait_for_env raises RuntimeError when server is unreachable

  7. WandB config has correct env-connected tags

"""