K446 commited on
Commit
efbeb4b
·
1 Parent(s): 114859b

Replace env-simulation reward with fast pure-heuristic to fix hang

Browse files
Files changed (2) hide show
  1. run_training.py +2 -1
  2. training/train_grpo.py +52 -77
run_training.py CHANGED
@@ -218,7 +218,7 @@ def run_grpo_training():
218
  else:
219
  obs_dicts.append(ctx)
220
 
221
- return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
222
 
223
  # Set generation config explicitly so EOS is always respected and
224
  # generation never runs to max_completion_length every single time.
@@ -252,6 +252,7 @@ def run_grpo_training():
252
  optim="paged_adamw_8bit",
253
  warmup_ratio=0.05,
254
  lr_scheduler_type="cosine",
 
255
  **({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
256
  **({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
257
  )
 
218
  else:
219
  obs_dicts.append(ctx)
220
 
221
+ return compute_grpo_reward_env(texts, obs_dicts, task_config)
222
 
223
  # Set generation config explicitly so EOS is always respected and
224
  # generation never runs to max_completion_length every single time.
 
252
  optim="paged_adamw_8bit",
253
  warmup_ratio=0.05,
254
  lr_scheduler_type="cosine",
255
+ dataloader_num_workers=0, # avoid subprocess issues with reward fn
256
  **({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
257
  **({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
258
  )
training/train_grpo.py CHANGED
@@ -248,26 +248,21 @@ def compute_grpo_reward_env(
248
  completions: list,
249
  observations: list,
250
  task_config: dict,
251
- horizon: int = 3,
252
  ) -> list:
253
- """Environment-grounded reward: step the actual physics simulation.
254
-
255
- For each LLM-generated action:
256
- 1. Restore the env to the observation state
257
- 2. Step with the proposed action and get the real reward
258
- 3. Run a short rollout (horizon steps) with heuristic continuation
259
- to capture trajectory-level impact
260
- 4. Add format/schema bonuses
261
-
262
- This directly addresses the proxy-reward disconnect that caused
263
- the original GRPO training to show zero improvement.
264
  """
265
- from src.baseline import heuristic_policy
266
-
267
  global _REWARD_CALL_COUNT
268
  _REWARD_CALL_COUNT += 1
269
- if _REWARD_CALL_COUNT <= 3 or _REWARD_CALL_COUNT % 50 == 0:
270
- print(f" [reward_fn] call #{_REWARD_CALL_COUNT} | n_completions={len(completions)}", flush=True)
271
 
272
  rewards = []
273
  for completion, obs_dict in zip(completions, observations):
@@ -275,7 +270,6 @@ def compute_grpo_reward_env(
275
  rewards.append(0.0)
276
  continue
277
 
278
- # Deserialize if needed (TRL may pass strings)
279
  if isinstance(obs_dict, str):
280
  try:
281
  obs_dict = json.loads(obs_dict)
@@ -285,77 +279,58 @@ def compute_grpo_reward_env(
285
 
286
  freq = obs_dict.get('grid_frequency', 50.0)
287
  freq_error = freq - 50.0
 
288
 
289
- # ── 1. JSON validity signal — biggest discriminator ──
290
- # Raw text check first (faster than extract_action)
291
- raw_has_json = '{' in completion and '}' in completion
292
  try:
293
- import re as _re
294
- _m = _re.search(r'\{[\s\S]*\}', completion)
295
  _parsed = json.loads(_m.group()) if _m else None
296
- json_valid = _parsed is not None and 'bus_adjustments' in _parsed
 
 
 
297
  except Exception:
298
  json_valid = False
299
 
300
  if not json_valid:
301
- # Invalid / missing JSON — strong penalty so the group has variance
302
  rewards.append(-0.5)
303
  continue
304
 
305
- action = extract_action(completion)
306
- has_adjustments = bool(action.bus_adjustments)
307
-
308
- # ── 2. Format reward — directional correctness ──
309
- format_score = 0.0
310
- if has_adjustments:
311
- total_delta = sum(a.delta for a in action.bus_adjustments)
312
- # Reward correct direction relative to frequency error
313
- if abs(freq_error) > 0.05:
314
- # freq too low → need positive delta; freq too high → negative delta
315
- correct_dir = (freq_error < 0 and total_delta > 0) or \
316
- (freq_error > 0 and total_delta < 0)
317
- format_score = 0.3 if correct_dir else -0.3
 
 
 
 
 
 
 
318
  else:
319
- # Stable grid: small action is fine, large one wastes resources
320
- format_score = 0.1 if abs(total_delta) < 5.0 else -0.1
321
- else:
322
- # No-op: fine when stable, bad when deviating
323
- format_score = 0.1 if abs(freq_error) < 0.05 else -0.3
324
-
325
- # ── 3. Environment-grounded reward ──
326
- try:
327
- env = _get_reward_env(task_config)
328
- env._set_state(obs_dict)
329
-
330
- obs_after, reward, done, info = env.step(action)
331
- env_score = reward.value
332
-
333
- if info.is_blackout:
334
- rewards.append(-1.0)
335
- continue
336
-
337
- # horizon=1: just immediate reward — avoids 24 extra env steps per optimizer step
338
- rollout_reward = 0.0
339
- for _ in range(horizon - 1):
340
- if done:
341
- break
342
- h_action = heuristic_policy(obs_after)
343
- obs_after, r, done, info = env.step(h_action)
344
- rollout_reward += r.value
345
- if info.is_blackout:
346
- rollout_reward -= 10.0
347
- break
348
-
349
- total_env_score = env_score + 0.5 * rollout_reward
350
-
351
- # Narrower normalizer → wider spread across completions
352
- # Typical per-step reward: 0.5–1.5 (good), -100 (blackout)
353
- normalized = total_env_score / 3.0
354
-
355
- except Exception:
356
- normalized = _compute_heuristic_score(action, obs_dict)
357
-
358
- total = format_score + normalized
359
  rewards.append(max(-1.0, min(1.0, total)))
360
 
361
  return rewards
 
248
  completions: list,
249
  observations: list,
250
  task_config: dict,
251
+ horizon: int = 1,
252
  ) -> list:
253
+ """Fast multi-signal reward for GRPO no env simulation to avoid hangs.
254
+
255
+ Signals (ordered by discriminative power):
256
+ 1. JSON validity : -0.5 (invalid) vs 0 (valid) — creates hard cliff
257
+ 2. Schema check : +0.1 for correct bus_id types and non-empty adjustments
258
+ 3. Direction : ±0.4 based on whether delta corrects frequency error
259
+ 4. Proportionality : ±0.2 based on magnitude relative to freq error
260
+ 5. Stability bonus : +0.1 for small action when grid is already stable
 
 
 
261
  """
 
 
262
  global _REWARD_CALL_COUNT
263
  _REWARD_CALL_COUNT += 1
264
+ if _REWARD_CALL_COUNT <= 5 or _REWARD_CALL_COUNT % 100 == 0:
265
+ print(f" [reward_fn] call #{_REWARD_CALL_COUNT} | n={len(completions)}", flush=True)
266
 
267
  rewards = []
268
  for completion, obs_dict in zip(completions, observations):
 
270
  rewards.append(0.0)
271
  continue
272
 
 
273
  if isinstance(obs_dict, str):
274
  try:
275
  obs_dict = json.loads(obs_dict)
 
279
 
280
  freq = obs_dict.get('grid_frequency', 50.0)
281
  freq_error = freq - 50.0
282
+ abs_error = abs(freq_error)
283
 
284
+ # ── 1. JSON validity ──
 
 
285
  try:
286
+ _m = re.search(r'\{[\s\S]*\}', completion)
 
287
  _parsed = json.loads(_m.group()) if _m else None
288
+ json_valid = (
289
+ _parsed is not None
290
+ and isinstance(_parsed.get('bus_adjustments'), list)
291
+ )
292
  except Exception:
293
  json_valid = False
294
 
295
  if not json_valid:
 
296
  rewards.append(-0.5)
297
  continue
298
 
299
+ # ── 2. Schema / action quality ──
300
+ adjustments = _parsed.get('bus_adjustments', [])
301
+ schema_score = 0.0
302
+ valid_adjs = []
303
+ for adj in adjustments:
304
+ if isinstance(adj.get('bus_id'), int) and isinstance(adj.get('delta'), (int, float)):
305
+ valid_adjs.append(adj)
306
+ if valid_adjs:
307
+ schema_score = 0.1
308
+ elif abs_error > 0.05:
309
+ schema_score = -0.1 # should have acted but gave no valid adjustments
310
+
311
+ # ── 3. Directional correctness ──
312
+ direction_score = 0.0
313
+ if valid_adjs:
314
+ total_delta = sum(a['delta'] for a in valid_adjs)
315
+ if abs_error > 0.05:
316
+ correct = (freq_error < 0 and total_delta > 0) or \
317
+ (freq_error > 0 and total_delta < 0)
318
+ direction_score = 0.4 if correct else -0.4
319
  else:
320
+ # Grid stable small action OK, large action penalised
321
+ direction_score = 0.1 if abs(total_delta) < 5.0 else -0.2
322
+
323
+ # ── 4. Proportionality ──
324
+ prop_score = 0.0
325
+ if valid_adjs and abs_error > 0.05:
326
+ total_delta = sum(a['delta'] for a in valid_adjs)
327
+ ideal = abs_error * 15.0 # rough MW per Hz gain
328
+ actual = abs(total_delta)
329
+ if actual > 0.1:
330
+ ratio = min(actual, ideal) / max(actual, ideal, 0.1)
331
+ prop_score = 0.2 * ratio # up to +0.2 for perfect proportionality
332
+
333
+ total = schema_score + direction_score + prop_score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  rewards.append(max(-1.0, min(1.0, total)))
335
 
336
  return rewards