Pratap-K commited on
Commit
c620fb9
·
1 Parent(s): 2a81a82

Update training

Browse files
Files changed (2) hide show
  1. notebooks/train_smartpayenev.ipynb +54 -29
  2. server/app.py +44 -8
notebooks/train_smartpayenev.ipynb CHANGED
@@ -120,25 +120,34 @@
120
  "DIFFICULTY = 2\n",
121
  "SEED = 42\n",
122
  "\n",
 
 
 
 
123
  "# Co-evolution loop\n",
124
- "N_ROUNDS = 3 if QUICK_MODE else 6 # defender<->fraud alternations\n",
125
- "GRPO_STEPS_PER_ROUND = 12 if QUICK_MODE else 40\n",
126
- "ES_STEPS_PER_ROUND = 4 if QUICK_MODE else 10\n",
127
- "ES_POPULATION = 4 if QUICK_MODE else 8\n",
128
- "ES_SIGMA = 0.25 # exploration std for ES\n",
129
- "ES_LR = 0.4 # ES update rate\n",
130
- "\n",
131
- "# Defender / GRPO\n",
132
- "PROMPT_DATASET_SIZE = 48 if QUICK_MODE else 192\n",
133
- "GRPO_NUM_GENERATIONS = 8 if QUICK_MODE else 8 # bigger group = better advantage\n",
134
- "ROLLOUT_STEPS_PER_REWARD = 4 if QUICK_MODE else 6 # multi-step rollout per generation\n",
135
- "\n",
136
- "# Eval\n",
137
- "EVAL_EPISODES = 3 if QUICK_MODE else 5\n",
138
- "EVAL_STEPS_PER_EPISODE = 30 if QUICK_MODE else 60\n",
 
 
 
 
 
139
  "\n",
140
  "MODEL_ID = 'unsloth/Qwen2.5-0.5B-Instruct'\n",
141
- "MAX_SEQ_LEN = 2048\n",
142
  "LOAD_IN_4BIT = True\n",
143
  "\n",
144
  "os.makedirs('artifacts', exist_ok=True)\n",
@@ -148,6 +157,9 @@
148
  " '| ROUNDS =', N_ROUNDS,\n",
149
  " '| GRPO/round =', GRPO_STEPS_PER_ROUND,\n",
150
  " '| ES/round =', ES_STEPS_PER_ROUND,\n",
 
 
 
151
  " '| MODEL_ID =', MODEL_ID)"
152
  ]
153
  },
@@ -453,8 +465,12 @@
453
  " def apply(self):\n",
454
  " env_configure_adversary(**self.theta, strategy='mixed')\n",
455
  "\n",
456
- " def evaluate_against_defender(self, defender_fn, n_episodes=2, n_steps=12):\n",
457
- " \"\"\"Defender_fn(obs)->action_dict. Returns mean defender reward (lower = harder fraud).\"\"\"\n",
 
 
 
 
458
  " rewards = []\n",
459
  " for ep in range(int(n_episodes)):\n",
460
  " obs = env_reset_seeded(seed=10_000 + ep, difficulty=DIFFICULTY)\n",
@@ -612,14 +628,21 @@
612
  " return rewards\n",
613
  "\n",
614
  "# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
 
 
 
 
 
 
615
  "@torch.no_grad()\n",
616
  "def _defender_action(obs):\n",
617
  " FastLanguageModel.for_inference(model)\n",
618
  " device = next(model.parameters()).device\n",
619
  " prompt = make_prompt(obs)\n",
620
- " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
 
621
  " out = model.generate(\n",
622
- " **inputs, max_new_tokens=48, do_sample=False,\n",
623
  " pad_token_id=tokenizer.pad_token_id,\n",
624
  " )\n",
625
  " text = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
@@ -631,12 +654,12 @@
631
  " return GRPOConfig(\n",
632
  " output_dir='outputs/theme4_grpo_unsloth',\n",
633
  " num_generations=GRPO_NUM_GENERATIONS,\n",
634
- " max_prompt_length=1024,\n",
635
- " max_completion_length=48,\n",
636
  " per_device_train_batch_size=1,\n",
637
  " gradient_accumulation_steps=2,\n",
638
  " max_steps=int(max_steps),\n",
639
- " logging_steps=2,\n",
640
  " learning_rate=1e-5,\n",
641
  " save_strategy='no',\n",
642
  " report_to=[],\n",
@@ -653,12 +676,13 @@
653
  "loss_history_all = []\n",
654
  "reward_log_all = []\n",
