Spaces:
Sleeping
Sleeping
Update training
Browse files- notebooks/train_smartpayenev.ipynb +54 -29
- server/app.py +44 -8
notebooks/train_smartpayenev.ipynb
CHANGED
|
@@ -120,25 +120,34 @@
|
|
| 120 |
"DIFFICULTY = 2\n",
|
| 121 |
"SEED = 42\n",
|
| 122 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
"# Co-evolution loop\n",
|
| 124 |
-
"N_ROUNDS =
|
| 125 |
-
"GRPO_STEPS_PER_ROUND =
|
| 126 |
-
"ES_STEPS_PER_ROUND =
|
| 127 |
-
"ES_POPULATION =
|
| 128 |
-
"ES_SIGMA = 0.25
|
| 129 |
-
"ES_LR = 0.4
|
| 130 |
-
"\n",
|
| 131 |
-
"# Defender / GRPO\n",
|
| 132 |
-
"PROMPT_DATASET_SIZE =
|
| 133 |
-
"GRPO_NUM_GENERATIONS =
|
| 134 |
-
"ROLLOUT_STEPS_PER_REWARD =
|
| 135 |
-
"\n",
|
| 136 |
-
"#
|
| 137 |
-
"EVAL_EPISODES =
|
| 138 |
-
"EVAL_STEPS_PER_EPISODE =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"\n",
|
| 140 |
"MODEL_ID = 'unsloth/Qwen2.5-0.5B-Instruct'\n",
|
| 141 |
-
"MAX_SEQ_LEN = 2048\n",
|
| 142 |
"LOAD_IN_4BIT = True\n",
|
| 143 |
"\n",
|
| 144 |
"os.makedirs('artifacts', exist_ok=True)\n",
|
|
@@ -148,6 +157,9 @@
|
|
| 148 |
" '| ROUNDS =', N_ROUNDS,\n",
|
| 149 |
" '| GRPO/round =', GRPO_STEPS_PER_ROUND,\n",
|
| 150 |
" '| ES/round =', ES_STEPS_PER_ROUND,\n",
|
|
|
|
|
|
|
|
|
|
| 151 |
" '| MODEL_ID =', MODEL_ID)"
|
| 152 |
]
|
| 153 |
},
|
|
@@ -453,8 +465,12 @@
|
|
| 453 |
" def apply(self):\n",
|
| 454 |
" env_configure_adversary(**self.theta, strategy='mixed')\n",
|
| 455 |
"\n",
|
| 456 |
-
" def evaluate_against_defender(self, defender_fn,
|
| 457 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
" rewards = []\n",
|
| 459 |
" for ep in range(int(n_episodes)):\n",
|
| 460 |
" obs = env_reset_seeded(seed=10_000 + ep, difficulty=DIFFICULTY)\n",
|
|
@@ -612,14 +628,21 @@
|
|
| 612 |
" return rewards\n",
|
| 613 |
"\n",
|
| 614 |
"# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 615 |
"@torch.no_grad()\n",
|
| 616 |
"def _defender_action(obs):\n",
|
| 617 |
" FastLanguageModel.for_inference(model)\n",
|
| 618 |
" device = next(model.parameters()).device\n",
|
| 619 |
" prompt = make_prompt(obs)\n",
|
| 620 |
-
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True,
|
|
|
|
| 621 |
" out = model.generate(\n",
|
| 622 |
-
" **inputs, max_new_tokens=
|
| 623 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 624 |
" )\n",
|
| 625 |
" text = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
|
@@ -631,12 +654,12 @@
|
|
| 631 |
" return GRPOConfig(\n",
|
| 632 |
" output_dir='outputs/theme4_grpo_unsloth',\n",
|
| 633 |
" num_generations=GRPO_NUM_GENERATIONS,\n",
|
| 634 |
-
" max_prompt_length=
|
| 635 |
-
" max_completion_length=
|
| 636 |
" per_device_train_batch_size=1,\n",
|
| 637 |
" gradient_accumulation_steps=2,\n",
|
| 638 |
" max_steps=int(max_steps),\n",
|
| 639 |
-
" logging_steps=
|
| 640 |
" learning_rate=1e-5,\n",
|
| 641 |
" save_strategy='no',\n",
|
| 642 |
" report_to=[],\n",
|
|
@@ -653,12 +676,13 @@
|
|
| 653 |
"loss_history_all = []\n",
|
| 654 |
"reward_log_all = []\n",
|
| 655 |
"\n",
|
| 656 |
-
"# Quick eval helper (
|
| 657 |
-
"
|
|
|
|
| 658 |
" rs = []\n",
|
| 659 |
-
" for ep in range(n_eps):\n",
|
| 660 |
" obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
|
| 661 |
-
" for _ in range(n_steps):\n",
|
| 662 |
" a = _defender_action(obs)\n",
|
| 663 |
" payload = env_step(a)\n",
|
| 664 |
" obs = payload.get('observation', payload)\n",
|
|
@@ -743,11 +767,12 @@
|
|
| 743 |
"\n",
|
| 744 |
"def trained_policy(obs):\n",
|
| 745 |
" prompt = make_prompt(obs)\n",
|
| 746 |
-
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True,
|
|
|
|
| 747 |
" with torch.no_grad():\n",
|
| 748 |
" out = model.generate(\n",
|
| 749 |
" **inputs,\n",
|
| 750 |
-
" max_new_tokens=
|
| 751 |
" do_sample=False,\n",
|
| 752 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 753 |
" )\n",
|
|
|
|
| 120 |
"DIFFICULTY = 2\n",
|
| 121 |
"SEED = 42\n",
|
| 122 |
"\n",
|
| 123 |
+
"# ── Minimal-viable QUICK config — every variable dialled to the lowest\n",
|
| 124 |
+
"# value that still produces all 7 plots + meaningful accuracy comparison.\n",
|
| 125 |
+
"# Approx wall time on a Colab T4: QUICK ~3-5 min, FULL ~12-18 min.\n",
|
| 126 |
+
"\n",
|
| 127 |
"# Co-evolution loop\n",
|
| 128 |
+
"N_ROUNDS = 2 if QUICK_MODE else 4 # need >=2 to see co-evolution curve\n",
|
| 129 |
+
"GRPO_STEPS_PER_ROUND = 4 if QUICK_MODE else 20\n",
|
| 130 |
+
"ES_STEPS_PER_ROUND = 2 if QUICK_MODE else 6\n",
|
| 131 |
+
"ES_POPULATION = 3 if QUICK_MODE else 6 # ES needs >=3 for ranked weights\n",
|
| 132 |
+
"ES_SIGMA = 0.25 # exploration std for ES\n",
|
| 133 |
+
"ES_LR = 0.4 # ES update rate\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"# Defender / GRPO (rewards are mean over a K-step rollout)\n",
|
| 136 |
+
"PROMPT_DATASET_SIZE = 16 if QUICK_MODE else 96\n",
|
| 137 |
+
"GRPO_NUM_GENERATIONS = 4 if QUICK_MODE else 6 # >=2 for group-relative advantage\n",
|
| 138 |
+
"ROLLOUT_STEPS_PER_REWARD = 2 if QUICK_MODE else 4\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"# Final frozen-holdout eval\n",
|
| 141 |
+
"EVAL_EPISODES = 2 if QUICK_MODE else 4\n",
|
| 142 |
+
"EVAL_STEPS_PER_EPISODE = 15 if QUICK_MODE else 40\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# Inner micro-eval used by ES + per-round defender check (called many times,\n",
|
| 145 |
+
"# so keep these tiny — they dominate co-training wall time).\n",
|
| 146 |
+
"COEVO_EVAL_EPISODES = 1 if QUICK_MODE else 2\n",
|
| 147 |
+
"COEVO_EVAL_STEPS = 6 if QUICK_MODE else 12\n",
|
| 148 |
"\n",
|
| 149 |
"MODEL_ID = 'unsloth/Qwen2.5-0.5B-Instruct'\n",
|
| 150 |
+
"MAX_SEQ_LEN = 1024 if QUICK_MODE else 2048\n",
|
| 151 |
"LOAD_IN_4BIT = True\n",
|
| 152 |
"\n",
|
| 153 |
"os.makedirs('artifacts', exist_ok=True)\n",
|
|
|
|
| 157 |
" '| ROUNDS =', N_ROUNDS,\n",
|
| 158 |
" '| GRPO/round =', GRPO_STEPS_PER_ROUND,\n",
|
| 159 |
" '| ES/round =', ES_STEPS_PER_ROUND,\n",
|
| 160 |
+
" '| pop =', ES_POPULATION,\n",
|
| 161 |
+
" '| K-rollout =', ROLLOUT_STEPS_PER_REWARD,\n",
|
| 162 |
+
" '| eval =', f'{EVAL_EPISODES}x{EVAL_STEPS_PER_EPISODE}',\n",
|
| 163 |
" '| MODEL_ID =', MODEL_ID)"
|
| 164 |
]
|
| 165 |
},
|
|
|
|
| 465 |
" def apply(self):\n",
|
| 466 |
" env_configure_adversary(**self.theta, strategy='mixed')\n",
|
| 467 |
"\n",
|
| 468 |
+
" def evaluate_against_defender(self, defender_fn,\n",
|
| 469 |
+
" n_episodes=COEVO_EVAL_EPISODES,\n",
|
| 470 |
+
" n_steps=COEVO_EVAL_STEPS):\n",
|
| 471 |
+
" \"\"\"Defender_fn(obs)->action_dict. Returns mean defender reward (lower = harder fraud).\n",
|
| 472 |
+
" Defaults are intentionally tiny — this is called ES_POPULATION times per\n",
|
| 473 |
+
" ES step, so any extra step here multiplies the wall time fast.\"\"\"\n",
|
| 474 |
" rewards = []\n",
|
| 475 |
" for ep in range(int(n_episodes)):\n",
|
| 476 |
" obs = env_reset_seeded(seed=10_000 + ep, difficulty=DIFFICULTY)\n",
|
|
|
|
| 628 |
" return rewards\n",
|
| 629 |
"\n",
|
| 630 |
"# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
|
| 631 |
+
"# Cap inputs/outputs aggressively so each defender call is ~few hundred ms,\n",
|
| 632 |
+
"# not seconds. ES calls this ES_POPULATION * COEVO_EVAL_EPISODES * COEVO_EVAL_STEPS\n",
|
| 633 |
+
"# times per ES step, so latency here dominates total wall time.\n",
|
| 634 |
+
"_DEF_MAX_PROMPT = 512 if QUICK_MODE else 1024\n",
|
| 635 |
+
"_DEF_MAX_NEW = 24 if QUICK_MODE else 48\n",
|
| 636 |
+
"\n",
|
| 637 |
"@torch.no_grad()\n",
|
| 638 |
"def _defender_action(obs):\n",
|
| 639 |
" FastLanguageModel.for_inference(model)\n",
|
| 640 |
" device = next(model.parameters()).device\n",
|
| 641 |
" prompt = make_prompt(obs)\n",
|
| 642 |
+
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True,\n",
|
| 643 |
+
" max_length=_DEF_MAX_PROMPT).to(device)\n",
|
| 644 |
" out = model.generate(\n",
|
| 645 |
+
" **inputs, max_new_tokens=_DEF_MAX_NEW, do_sample=False,\n",
|
| 646 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 647 |
" )\n",
|
| 648 |
" text = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
|
|
|
|
| 654 |
" return GRPOConfig(\n",
|
| 655 |
" output_dir='outputs/theme4_grpo_unsloth',\n",
|
| 656 |
" num_generations=GRPO_NUM_GENERATIONS,\n",
|
| 657 |
+
" max_prompt_length=_DEF_MAX_PROMPT,\n",
|
| 658 |
+
" max_completion_length=_DEF_MAX_NEW,\n",
|
| 659 |
" per_device_train_batch_size=1,\n",
|
| 660 |
" gradient_accumulation_steps=2,\n",
|
| 661 |
" max_steps=int(max_steps),\n",
|
| 662 |
+
" logging_steps=1,\n",
|
| 663 |
" learning_rate=1e-5,\n",
|
| 664 |
" save_strategy='no',\n",
|
| 665 |
" report_to=[],\n",
|
|
|
|
| 676 |
"loss_history_all = []\n",
|
| 677 |
"reward_log_all = []\n",
|
| 678 |
"\n",
|
| 679 |
+
"# Quick eval helper — tiny by design (called 3x per round: once after defender\n",
|
| 680 |
+
"# phase, twice for the exploitability gap). Uses the same COEVO_* knobs.\n",
|
| 681 |
+
"def quick_defender_eval(n_eps=COEVO_EVAL_EPISODES, n_steps=COEVO_EVAL_STEPS):\n",
|
| 682 |
" rs = []\n",
|
| 683 |
+
" for ep in range(int(n_eps)):\n",
|
| 684 |
" obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
|
| 685 |
+
" for _ in range(int(n_steps)):\n",
|
| 686 |
" a = _defender_action(obs)\n",
|
| 687 |
" payload = env_step(a)\n",
|
| 688 |
" obs = payload.get('observation', payload)\n",
|
|
|
|
| 767 |
"\n",
|
| 768 |
"def trained_policy(obs):\n",
|
| 769 |
" prompt = make_prompt(obs)\n",
|
| 770 |
+
" inputs = tokenizer(prompt, return_tensors='pt', truncation=True,\n",
|
| 771 |
+
" max_length=_DEF_MAX_PROMPT).to(device)\n",
|
| 772 |
" with torch.no_grad():\n",
|
| 773 |
" out = model.generate(\n",
|
| 774 |
" **inputs,\n",
|
| 775 |
+
" max_new_tokens=_DEF_MAX_NEW,\n",
|
| 776 |
" do_sample=False,\n",
|
| 777 |
" pad_token_id=tokenizer.pad_token_id,\n",
|
| 778 |
" )\n",
|
server/app.py
CHANGED
|
@@ -45,9 +45,48 @@ except (ImportError, ValueError):
|
|
| 45 |
from server.SmartPayEnv_environment import SmartpayenvEnvironment
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Create the app with web interface and README integration
|
| 49 |
app = create_app(
|
| 50 |
-
|
| 51 |
SmartpayenvAction,
|
| 52 |
SmartpayenvObservation,
|
| 53 |
env_name="SmartPayEnv",
|
|
@@ -57,11 +96,8 @@ app = create_app(
|
|
| 57 |
|
| 58 |
@app.post("/simulate", response_model=SmartpayenvObservation)
|
| 59 |
async def simulate(action: SmartpayenvAction):
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
"""
|
| 63 |
-
# OpenEnv environments are stored in app.env
|
| 64 |
-
return app.env.simulate(action)
|
| 65 |
|
| 66 |
|
| 67 |
# ── Theme-4 co-evolution endpoints ────────────────────────────────────
|
|
@@ -85,7 +121,7 @@ class SeededReset(BaseModel):
|
|
| 85 |
@app.post("/configure_adversary")
|
| 86 |
async def configure_adversary(cfg: AdversaryConfig):
|
| 87 |
"""Set the learnable fraud agent's behaviour. Returns the active config."""
|
| 88 |
-
return
|
| 89 |
intensity=cfg.intensity,
|
| 90 |
noise_boost=cfg.noise_boost,
|
| 91 |
pattern_rate=cfg.pattern_rate,
|
|
@@ -97,7 +133,7 @@ async def configure_adversary(cfg: AdversaryConfig):
|
|
| 97 |
async def reset_seeded(req: SeededReset):
|
| 98 |
"""Deterministic reset: same `seed` => same starting trajectory.
|
| 99 |
Useful for GRPO so all completions in a group share the same state."""
|
| 100 |
-
return
|
| 101 |
|
| 102 |
|
| 103 |
def main():
|
|
|
|
| 45 |
from server.SmartPayEnv_environment import SmartpayenvEnvironment
|
| 46 |
|
| 47 |
|
| 48 |
+
# ── Singleton env so custom endpoints share state with openenv ─────────
|
| 49 |
+
# Different openenv versions store the env in different places
|
| 50 |
+
# (app.env, app.state.env, per-request factory, etc.). Rather than
|
| 51 |
+
# guessing, we use a singleton subclass: no matter how many times
|
| 52 |
+
# openenv instantiates the env class, it always gets the same object,
|
| 53 |
+
# and we can always reach it via _SHARED_ENV.
|
| 54 |
+
_SHARED_ENV: SmartpayenvEnvironment | None = None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SharedSmartpayenvEnvironment(SmartpayenvEnvironment):
|
| 58 |
+
"""Singleton subclass — always returns the same env instance."""
|
| 59 |
+
|
| 60 |
+
def __new__(cls, *args, **kwargs):
|
| 61 |
+
global _SHARED_ENV
|
| 62 |
+
if _SHARED_ENV is None:
|
| 63 |
+
inst = super().__new__(cls)
|
| 64 |
+
super(SharedSmartpayenvEnvironment, inst).__init__(*args, **kwargs)
|
| 65 |
+
inst._singleton_initialized = True # type: ignore[attr-defined]
|
| 66 |
+
_SHARED_ENV = inst
|
| 67 |
+
return _SHARED_ENV
|
| 68 |
+
|
| 69 |
+
def __init__(self, *args, **kwargs): # noqa: D401
|
| 70 |
+
# Already initialised by __new__ on first construction; subsequent
|
| 71 |
+
# constructions are no-ops so we don't reset the env.
|
| 72 |
+
if getattr(self, "_singleton_initialized", False):
|
| 73 |
+
return
|
| 74 |
+
super().__init__(*args, **kwargs)
|
| 75 |
+
self._singleton_initialized = True
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _get_env() -> SmartpayenvEnvironment:
|
| 79 |
+
"""Return the shared env, creating it if openenv hasn't yet."""
|
| 80 |
+
global _SHARED_ENV
|
| 81 |
+
if _SHARED_ENV is None:
|
| 82 |
+
SharedSmartpayenvEnvironment() # populates _SHARED_ENV
|
| 83 |
+
assert _SHARED_ENV is not None
|
| 84 |
+
return _SHARED_ENV
|
| 85 |
+
|
| 86 |
+
|
| 87 |
# Create the app with web interface and README integration
|
| 88 |
app = create_app(
|
| 89 |
+
SharedSmartpayenvEnvironment,
|
| 90 |
SmartpayenvAction,
|
| 91 |
SmartpayenvObservation,
|
| 92 |
env_name="SmartPayEnv",
|
|
|
|
| 96 |
|
| 97 |
@app.post("/simulate", response_model=SmartpayenvObservation)
|
| 98 |
async def simulate(action: SmartpayenvAction):
|
| 99 |
+
"""Simulates an action without advancing the true environment state."""
|
| 100 |
+
return _get_env().simulate(action)
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
# ── Theme-4 co-evolution endpoints ────────────────────────────────────
|
|
|
|
| 121 |
@app.post("/configure_adversary")
|
| 122 |
async def configure_adversary(cfg: AdversaryConfig):
|
| 123 |
"""Set the learnable fraud agent's behaviour. Returns the active config."""
|
| 124 |
+
return _get_env().configure_adversary(
|
| 125 |
intensity=cfg.intensity,
|
| 126 |
noise_boost=cfg.noise_boost,
|
| 127 |
pattern_rate=cfg.pattern_rate,
|
|
|
|
| 133 |
async def reset_seeded(req: SeededReset):
|
| 134 |
"""Deterministic reset: same `seed` => same starting trajectory.
|
| 135 |
Useful for GRPO so all completions in a group share the same state."""
|
| 136 |
+
return _get_env().reset(difficulty=int(req.difficulty), seed=req.seed)
|
| 137 |
|
| 138 |
|
| 139 |
def main():
|