Spaces:
Running
Running
File size: 10,238 Bytes
b4ac377 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | """
test_env_connection.py — Validates that train_minimal.py is correctly
wired to call the live environment via HTTP.
This script verifies:
1. reward_fn signature accepts **kwargs (not positional args like ground_truths)
2. make_row() produces task_id and seed columns
3. run_episode_via_http() makes actual HTTP POST calls
4. _start_env_server_if_needed() raises when server is unreachable
5. The word "no server" / "no HTTP" does NOT appear in the docstring
Usage:
python train/test_env_connection.py
"""
import json
import os
import re
import sys
import inspect
from pathlib import Path
from unittest.mock import patch, MagicMock
sys.path.insert(0, ".")
# Mock out heavy training-only dependencies that may not be installed locally
# We only need to test the HTTP wiring logic, not actual GPU training
for mod_name in ["wandb", "trl", "trl.GRPOConfig", "trl.GRPOTrainer",
"datasets", "datasets.Dataset", "unsloth"]:
if mod_name not in sys.modules:
sys.modules[mod_name] = MagicMock()
# torch may or may not be installed — mock if missing
try:
import torch
except ImportError:
sys.modules["torch"] = MagicMock()
sys.modules["torch.cuda"] = MagicMock()
# ── Test 1: Verify the module docstring says server IS required ─────────────
print("Test 1: Checking module docstring...")
with open("train/train_minimal.py", encoding="utf-8") as f:
source = f.read()
# MUST NOT contain anti-patterns
forbidden = ["no server required", "no HTTP server", "no server needed", "direct-reward", "no-server"]
for phrase in forbidden:
if phrase.lower() in source.lower():
print(f" ❌ FAIL: Found forbidden phrase '{phrase}' in train_minimal.py")
sys.exit(1)
# MUST contain these indicators that it's env-connected
required = [
"POST /step",
"POST /reset",
"/reset",
"/step",
"env-connected",
"http-reward",
"MR-2",
]
for phrase in required:
if phrase not in source:
print(f" ❌ FAIL: Missing required phrase '{phrase}' in train_minimal.py")
sys.exit(1)
print(" ✅ Module docstring correctly declares env-connected training")
# ── Test 2: make_row() includes task_id and seed ────────────────────────────
print("Test 2: Checking make_row() output columns...")
from server.claim_generator import generate_claim
class MockTokenizer:
def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True):
return json.dumps(messages)
# Import make_row
from train.train_minimal import make_row
ep = generate_claim(seed=42, fraud_type="medical_inflation", coverage_type="health", difficulty="medium")
tok = MockTokenizer()
row = make_row(ep, tok)
assert "task_id" in row, f"❌ FAIL: make_row() missing 'task_id'. Got keys: {list(row.keys())}"
assert "seed" in row, f"❌ FAIL: make_row() missing 'seed'. Got keys: {list(row.keys())}"
assert row["task_id"] == "contradictory_claim", f"❌ FAIL: task_id should be 'contradictory_claim', got '{row['task_id']}'"
assert row["seed"] == "42", f"❌ FAIL: seed should be '42' (str), got '{row['seed']}'"
print(f" ✅ make_row() includes task_id='{row['task_id']}' and seed='{row['seed']}'")
# ── Test 3: reward_fn uses **kwargs (not positional) ────────────────────────
print("Test 3: Checking reward_fn signature...")
from train.train_minimal import reward_fn
sig = inspect.signature(reward_fn)
params = list(sig.parameters.keys())
# Must accept **kwargs
assert any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()), \
f"❌ FAIL: reward_fn does not accept **kwargs. Params: {params}"
# Must NOT have 'expected_signals_list' as a positional param (old signature)
assert "expected_signals_list" not in params, \
f"❌ FAIL: reward_fn still has 'expected_signals_list' positional param (old signature)"
# Must NOT have 'ground_truths' as a positional param (should come via **kwargs)
assert "ground_truths" not in params, \
f"❌ FAIL: reward_fn still has 'ground_truths' as positional param. Should come via **kwargs"
print(f" ✅ reward_fn signature: ({', '.join(params)}) — uses **kwargs correctly")
# ── Test 4: run_episode_via_http makes HTTP calls ───────────────────────────
print("Test 4: Verifying run_episode_via_http() makes HTTP POST calls...")
from train.train_minimal import run_episode_via_http
# Mock requests to verify it makes the right calls
with patch("train.train_minimal.http_client") as mock_http:
# Setup mock responses
mock_reset_resp = MagicMock()
mock_reset_resp.json.return_value = {"session_id": "test-session-123"}
mock_reset_resp.raise_for_status = MagicMock()
mock_step_resp = MagicMock()
mock_step_resp.json.return_value = {"reward": 0.85, "done": True}
mock_step_resp.raise_for_status = MagicMock()
mock_http.post.side_effect = [mock_reset_resp, mock_step_resp]
reward = run_episode_via_http(
task_id="clean_claim",
seed=42,
decision="approve_claim",
confidence="HIGH",
reason="All documents verified.",
base_url="http://fake:7860",
)
# Verify POST /reset was called
calls = mock_http.post.call_args_list
assert len(calls) == 2, f"❌ FAIL: Expected 2 POST calls, got {len(calls)}"
reset_call = calls[0]
assert "/reset" in reset_call[0][0], f"❌ FAIL: First POST not to /reset"
reset_body = reset_call[1]["json"]
assert reset_body["task_id"] == "clean_claim", f"❌ FAIL: /reset body missing task_id"
assert reset_body["seed"] == 42, f"❌ FAIL: /reset body missing seed"
step_call = calls[1]
assert "/step" in step_call[0][0], f"❌ FAIL: Second POST not to /step"
step_body = step_call[1]["json"]
assert step_body["session_id"] == "test-session-123", f"❌ FAIL: /step missing session_id from /reset"
assert step_body["action"]["action_type"] == "approve_claim", f"❌ FAIL: action_type wrong"
assert step_body["action"]["confidence"] == "HIGH", f"❌ FAIL: confidence wrong"
assert reward == 0.85, f"❌ FAIL: reward should be 0.85 from /step, got {reward}"
print(" ✅ run_episode_via_http() makes POST /reset then POST /step correctly")
print(f" → /reset body: {{task_id, seed}}")
print(f" → /step body: {{action: {{action_type, confidence, reasoning}}, session_id}}")
print(f" → reward returned from /step response: 0.85")
# ── Test 5: reward_fn calls run_episode_via_http (not training_reward) ──────
print("Test 5: Verifying reward_fn calls HTTP, not training_reward()...")
with patch("train.train_minimal.run_episode_via_http") as mock_episode:
mock_episode.return_value = 0.75
completions = [
[{"content": "DECISION: approve_claim\nCONFIDENCE: HIGH\nREASON: docs verified"}],
[{"content": "DECISION: deny_claim\nCONFIDENCE: MED\nREASON: suspicious docs"}],
]
prompts = ["prompt1", "prompt2"]
rewards = reward_fn(
completions,
prompts,
task_id=["clean_claim", "contradictory_claim"],
seed=["42", "43"],
ground_truth=["approve_claim", "deny_claim"],
)
assert mock_episode.call_count == 2, f"❌ FAIL: Expected 2 HTTP calls, got {mock_episode.call_count}"
assert rewards == [0.75, 0.75], f"❌ FAIL: rewards should be [0.75, 0.75], got {rewards}"
print(" ✅ reward_fn calls run_episode_via_http() for each completion")
# ── Test 6: _start_env_server_if_needed fails without server ────────────────
print("Test 6: Verifying training fails without server...")
from train.train_minimal import _wait_for_env
try:
# Use very short retries to a port that's definitely not running
_wait_for_env("http://localhost:19999", retries=1)
print(" ❌ FAIL: Should have raised RuntimeError when server is unreachable")
sys.exit(1)
except RuntimeError as e:
assert "not reachable" in str(e).lower(), f"❌ FAIL: Error message unclear: {e}"
print(f" ✅ _wait_for_env raises RuntimeError when server is down")
# ── Test 7: WandB config says env-connected ─────────────────────────────────
print("Test 7: Checking WandB tags and config...")
assert '"env-connected"' in source, "❌ FAIL: WandB tags don't include 'env-connected'"
assert '"http-reward"' in source, "❌ FAIL: WandB tags don't include 'http-reward'"
assert '"env_http_reward"' in source, "❌ FAIL: reward_type not set to 'env_http_reward'"
assert '"no-server"' not in source, "❌ FAIL: WandB tags still contain 'no-server'"
assert '"direct-reward"' not in source, "❌ FAIL: WandB tags still contain 'direct-reward'"
print(" ✅ WandB config correctly reflects env-connected training")
# ── Final Summary ───────────────────────────────────────────────────────────
print()
print("=" * 70)
print(" ALL 7 TESTS PASSED ✅")
print()
print(" MR-2 Compliance verified:")
print(" • reward_fn calls POST /reset + POST /step (not training_reward)")
print(" • make_row() includes task_id + seed for /reset")
print(" • Training WILL FAIL if environment server is not running")
print(" • No 'no-server' or 'direct-reward' remnants in code")
print("=" * 70)
"""
This script validates:
1. The module docstring declares env-connected training
2. make_row() includes task_id and seed columns
3. reward_fn uses **kwargs (not positional args)
4. run_episode_via_http() makes correct POST /reset then POST /step
5. reward_fn dispatches to run_episode_via_http (not training_reward)
6. _wait_for_env raises RuntimeError when server is unreachable
7. WandB config has correct env-connected tags
"""
|