655
  "\n",
656
- "# Quick eval helper (small to keep co-training cheap)\n",
657
- "def quick_defender_eval(n_eps=2, n_steps=12):\n",
 
658
  " rs = []\n",
659
- " for ep in range(n_eps):\n",
660
  " obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
661
- " for _ in range(n_steps):\n",
662
  " a = _defender_action(obs)\n",
663
  " payload = env_step(a)\n",
664
  " obs = payload.get('observation', payload)\n",
@@ -743,11 +767,12 @@
743
  "\n",
744
  "def trained_policy(obs):\n",
745
  " prompt = make_prompt(obs)\n",
746
- " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n",
 
747
  " with torch.no_grad():\n",
748
  " out = model.generate(\n",
749
  " **inputs,\n",
750
- " max_new_tokens=64,\n",
751
  " do_sample=False,\n",
752
  " pad_token_id=tokenizer.pad_token_id,\n",
753
  " )\n",
 
120
  "DIFFICULTY = 2\n",
121
  "SEED = 42\n",
122
  "\n",
123
+ "# ── Minimal-viable QUICK config — every variable dialled to the lowest\n",
124
+ "# value that still produces all 7 plots + meaningful accuracy comparison.\n",
125
+ "# Approx wall time on a Colab T4: QUICK ~3-5 min, FULL ~12-18 min.\n",
126
+ "\n",
127
  "# Co-evolution loop\n",
128
+ "N_ROUNDS = 2 if QUICK_MODE else 4 # need >=2 to see co-evolution curve\n",
129
+ "GRPO_STEPS_PER_ROUND = 4 if QUICK_MODE else 20\n",
130
+ "ES_STEPS_PER_ROUND = 2 if QUICK_MODE else 6\n",
131
+ "ES_POPULATION = 3 if QUICK_MODE else 6 # ES needs >=3 for ranked weights\n",
132
+ "ES_SIGMA = 0.25 # exploration std for ES\n",
133
+ "ES_LR = 0.4 # ES update rate\n",
134
+ "\n",
135
+ "# Defender / GRPO (rewards are mean over a K-step rollout)\n",
136
+ "PROMPT_DATASET_SIZE = 16 if QUICK_MODE else 96\n",
137
+ "GRPO_NUM_GENERATIONS = 4 if QUICK_MODE else 6 # >=2 for group-relative advantage\n",
138
+ "ROLLOUT_STEPS_PER_REWARD = 2 if QUICK_MODE else 4\n",
139
+ "\n",
140
+ "# Final frozen-holdout eval\n",
141
+ "EVAL_EPISODES = 2 if QUICK_MODE else 4\n",
142
+ "EVAL_STEPS_PER_EPISODE = 15 if QUICK_MODE else 40\n",
143
+ "\n",
144
+ "# Inner micro-eval used by ES + per-round defender check (called many times,\n",
145
+ "# so keep these tiny — they dominate co-training wall time).\n",
146
+ "COEVO_EVAL_EPISODES = 1 if QUICK_MODE else 2\n",
147
+ "COEVO_EVAL_STEPS = 6 if QUICK_MODE else 12\n",
148
  "\n",
149
  "MODEL_ID = 'unsloth/Qwen2.5-0.5B-Instruct'\n",
150
+ "MAX_SEQ_LEN = 1024 if QUICK_MODE else 2048\n",
151
  "LOAD_IN_4BIT = True\n",
152
  "\n",
153
  "os.makedirs('artifacts', exist_ok=True)\n",
 
157
  " '| ROUNDS =', N_ROUNDS,\n",
158
  " '| GRPO/round =', GRPO_STEPS_PER_ROUND,\n",
159
  " '| ES/round =', ES_STEPS_PER_ROUND,\n",
160
+ " '| pop =', ES_POPULATION,\n",
161
+ " '| K-rollout =', ROLLOUT_STEPS_PER_REWARD,\n",
162
+ " '| eval =', f'{EVAL_EPISODES}x{EVAL_STEPS_PER_EPISODE}',\n",
163
  " '| MODEL_ID =', MODEL_ID)"
164
  ]
165
  },
 
465
  " def apply(self):\n",
466
  " env_configure_adversary(**self.theta, strategy='mixed')\n",
467
  "\n",
