Spaces:
Sleeping
Sleeping
| """ | |
| Theme-4 training starter for SmartPayEnv. | |
| This script demonstrates a novel self-improvement loop: | |
| 1) sample K candidate actions per observation | |
| 2) score each candidate with /simulate rewards (group-relative signal) | |
| 3) collect best/worst pairs for preference-style post-training | |
| It is intentionally lightweight so teams can run it in Colab with TRL/Unsloth. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import requests | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:7860").rstrip("/") | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "200")) | |
| GROUP_SIZE = int(os.getenv("GROUP_SIZE", "8")) | |
| DIFFICULTY = int(os.getenv("DIFFICULTY", "2")) | |
| RANDOM_SEED = int(os.getenv("SEED", "42")) | |
| class RolloutExample: | |
| prompt: str | |
| chosen: str | |
| rejected: str | |
| chosen_reward: float | |
| rejected_reward: float | |
| def _action_candidates() -> list[dict[str, int]]: | |
| all_actions: list[dict[str, int]] = [] | |
| for gateway in (0, 1, 2): | |
| for fraud_decision in (0, 1, 2, 3): | |
| for retry_strategy in (0, 1): | |
| all_actions.append( | |
| { | |
| "gateway": gateway, | |
| "fraud_decision": fraud_decision, | |
| "retry_strategy": retry_strategy, | |
| } | |
| ) | |
| random.shuffle(all_actions) | |
| return all_actions | |
| def _simulate_reward(action: dict[str, int]) -> float: | |
| response = requests.post(f"{ENV_URL}/simulate", json={"action": action}, timeout=30) | |
| response.raise_for_status() | |
| obs = response.json() | |
| return float(obs.get("reward", 0.0)) | |
| def _step(action: dict[str, int]) -> dict[str, Any]: | |
| response = requests.post(f"{ENV_URL}/step", json={"action": action}, timeout=30) | |
| response.raise_for_status() | |
| return response.json() | |
| def _reset(difficulty: int = 2) -> dict[str, Any]: | |
| response = requests.post(f"{ENV_URL}/reset", json={"difficulty": difficulty}, timeout=30) | |
| response.raise_for_status() | |
| payload = response.json() | |
| return payload.get("observation", payload) | |
| def collect_group_relative_pairs( | |
| max_steps: int = MAX_STEPS, | |
| group_size: int = GROUP_SIZE, | |
| difficulty: int = DIFFICULTY, | |
| seed: int = RANDOM_SEED, | |
| ) -> list[RolloutExample]: | |
| random.seed(seed) | |
| obs = _reset(difficulty=difficulty) | |
| dataset: list[RolloutExample] = [] | |
| actions_pool = _action_candidates() | |
| for _ in range(max_steps): | |
| sampled = random.sample(actions_pool, k=min(group_size, len(actions_pool))) | |
| scored: list[tuple[dict[str, int], float]] = [] | |
| for action in sampled: | |
| try: | |
| reward = _simulate_reward(action) | |
| scored.append((action, reward)) | |
| except requests.RequestException: | |
| continue | |
| if len(scored) < 2: | |
| break | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| best_action, best_reward = scored[0] | |
| worst_action, worst_reward = scored[-1] | |
| prompt = ( | |
| "SmartPayEnv observation:\n" | |
| f"{json.dumps(obs, sort_keys=True)}\n" | |
| "Return one action JSON with fields: gateway, fraud_decision, retry_strategy." | |
| ) | |
| dataset.append( | |
| RolloutExample( | |
| prompt=prompt, | |
| chosen=json.dumps(best_action, sort_keys=True), | |
| rejected=json.dumps(worst_action, sort_keys=True), | |
| chosen_reward=best_reward, | |
| rejected_reward=worst_reward, | |
| ) | |
| ) | |
| step_payload = _step(best_action) | |
| obs = step_payload.get("observation", step_payload) | |
| if bool(obs.get("done", False)): | |
| obs = _reset(difficulty=difficulty) | |
| return dataset | |
| def export_jsonl(dataset: list[RolloutExample], output_path: str) -> None: | |
| with open(output_path, "w", encoding="utf-8") as f: | |
| for row in dataset: | |
| f.write( | |
| json.dumps( | |
| { | |
| "prompt": row.prompt, | |
| "chosen": row.chosen, | |
| "rejected": row.rejected, | |
| "chosen_reward": row.chosen_reward, | |
| "rejected_reward": row.rejected_reward, | |
| } | |
| ) | |
| + "\n" | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Collect group-relative preference pairs from SmartPayEnv.") | |
| parser.add_argument("--env-url", default=ENV_URL, help="SmartPayEnv server URL") | |
| parser.add_argument("--max-steps", type=int, default=MAX_STEPS, help="Number of rollout steps") | |
| parser.add_argument("--group-size", type=int, default=GROUP_SIZE, help="Actions sampled per step") | |
| parser.add_argument("--difficulty", type=int, default=DIFFICULTY, help="Environment difficulty 0/1/2") | |
| parser.add_argument("--seed", type=int, default=RANDOM_SEED, help="Random seed") | |
| parser.add_argument("--output", default="theme4_grpo_pairs.jsonl", help="Output JSONL path") | |
| args = parser.parse_args() | |
| ENV_URL = args.env_url.rstrip("/") | |
| data = collect_group_relative_pairs( | |
| max_steps=args.max_steps, | |
| group_size=args.group_size, | |
| difficulty=args.difficulty, | |
| seed=args.seed, | |
| ) | |
| export_jsonl(data, args.output) | |
| print(f"Collected {len(data)} preference pairs into {args.output}") | |