""" test_env_connection.py — Validates that train_minimal.py is correctly wired to call the live environment via HTTP. This script verifies: 1. reward_fn signature accepts **kwargs (not positional args like ground_truths) 2. make_row() produces task_id and seed columns 3. run_episode_via_http() makes actual HTTP POST calls 4. _start_env_server_if_needed() raises when server is unreachable 5. The word "no server" / "no HTTP" does NOT appear in the docstring Usage: python train/test_env_connection.py """ import json import os import re import sys import inspect from pathlib import Path from unittest.mock import patch, MagicMock sys.path.insert(0, ".") # Mock out heavy training-only dependencies that may not be installed locally # We only need to test the HTTP wiring logic, not actual GPU training for mod_name in ["wandb", "trl", "trl.GRPOConfig", "trl.GRPOTrainer", "datasets", "datasets.Dataset", "unsloth"]: if mod_name not in sys.modules: sys.modules[mod_name] = MagicMock() # torch may or may not be installed — mock if missing try: import torch except ImportError: sys.modules["torch"] = MagicMock() sys.modules["torch.cuda"] = MagicMock() # ── Test 1: Verify the module docstring says server IS required ───────────── print("Test 1: Checking module docstring...") with open("train/train_minimal.py", encoding="utf-8") as f: source = f.read() # MUST NOT contain anti-patterns forbidden = ["no server required", "no HTTP server", "no server needed", "direct-reward", "no-server"] for phrase in forbidden: if phrase.lower() in source.lower(): print(f" ❌ FAIL: Found forbidden phrase '{phrase}' in train_minimal.py") sys.exit(1) # MUST contain these indicators that it's env-connected required = [ "POST /step", "POST /reset", "/reset", "/step", "env-connected", "http-reward", "MR-2", ] for phrase in required: if phrase not in source: print(f" ❌ FAIL: Missing required phrase '{phrase}' in train_minimal.py") sys.exit(1) print(" ✅ Module docstring correctly declares env-connected training") # ── Test 2: make_row() includes task_id and seed ──────────────────────────── print("Test 2: Checking make_row() output columns...") from server.claim_generator import generate_claim class MockTokenizer: def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=True): return json.dumps(messages) # Import make_row from train.train_minimal import make_row ep = generate_claim(seed=42, fraud_type="medical_inflation", coverage_type="health", difficulty="medium") tok = MockTokenizer() row = make_row(ep, tok) assert "task_id" in row, f"❌ FAIL: make_row() missing 'task_id'. Got keys: {list(row.keys())}" assert "seed" in row, f"❌ FAIL: make_row() missing 'seed'. Got keys: {list(row.keys())}" assert row["task_id"] == "contradictory_claim", f"❌ FAIL: task_id should be 'contradictory_claim', got '{row['task_id']}'" assert row["seed"] == "42", f"❌ FAIL: seed should be '42' (str), got '{row['seed']}'" print(f" ✅ make_row() includes task_id='{row['task_id']}' and seed='{row['seed']}'") # ── Test 3: reward_fn uses **kwargs (not positional) ──────────────────────── print("Test 3: Checking reward_fn signature...") from train.train_minimal import reward_fn sig = inspect.signature(reward_fn) params = list(sig.parameters.keys()) # Must accept **kwargs assert any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()), \ f"❌ FAIL: reward_fn does not accept **kwargs. Params: {params}" # Must NOT have 'expected_signals_list' as a positional param (old signature) assert "expected_signals_list" not in params, \ f"❌ FAIL: reward_fn still has 'expected_signals_list' positional param (old signature)" # Must NOT have 'ground_truths' as a positional param (should come via **kwargs) assert "ground_truths" not in params, \ f"❌ FAIL: reward_fn still has 'ground_truths' as positional param. Should come via **kwargs" print(f" ✅ reward_fn signature: ({', '.join(params)}) — uses **kwargs correctly") # ── Test 4: run_episode_via_http makes HTTP calls ─────────────────────────── print("Test 4: Verifying run_episode_via_http() makes HTTP POST calls...") from train.train_minimal import run_episode_via_http # Mock requests to verify it makes the right calls with patch("train.train_minimal.http_client") as mock_http: # Setup mock responses mock_reset_resp = MagicMock() mock_reset_resp.json.return_value = {"session_id": "test-session-123"} mock_reset_resp.raise_for_status = MagicMock() mock_step_resp = MagicMock() mock_step_resp.json.return_value = {"reward": 0.85, "done": True} mock_step_resp.raise_for_status = MagicMock() mock_http.post.side_effect = [mock_reset_resp, mock_step_resp] reward = run_episode_via_http( task_id="clean_claim", seed=42, decision="approve_claim", confidence="HIGH", reason="All documents verified.", base_url="http://fake:7860", ) # Verify POST /reset was called calls = mock_http.post.call_args_list assert len(calls) == 2, f"❌ FAIL: Expected 2 POST calls, got {len(calls)}" reset_call = calls[0] assert "/reset" in reset_call[0][0], f"❌ FAIL: First POST not to /reset" reset_body = reset_call[1]["json"] assert reset_body["task_id"] == "clean_claim", f"❌ FAIL: /reset body missing task_id" assert reset_body["seed"] == 42, f"❌ FAIL: /reset body missing seed" step_call = calls[1] assert "/step" in step_call[0][0], f"❌ FAIL: Second POST not to /step" step_body = step_call[1]["json"] assert step_body["session_id"] == "test-session-123", f"❌ FAIL: /step missing session_id from /reset" assert step_body["action"]["action_type"] == "approve_claim", f"❌ FAIL: action_type wrong" assert step_body["action"]["confidence"] == "HIGH", f"❌ FAIL: confidence wrong" assert reward == 0.85, f"❌ FAIL: reward should be 0.85 from /step, got {reward}" print(" ✅ run_episode_via_http() makes POST /reset then POST /step correctly") print(f" → /reset body: {{task_id, seed}}") print(f" → /step body: {{action: {{action_type, confidence, reasoning}}, session_id}}") print(f" → reward returned from /step response: 0.85") # ── Test 5: reward_fn calls run_episode_via_http (not training_reward) ────── print("Test 5: Verifying reward_fn calls HTTP, not training_reward()...") with patch("train.train_minimal.run_episode_via_http") as mock_episode: mock_episode.return_value = 0.75 completions = [ [{"content": "DECISION: approve_claim\nCONFIDENCE: HIGH\nREASON: docs verified"}], [{"content": "DECISION: deny_claim\nCONFIDENCE: MED\nREASON: suspicious docs"}], ] prompts = ["prompt1", "prompt2"] rewards = reward_fn( completions, prompts, task_id=["clean_claim", "contradictory_claim"], seed=["42", "43"], ground_truth=["approve_claim", "deny_claim"], ) assert mock_episode.call_count == 2, f"❌ FAIL: Expected 2 HTTP calls, got {mock_episode.call_count}" assert rewards == [0.75, 0.75], f"❌ FAIL: rewards should be [0.75, 0.75], got {rewards}" print(" ✅ reward_fn calls run_episode_via_http() for each completion") # ── Test 6: _start_env_server_if_needed fails without server ──────────────── print("Test 6: Verifying training fails without server...") from train.train_minimal import _wait_for_env try: # Use very short retries to a port that's definitely not running _wait_for_env("http://localhost:19999", retries=1) print(" ❌ FAIL: Should have raised RuntimeError when server is unreachable") sys.exit(1) except RuntimeError as e: assert "not reachable" in str(e).lower(), f"❌ FAIL: Error message unclear: {e}" print(f" ✅ _wait_for_env raises RuntimeError when server is down") # ── Test 7: WandB config says env-connected ───────────────────────────────── print("Test 7: Checking WandB tags and config...") assert '"env-connected"' in source, "❌ FAIL: WandB tags don't include 'env-connected'" assert '"http-reward"' in source, "❌ FAIL: WandB tags don't include 'http-reward'" assert '"env_http_reward"' in source, "❌ FAIL: reward_type not set to 'env_http_reward'" assert '"no-server"' not in source, "❌ FAIL: WandB tags still contain 'no-server'" assert '"direct-reward"' not in source, "❌ FAIL: WandB tags still contain 'direct-reward'" print(" ✅ WandB config correctly reflects env-connected training") # ── Final Summary ─────────────────────────────────────────────────────────── print() print("=" * 70) print(" ALL 7 TESTS PASSED ✅") print() print(" MR-2 Compliance verified:") print(" • reward_fn calls POST /reset + POST /step (not training_reward)") print(" • make_row() includes task_id + seed for /reset") print(" • Training WILL FAIL if environment server is not running") print(" • No 'no-server' or 'direct-reward' remnants in code") print("=" * 70) """ This script validates: 1. The module docstring declares env-connected training 2. make_row() includes task_id and seed columns 3. reward_fn uses **kwargs (not positional args) 4. run_episode_via_http() makes correct POST /reset then POST /step 5. reward_fn dispatches to run_episode_via_http (not training_reward) 6. _wait_for_env raises RuntimeError when server is unreachable 7. WandB config has correct env-connected tags """