468
+ " def evaluate_against_defender(self, defender_fn,\n",
469
+ " n_episodes=COEVO_EVAL_EPISODES,\n",
470
+ " n_steps=COEVO_EVAL_STEPS):\n",
471
+ " \"\"\"Defender_fn(obs)->action_dict. Returns mean defender reward (lower = harder fraud).\n",
472
+ " Defaults are intentionally tiny — this is called ES_POPULATION times per\n",
473
+ " ES step, so any extra step here multiplies the wall time fast.\"\"\"\n",
474
  " rewards = []\n",
475
  " for ep in range(int(n_episodes)):\n",
476
  " obs = env_reset_seeded(seed=10_000 + ep, difficulty=DIFFICULTY)\n",
 
628
  " return rewards\n",
629
  "\n",
630
  "# ── Defender policy fn (used inside ES eval) ──────────────────────────\n",
631
+ "# Cap inputs/outputs aggressively so each defender call is ~few hundred ms,\n",
632
+ "# not seconds. ES calls this ES_POPULATION * COEVO_EVAL_EPISODES * COEVO_EVAL_STEPS\n",
633
+ "# times per ES step, so latency here dominates total wall time.\n",
634
+ "_DEF_MAX_PROMPT = 512 if QUICK_MODE else 1024\n",
635
+ "_DEF_MAX_NEW = 24 if QUICK_MODE else 48\n",
636
+ "\n",
637
  "@torch.no_grad()\n",
638
  "def _defender_action(obs):\n",
639
  " FastLanguageModel.for_inference(model)\n",
640
  " device = next(model.parameters()).device\n",
641
  " prompt = make_prompt(obs)\n",
642
+ " inputs = tokenizer(prompt, return_tensors='pt', truncation=True,\n",
643
+ " max_length=_DEF_MAX_PROMPT).to(device)\n",
644
  " out = model.generate(\n",
645
+ " **inputs, max_new_tokens=_DEF_MAX_NEW, do_sample=False,\n",
646
  " pad_token_id=tokenizer.pad_token_id,\n",
647
  " )\n",
648
  " text = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
 
654
  " return GRPOConfig(\n",
655
  " output_dir='outputs/theme4_grpo_unsloth',\n",
656
  " num_generations=GRPO_NUM_GENERATIONS,\n",
657
+ " max_prompt_length=_DEF_MAX_PROMPT,\n",
658
+ " max_completion_length=_DEF_MAX_NEW,\n",
659
  " per_device_train_batch_size=1,\n",
660
  " gradient_accumulation_steps=2,\n",
661
  " max_steps=int(max_steps),\n",
662
+ " logging_steps=1,\n",
663
  " learning_rate=1e-5,\n",
664
  " save_strategy='no',\n",
665
  " report_to=[],\n",
 
676
  "loss_history_all = []\n",
677
  "reward_log_all = []\n",
678
  "\n",
679
+ "# Quick eval helper — tiny by design (called 3x per round: once after defender\n",
680
+ "# phase, twice for the exploitability gap). Uses the same COEVO_* knobs.\n",
681
+ "def quick_defender_eval(n_eps=COEVO_EVAL_EPISODES, n_steps=COEVO_EVAL_STEPS):\n",
682
  " rs = []\n",
683
+ " for ep in range(int(n_eps)):\n",
684
  " obs = env_reset_seeded(seed=20_000 + ep, difficulty=DIFFICULTY)\n",
685
+ " for _ in range(int(n_steps)):\n",
686
  " a = _defender_action(obs)\n",
687
  " payload = env_step(a)\n",
688
  " obs = payload.get('observation', payload)\n",
 
767
  "\n",
768
  "def trained_policy(obs):\n",
769
  " prompt = make_prompt(obs)\n",
770
+ " inputs = tokenizer(prompt, return_tensors='pt', truncation=True,\n",
771
+ " max_length=_DEF_MAX_PROMPT).to(device)\n",
772
  " with torch.no_grad():\n",
773
  " out = model.generate(\n",
774
  " **inputs,\n",
775
+ " max_new_tokens=_DEF_MAX_NEW,\n",
776
  " do_sample=False,\n",
777
  " pad_token_id=tokenizer.pad_token_id,\n",
778
  " )\n",
