K446 commited on
Commit
7be88b4
·
1 Parent(s): a6ecb81

Update run_training.py and train_grpo.py, remove Dockerfile.training

Browse files
Files changed (3) hide show
  1. Dockerfile.training +0 -24
  2. run_training.py +1 -1
  3. training/train_grpo.py +74 -32
Dockerfile.training DELETED
@@ -1,24 +0,0 @@
1
- # OpenGrid GRPO Training Space — Runs on A10G GPU
2
- # After training completes, serves results on port 7860
3
-
4
- FROM python:3.10-slim
5
-
6
- LABEL org.opencontainers.image.title="OpenGrid GRPO Training"
7
- LABEL org.opencontainers.image.description="GRPO training for power grid multi-agent controller"
8
-
9
- RUN useradd -m -u 1000 user
10
- USER user
11
- ENV PATH="/home/user/.local/bin:$PATH"
12
-
13
- WORKDIR /app
14
-
15
- # Install training dependencies
16
- COPY --chown=user requirements-training.txt .
17
- RUN pip install --no-cache-dir --upgrade -r requirements-training.txt
18
-
19
- # Copy application code
20
- COPY --chown=user . /app
21
-
22
- # Training entrypoint: runs GRPO then serves results
23
- EXPOSE 7860
24
- CMD ["python", "run_training.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_training.py CHANGED
@@ -242,7 +242,7 @@ def run_grpo_training():
242
  save_steps=50,
243
  max_prompt_length=1024,
244
  max_completion_length=96,
245
- num_generations=2,
246
  report_to="none",
247
  remove_unused_columns=False,
248
  bf16=_bf16,
 
242
  save_steps=50,
243
  max_prompt_length=1024,
244
  max_completion_length=96,
245
+ num_generations=4, # min for meaningful GRPO variance; 2 gives reward_std=0
246
  report_to="none",
247
  remove_unused_columns=False,
248
  bf16=_bf16,
training/train_grpo.py CHANGED
@@ -227,14 +227,20 @@ def rollout_multi_agent(env: OpenGridEnv, generate_fn, task_config: dict) -> dic
227
  # GRPO Reward Functions
228
  # ============================================================================
229
 
230
- # Cache task configs to avoid re-deepcopy on every reward call
231
- _REWARD_ENV_CACHE = {}
 
 
232
 
233
 
234
  def _get_reward_env(task_config: dict) -> OpenGridEnv:
235
- """Get a fresh environment for reward computation."""
236
- env = OpenGridEnv(copy.deepcopy(task_config))
237
- env.reset()
 
 
 
 
238
  return env
239
 
240
 
@@ -258,6 +264,11 @@ def compute_grpo_reward_env(
258
  """
259
  from src.baseline import heuristic_policy
260
 
 
 
 
 
 
261
  rewards = []
262
  for completion, obs_dict in zip(completions, observations):
263
  if obs_dict is None:
@@ -272,36 +283,58 @@ def compute_grpo_reward_env(
272
  rewards.append(0.0)
273
  continue
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  action = extract_action(completion)
276
  has_adjustments = bool(action.bus_adjustments)
277
 
278
- # ── 1. Format reward (small but keeps gradient alive) ──
279
  format_score = 0.0
280
  if has_adjustments:
281
- format_score += 0.05
282
- else:
283
- freq = obs_dict.get('grid_frequency', 50.0)
284
- if abs(freq - 50.0) < 0.05:
285
- format_score += 0.05 # No-op when stable is fine
 
 
286
  else:
287
- format_score -= 0.05 # No-op during deviation is bad
 
 
 
 
288
 
289
- # ── 2. Environment-grounded reward ──
290
  try:
291
  env = _get_reward_env(task_config)
292
  env._set_state(obs_dict)
293
 
294
- # Step with the LLM's proposed action
295
  obs_after, reward, done, info = env.step(action)
296
  env_score = reward.value
297
 
298
- # Blackout = catastrophic
299
  if info.is_blackout:
300
  rewards.append(-1.0)
301
  continue
302
 
303
- # ── 3. Mini-rollout: what happens next? ──
304
- # Run a few more steps with heuristic to measure trajectory impact
305
  rollout_reward = 0.0
306
  for _ in range(horizon - 1):
307
  if done:
@@ -313,16 +346,13 @@ def compute_grpo_reward_env(
313
  rollout_reward -= 10.0
314
  break
315
 
316
- # Combine: immediate reward + discounted future
317
  total_env_score = env_score + 0.5 * rollout_reward
318
 
319
- # Normalize to [-1, 1] range
320
- # Typical per-step reward is ~0.5 to 1.5, rollout adds ~1-4
321
- # So total_env_score is roughly in [-10, 4] range
322
- normalized = total_env_score / 5.0
323
 
324
- except Exception as e:
325
- # Fallback: use lightweight heuristic scoring
326
  normalized = _compute_heuristic_score(action, obs_dict)
327
 
328
  total = format_score + normalized
@@ -379,12 +409,15 @@ def train_grpo(args):
379
  """Main GRPO training loop using TRL."""
380
  try:
381
  from trl import GRPOTrainer, GRPOConfig
382
- from transformers import AutoTokenizer, AutoModelForCausalLM
383
  except ImportError:
384
  print("ERROR: TRL not installed. Run: pip install trl transformers")
385
  print("For quantized training: pip install unsloth")
386
  sys.exit(1)
387
 
 
 
 
388
  print(f"[TRAIN] Model: {args.model}")
389
  print(f"[TRAIN] Task: {args.task}")
390
  print(f"[TRAIN] Epochs: {args.epochs}")
@@ -518,21 +551,30 @@ def train_grpo(args):
518
  else:
519
  obs_dicts.append(ctx)
520
 
521
- return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=3)
522
 
523
- # GRPO Config — tuned for sustained learning signal
524
  grpo_config = GRPOConfig(
525
  output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
526
  num_train_epochs=args.epochs,
527
  per_device_train_batch_size=args.batch_size,
528
- gradient_accumulation_steps=max(1, 16 // args.batch_size),
529
- learning_rate=1e-5, # Was 5e-6 — slightly more aggressive
530
- logging_steps=5,
531
  save_steps=50,
532
- max_completion_length=256,
533
- num_generations=8, # Was 4 — wider group for better ranking signal
 
 
534
  report_to="none",
535
  remove_unused_columns=False,
 
 
 
 
 
 
 
536
  )
537
 
538
  # Create dataset — include obs_context so TRL passes it to reward_fn
 
227
  # GRPO Reward Functions
228
  # ============================================================================
229
 
230
+ # Cache one env instance per task config — re-instantiating + deepcopy + reset
231
+ # on every reward call adds significant per-step latency for GRPO.
232
+ _REWARD_ENV_CACHE: dict = {}
233
+ _REWARD_CALL_COUNT = 0
234
 
235
 
236
  def _get_reward_env(task_config: dict) -> OpenGridEnv:
237
+ """Return a cached env for this task_config, building it once."""
238
+ key = id(task_config)
239
+ env = _REWARD_ENV_CACHE.get(key)
240
+ if env is None:
241
+ env = OpenGridEnv(copy.deepcopy(task_config))
242
+ env.reset()
243
+ _REWARD_ENV_CACHE[key] = env
244
  return env
245
 
246
 
 
264
  """
265
  from src.baseline import heuristic_policy
266
 
267
+ global _REWARD_CALL_COUNT
268
+ _REWARD_CALL_COUNT += 1
269
+ if _REWARD_CALL_COUNT <= 3 or _REWARD_CALL_COUNT % 50 == 0:
270
+ print(f" [reward_fn] call #{_REWARD_CALL_COUNT} | n_completions={len(completions)}", flush=True)
271
+
272
  rewards = []
273
  for completion, obs_dict in zip(completions, observations):
274
  if obs_dict is None:
 
283
  rewards.append(0.0)
284
  continue
285
 
286
+ freq = obs_dict.get('grid_frequency', 50.0)
287
+ freq_error = freq - 50.0
288
+
289
+ # ── 1. JSON validity signal — biggest discriminator ──
290
+ # Raw text check first (faster than extract_action)
291
+ raw_has_json = '{' in completion and '}' in completion
292
+ try:
293
+ import re as _re
294
+ _m = _re.search(r'\{[\s\S]*\}', completion)
295
+ _parsed = json.loads(_m.group()) if _m else None
296
+ json_valid = _parsed is not None and 'bus_adjustments' in _parsed
297
+ except Exception:
298
+ json_valid = False
299
+
300
+ if not json_valid:
301
+ # Invalid / missing JSON — strong penalty so the group has variance
302
+ rewards.append(-0.5)
303
+ continue
304
+
305
  action = extract_action(completion)
306
  has_adjustments = bool(action.bus_adjustments)
307
 
308
+ # ── 2. Format reward directional correctness ──
309
  format_score = 0.0
310
  if has_adjustments:
311
+ total_delta = sum(a.delta for a in action.bus_adjustments)
312
+ # Reward correct direction relative to frequency error
313
+ if abs(freq_error) > 0.05:
314
+ # freq too low need positive delta; freq too high → negative delta
315
+ correct_dir = (freq_error < 0 and total_delta > 0) or \
316
+ (freq_error > 0 and total_delta < 0)
317
+ format_score = 0.3 if correct_dir else -0.3
318
  else:
319
+ # Stable grid: small action is fine, large one wastes resources
320
+ format_score = 0.1 if abs(total_delta) < 5.0 else -0.1
321
+ else:
322
+ # No-op: fine when stable, bad when deviating
323
+ format_score = 0.1 if abs(freq_error) < 0.05 else -0.3
324
 
325
+ # ── 3. Environment-grounded reward ──
326
  try:
327
  env = _get_reward_env(task_config)
328
  env._set_state(obs_dict)
329
 
 
330
  obs_after, reward, done, info = env.step(action)
331
  env_score = reward.value
332
 
 
333
  if info.is_blackout:
334
  rewards.append(-1.0)
335
  continue
336
 
337
+ # horizon=1: just immediate reward avoids 24 extra env steps per optimizer step
 
338
  rollout_reward = 0.0
339
  for _ in range(horizon - 1):
340
  if done:
 
346
  rollout_reward -= 10.0
347
  break
348
 
 
349
  total_env_score = env_score + 0.5 * rollout_reward
350
 
351
+ # Narrower normalizer wider spread across completions
352
+ # Typical per-step reward: 0.51.5 (good), -100 (blackout)
353
+ normalized = total_env_score / 3.0
 
354
 
355
+ except Exception:
 
356
  normalized = _compute_heuristic_score(action, obs_dict)
357
 
358
  total = format_score + normalized
 
409
  """Main GRPO training loop using TRL."""
410
  try:
411
  from trl import GRPOTrainer, GRPOConfig
412
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
413
  except ImportError:
414
  print("ERROR: TRL not installed. Run: pip install trl transformers")
415
  print("For quantized training: pip install unsloth")
416
  sys.exit(1)
417
 
418
+ import inspect as _inspect
419
+ _grpo_params = set(_inspect.signature(GRPOConfig.__init__).parameters)
420
+
421
  print(f"[TRAIN] Model: {args.model}")
422
  print(f"[TRAIN] Task: {args.task}")
423
  print(f"[TRAIN] Epochs: {args.epochs}")
 
551
  else:
552
  obs_dicts.append(ctx)
553
 
554
+ return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
555
 
556
+ # GRPO Config — tuned for sustained learning signal AND visible progress
557
  grpo_config = GRPOConfig(
558
  output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
559
  num_train_epochs=args.epochs,
560
  per_device_train_batch_size=args.batch_size,
561
+ gradient_accumulation_steps=max(1, 8 // args.batch_size),
562
+ learning_rate=1e-5,
563
+ logging_steps=1,
564
  save_steps=50,
565
+ max_prompt_length=1024,
566
+ max_completion_length=96,
567
+ num_generations=4,
568
+ temperature=0.7,
569
  report_to="none",
570
  remove_unused_columns=False,
571
+ gradient_checkpointing=True,
572
+ gradient_checkpointing_kwargs={"use_reentrant": False},
573
+ optim="paged_adamw_8bit",
574
+ warmup_ratio=0.05,
575
+ lr_scheduler_type="cosine",
576
+ **({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
577
+ **({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
578
  )
579
 
580
  # Create dataset — include obs_context so TRL passes it to reward_fn