""" Theme-4 training starter for SmartPayEnv. This script demonstrates a novel self-improvement loop: 1) sample K candidate actions per observation 2) score each candidate with /simulate rewards (group-relative signal) 3) collect best/worst pairs for preference-style post-training It is intentionally lightweight so teams can run it in Colab with TRL/Unsloth. """ from __future__ import annotations import argparse import json import os import random from dataclasses import dataclass from typing import Any import requests ENV_URL = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/") MAX_STEPS = int(os.getenv("MAX_STEPS", "200")) GROUP_SIZE = int(os.getenv("GROUP_SIZE", "8")) DIFFICULTY = int(os.getenv("DIFFICULTY", "2")) RANDOM_SEED = int(os.getenv("SEED", "42")) @dataclass class RolloutExample: prompt: str chosen: str rejected: str chosen_reward: float rejected_reward: float def _action_candidates() -> list[dict[str, int]]: all_actions: list[dict[str, int]] = [] for gateway in (0, 1, 2): for fraud_decision in (0, 1, 2, 3): for retry_strategy in (0, 1): all_actions.append( { "gateway": gateway, "fraud_decision": fraud_decision, "retry_strategy": retry_strategy, } ) random.shuffle(all_actions) return all_actions def _simulate_reward(action: dict[str, int]) -> float: response = requests.post(f"{ENV_URL}/simulate", json={"action": action}, timeout=30) response.raise_for_status() obs = response.json() return float(obs.get("reward", 0.0)) def _step(action: dict[str, int]) -> dict[str, Any]: response = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30) response.raise_for_status() return response.json() def _reset(difficulty: int = 2) -> dict[str, Any]: response = requests.post(f"{ENV_URL}/reset", json={"difficulty": difficulty}, timeout=30) response.raise_for_status() payload = response.json() return payload.get("observation", payload) def collect_group_relative_pairs( max_steps: int = MAX_STEPS, group_size: int = GROUP_SIZE, difficulty: int = DIFFICULTY, seed: int = RANDOM_SEED, ) -> list[RolloutExample]: random.seed(seed) obs = _reset(difficulty=difficulty) dataset: list[RolloutExample] = [] actions_pool = _action_candidates() for _ in range(max_steps): sampled = random.sample(actions_pool, k=min(group_size, len(actions_pool))) scored: list[tuple[dict[str, int], float]] = [] for action in sampled: try: reward = _simulate_reward(action) scored.append((action, reward)) except requests.RequestException: continue if len(scored) < 2: break scored.sort(key=lambda x: x[1], reverse=True) best_action, best_reward = scored[0] worst_action, worst_reward = scored[-1] prompt = ( "SmartPayEnv observation:\n" f"{json.dumps(obs, sort_keys=True)}\n" "Return one action JSON with fields: gateway, fraud_decision, retry_strategy." ) dataset.append( RolloutExample( prompt=prompt, chosen=json.dumps(best_action, sort_keys=True), rejected=json.dumps(worst_action, sort_keys=True), chosen_reward=best_reward, rejected_reward=worst_reward, ) ) step_payload = _step(best_action) obs = step_payload.get("observation", step_payload) if bool(obs.get("done", False)): obs = _reset(difficulty=difficulty) return dataset def export_jsonl(dataset: list[RolloutExample], output_path: str) -> None: with open(output_path, "w", encoding="utf-8") as f: for row in dataset: f.write( json.dumps( { "prompt": row.prompt, "chosen": row.chosen, "rejected": row.rejected, "chosen_reward": row.chosen_reward, "rejected_reward": row.rejected_reward, } ) + "\n" ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Collect group-relative preference pairs from SmartPayEnv.") parser.add_argument("--env-url", default=ENV_URL, help="SmartPayEnv server URL") parser.add_argument("--max-steps", type=int, default=MAX_STEPS, help="Number of rollout steps") parser.add_argument("--group-size", type=int, default=GROUP_SIZE, help="Actions sampled per step") parser.add_argument("--difficulty", type=int, default=DIFFICULTY, help="Environment difficulty 0/1/2") parser.add_argument("--seed", type=int, default=RANDOM_SEED, help="Random seed") parser.add_argument("--output", default="theme4_grpo_pairs.jsonl", help="Output JSONL path") args = parser.parse_args() ENV_URL = args.env_url.rstrip("/") data = collect_group_relative_pairs( max_steps=args.max_steps, group_size=args.group_size, difficulty=args.difficulty, seed=args.seed, ) export_jsonl(data, args.output) print(f"Collected {len(data)} preference pairs into {args.output}")