server/app.py CHANGED
@@ -45,9 +45,48 @@ except (ImportError, ValueError):
45
  from server.SmartPayEnv_environment import SmartpayenvEnvironment
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # Create the app with web interface and README integration
49
  app = create_app(
50
- SmartpayenvEnvironment,
51
  SmartpayenvAction,
52
  SmartpayenvObservation,
53
  env_name="SmartPayEnv",
@@ -57,11 +96,8 @@ app = create_app(
57
 
58
  @app.post("/simulate", response_model=SmartpayenvObservation)
59
  async def simulate(action: SmartpayenvAction):
60
- """
61
- Simulates an action without advancing the true environment state.
62
- """
63
- # OpenEnv environments are stored in app.env
64
- return app.env.simulate(action)
65
 
66
 
67
  # ── Theme-4 co-evolution endpoints ────────────────────────────────────
@@ -85,7 +121,7 @@ class SeededReset(BaseModel):
85
  @app.post("/configure_adversary")
86
  async def configure_adversary(cfg: AdversaryConfig):
87
  """Set the learnable fraud agent's behaviour. Returns the active config."""
88
- return app.env.configure_adversary(
89
  intensity=cfg.intensity,
90
  noise_boost=cfg.noise_boost,
91
  pattern_rate=cfg.pattern_rate,
@@ -97,7 +133,7 @@ async def configure_adversary(cfg: AdversaryConfig):
97
  async def reset_seeded(req: SeededReset):
98
  """Deterministic reset: same `seed` => same starting trajectory.
99
  Useful for GRPO so all completions in a group share the same state."""
100
- return app.env.reset(difficulty=int(req.difficulty), seed=req.seed)
101
 
102
 
103
  def main():
 
45
  from server.SmartPayEnv_environment import SmartpayenvEnvironment
46
 
47
 
48
+ # ── Singleton env so custom endpoints share state with openenv ─────────
49
+ # Different openenv versions store the env in different places
50
+ # (app.env, app.state.env, per-request factory, etc.). Rather than
51
+ # guessing, we use a singleton subclass: no matter how many times
52
+ # openenv instantiates the env class, it always gets the same object,
53
+ # and we can always reach it via _SHARED_ENV.
54
+ _SHARED_ENV: SmartpayenvEnvironment | None = None
55
+
56
+
57
+ class SharedSmartpayenvEnvironment(SmartpayenvEnvironment):
58
+ """Singleton subclass — always returns the same env instance."""
59
+
60
+ def __new__(cls, *args, **kwargs):
61
+ global _SHARED_ENV
62
+ if _SHARED_ENV is None:
63
+ inst = super().__new__(cls)
64
+ super(SharedSmartpayenvEnvironment, inst).__init__(*args, **kwargs)
65
+ inst._singleton_initialized = True # type: ignore[attr-defined]
66
+ _SHARED_ENV = inst
67
+ return _SHARED_ENV
68
+
69
+ def __init__(self, *args, **kwargs): # noqa: D401
70
+ # Already initialised by __new__ on first construction; subsequent
71
+ # constructions are no-ops so we don't reset the env.
72
+ if getattr(self, "_singleton_initialized", False):
73
+ return
74
+ super().__init__(*args, **kwargs)
75
+ self._singleton_initialized = True
76
+
77
+
78
+ def _get_env() -> SmartpayenvEnvironment:
79
+ """Return the shared env, creating it if openenv hasn't yet."""
80
+ global _SHARED_ENV
81
+ if _SHARED_ENV is None:
82
+ SharedSmartpayenvEnvironment() # populates _SHARED_ENV
83
+ assert _SHARED_ENV is not None
84
+ return _SHARED_ENV
85
+
86
+
87
  # Create the app with web interface and README integration
88
  app = create_app(
89
+ SharedSmartpayenvEnvironment,
90
  SmartpayenvAction,
91
  SmartpayenvObservation,
92
  env_name="SmartPayEnv",
 
96
 
97
  @app.post("/simulate", response_model=SmartpayenvObservation)
98
  async def simulate(action: SmartpayenvAction):
99
+ """Simulates an action without advancing the true environment state."""
100
+ return _get_env().simulate(action)
 
 
 
101
 
102
 
103
  # ── Theme-4 co-evolution endpoints ────────────────────────────────────
 
121
  @app.post("/configure_adversary")
122
  async def configure_adversary(cfg: AdversaryConfig):
123
  """Set the learnable fraud agent's behaviour. Returns the active config."""
124
+ return _get_env().configure_adversary(
125
  intensity=cfg.intensity,
126
  noise_boost=cfg.noise_boost,
127
  pattern_rate=cfg.pattern_rate,
 
133
  async def reset_seeded(req: SeededReset):
134
  """Deterministic reset: same `seed` => same starting trajectory.
135
  Useful for GRPO so all completions in a group share the same state."""
136
+ return _get_env().reset(difficulty=int(req.difficulty), seed=req.seed)
137
 
138
 
139
  def main():