Replace env-simulation reward with fast pure-heuristic to fix hang
Browse files- run_training.py +2 -1
- training/train_grpo.py +52 -77
run_training.py
CHANGED
|
@@ -218,7 +218,7 @@ def run_grpo_training():
|
|
| 218 |
else:
|
| 219 |
obs_dicts.append(ctx)
|
| 220 |
|
| 221 |
-
return compute_grpo_reward_env(texts, obs_dicts, task_config
|
| 222 |
|
| 223 |
# Set generation config explicitly so EOS is always respected and
|
| 224 |
# generation never runs to max_completion_length every single time.
|
|
@@ -252,6 +252,7 @@ def run_grpo_training():
|
|
| 252 |
optim="paged_adamw_8bit",
|
| 253 |
warmup_ratio=0.05,
|
| 254 |
lr_scheduler_type="cosine",
|
|
|
|
| 255 |
**({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
|
| 256 |
**({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
|
| 257 |
)
|
|
|
|
| 218 |
else:
|
| 219 |
obs_dicts.append(ctx)
|
| 220 |
|
| 221 |
+
return compute_grpo_reward_env(texts, obs_dicts, task_config)
|
| 222 |
|
| 223 |
# Set generation config explicitly so EOS is always respected and
|
| 224 |
# generation never runs to max_completion_length every single time.
|
|
|
|
| 252 |
optim="paged_adamw_8bit",
|
| 253 |
warmup_ratio=0.05,
|
| 254 |
lr_scheduler_type="cosine",
|
| 255 |
+
dataloader_num_workers=0, # avoid subprocess issues with reward fn
|
| 256 |
**({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
|
| 257 |
**({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
|
| 258 |
)
|
training/train_grpo.py
CHANGED
|
@@ -248,26 +248,21 @@ def compute_grpo_reward_env(
|
|
| 248 |
completions: list,
|
| 249 |
observations: list,
|
| 250 |
task_config: dict,
|
| 251 |
-
horizon: int =
|
| 252 |
) -> list:
|
| 253 |
-
"""
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
This directly addresses the proxy-reward disconnect that caused
|
| 263 |
-
the original GRPO training to show zero improvement.
|
| 264 |
"""
|
| 265 |
-
from src.baseline import heuristic_policy
|
| 266 |
-
|
| 267 |
global _REWARD_CALL_COUNT
|
| 268 |
_REWARD_CALL_COUNT += 1
|
| 269 |
-
if _REWARD_CALL_COUNT <=
|
| 270 |
-
print(f" [reward_fn] call #{_REWARD_CALL_COUNT} |
|
| 271 |
|
| 272 |
rewards = []
|
| 273 |
for completion, obs_dict in zip(completions, observations):
|
|
@@ -275,7 +270,6 @@ def compute_grpo_reward_env(
|
|
| 275 |
rewards.append(0.0)
|
| 276 |
continue
|
| 277 |
|
| 278 |
-
# Deserialize if needed (TRL may pass strings)
|
| 279 |
if isinstance(obs_dict, str):
|
| 280 |
try:
|
| 281 |
obs_dict = json.loads(obs_dict)
|
|
@@ -285,77 +279,58 @@ def compute_grpo_reward_env(
|
|
| 285 |
|
| 286 |
freq = obs_dict.get('grid_frequency', 50.0)
|
| 287 |
freq_error = freq - 50.0
|
|
|
|
| 288 |
|
| 289 |
-
# ── 1. JSON validity
|
| 290 |
-
# Raw text check first (faster than extract_action)
|
| 291 |
-
raw_has_json = '{' in completion and '}' in completion
|
| 292 |
try:
|
| 293 |
-
|
| 294 |
-
_m = _re.search(r'\{[\s\S]*\}', completion)
|
| 295 |
_parsed = json.loads(_m.group()) if _m else None
|
| 296 |
-
json_valid =
|
|
|
|
|
|
|
|
|
|
| 297 |
except Exception:
|
| 298 |
json_valid = False
|
| 299 |
|
| 300 |
if not json_valid:
|
| 301 |
-
# Invalid / missing JSON — strong penalty so the group has variance
|
| 302 |
rewards.append(-0.5)
|
| 303 |
continue
|
| 304 |
|
| 305 |
-
action
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
else:
|
| 319 |
-
#
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
if info.is_blackout:
|
| 334 |
-
rewards.append(-1.0)
|
| 335 |
-
continue
|
| 336 |
-
|
| 337 |
-
# horizon=1: just immediate reward — avoids 24 extra env steps per optimizer step
|
| 338 |
-
rollout_reward = 0.0
|
| 339 |
-
for _ in range(horizon - 1):
|
| 340 |
-
if done:
|
| 341 |
-
break
|
| 342 |
-
h_action = heuristic_policy(obs_after)
|
| 343 |
-
obs_after, r, done, info = env.step(h_action)
|
| 344 |
-
rollout_reward += r.value
|
| 345 |
-
if info.is_blackout:
|
| 346 |
-
rollout_reward -= 10.0
|
| 347 |
-
break
|
| 348 |
-
|
| 349 |
-
total_env_score = env_score + 0.5 * rollout_reward
|
| 350 |
-
|
| 351 |
-
# Narrower normalizer → wider spread across completions
|
| 352 |
-
# Typical per-step reward: 0.5–1.5 (good), -100 (blackout)
|
| 353 |
-
normalized = total_env_score / 3.0
|
| 354 |
-
|
| 355 |
-
except Exception:
|
| 356 |
-
normalized = _compute_heuristic_score(action, obs_dict)
|
| 357 |
-
|
| 358 |
-
total = format_score + normalized
|
| 359 |
rewards.append(max(-1.0, min(1.0, total)))
|
| 360 |
|
| 361 |
return rewards
|
|
|
|
| 248 |
completions: list,
|
| 249 |
observations: list,
|
| 250 |
task_config: dict,
|
| 251 |
+
horizon: int = 1,
|
| 252 |
) -> list:
|
| 253 |
+
"""Fast multi-signal reward for GRPO — no env simulation to avoid hangs.
|
| 254 |
+
|
| 255 |
+
Signals (ordered by discriminative power):
|
| 256 |
+
1. JSON validity : -0.5 (invalid) vs 0 (valid) — creates hard cliff
|
| 257 |
+
2. Schema check : +0.1 for correct bus_id types and non-empty adjustments
|
| 258 |
+
3. Direction : ±0.4 based on whether delta corrects frequency error
|
| 259 |
+
4. Proportionality : ±0.2 based on magnitude relative to freq error
|
| 260 |
+
5. Stability bonus : +0.1 for small action when grid is already stable
|
|
|
|
|
|
|
|
|
|
| 261 |
"""
|
|
|
|
|
|
|
| 262 |
global _REWARD_CALL_COUNT
|
| 263 |
_REWARD_CALL_COUNT += 1
|
| 264 |
+
if _REWARD_CALL_COUNT <= 5 or _REWARD_CALL_COUNT % 100 == 0:
|
| 265 |
+
print(f" [reward_fn] call #{_REWARD_CALL_COUNT} | n={len(completions)}", flush=True)
|
| 266 |
|
| 267 |
rewards = []
|
| 268 |
for completion, obs_dict in zip(completions, observations):
|
|
|
|
| 270 |
rewards.append(0.0)
|
| 271 |
continue
|
| 272 |
|
|
|
|
| 273 |
if isinstance(obs_dict, str):
|
| 274 |
try:
|
| 275 |
obs_dict = json.loads(obs_dict)
|
|
|
|
| 279 |
|
| 280 |
freq = obs_dict.get('grid_frequency', 50.0)
|
| 281 |
freq_error = freq - 50.0
|
| 282 |
+
abs_error = abs(freq_error)
|
| 283 |
|
| 284 |
+
# ── 1. JSON validity ──
|
|
|
|
|
|
|
| 285 |
try:
|
| 286 |
+
_m = re.search(r'\{[\s\S]*\}', completion)
|
|
|
|
| 287 |
_parsed = json.loads(_m.group()) if _m else None
|
| 288 |
+
json_valid = (
|
| 289 |
+
_parsed is not None
|
| 290 |
+
and isinstance(_parsed.get('bus_adjustments'), list)
|
| 291 |
+
)
|
| 292 |
except Exception:
|
| 293 |
json_valid = False
|
| 294 |
|
| 295 |
if not json_valid:
|
|
|
|
| 296 |
rewards.append(-0.5)
|
| 297 |
continue
|
| 298 |
|
| 299 |
+
# ── 2. Schema / action quality ──
|
| 300 |
+
adjustments = _parsed.get('bus_adjustments', [])
|
| 301 |
+
schema_score = 0.0
|
| 302 |
+
valid_adjs = []
|
| 303 |
+
for adj in adjustments:
|
| 304 |
+
if isinstance(adj.get('bus_id'), int) and isinstance(adj.get('delta'), (int, float)):
|
| 305 |
+
valid_adjs.append(adj)
|
| 306 |
+
if valid_adjs:
|
| 307 |
+
schema_score = 0.1
|
| 308 |
+
elif abs_error > 0.05:
|
| 309 |
+
schema_score = -0.1 # should have acted but gave no valid adjustments
|
| 310 |
+
|
| 311 |
+
# ── 3. Directional correctness ──
|
| 312 |
+
direction_score = 0.0
|
| 313 |
+
if valid_adjs:
|
| 314 |
+
total_delta = sum(a['delta'] for a in valid_adjs)
|
| 315 |
+
if abs_error > 0.05:
|
| 316 |
+
correct = (freq_error < 0 and total_delta > 0) or \
|
| 317 |
+
(freq_error > 0 and total_delta < 0)
|
| 318 |
+
direction_score = 0.4 if correct else -0.4
|
| 319 |
else:
|
| 320 |
+
# Grid stable — small action OK, large action penalised
|
| 321 |
+
direction_score = 0.1 if abs(total_delta) < 5.0 else -0.2
|
| 322 |
+
|
| 323 |
+
# ── 4. Proportionality ──
|
| 324 |
+
prop_score = 0.0
|
| 325 |
+
if valid_adjs and abs_error > 0.05:
|
| 326 |
+
total_delta = sum(a['delta'] for a in valid_adjs)
|
| 327 |
+
ideal = abs_error * 15.0 # rough MW per Hz gain
|
| 328 |
+
actual = abs(total_delta)
|
| 329 |
+
if actual > 0.1:
|
| 330 |
+
ratio = min(actual, ideal) / max(actual, ideal, 0.1)
|
| 331 |
+
prop_score = 0.2 * ratio # up to +0.2 for perfect proportionality
|
| 332 |
+
|
| 333 |
+
total = schema_score + direction_score + prop_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
rewards.append(max(-1.0, min(1.0, total)))
|
| 335 |
|
| 336 |
return rewards
|