YashashMathur commited on
Commit
6a2dd66
Β·
verified Β·
1 Parent(s): 5d8f1e9

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +105 -27
train.py CHANGED
@@ -1,10 +1,19 @@
1
  """
2
  AEGIS Training Script for HF Spaces (A10G Small, 24GB VRAM)
3
- - Loads Qwen2.5-7B-Unsloth-bnb-4bit + step_50 LoRA adapter
4
- - Runs 10 remaining SFT steps + 500 GRPO steps
5
  - Saves LoRA checkpoints to HF Hub every 50 GRPO steps
6
  - Serves a minimal status page on :7860 so the Space stays alive
7
  - Prints "TRAINING COMPLETE - PLEASE DOWNGRADE HARDWARE" when done
 
 
 
 
 
 
 
 
 
8
  """
9
 
10
  import os, json, re, random, gc, sys, threading, time
@@ -33,8 +42,9 @@ except ImportError:
33
  # ─── Auth & Config ────────────────────────────────────────────────────────────
34
  HF_TOKEN = os.environ["HF_TOKEN"]
35
  HF_USERNAME = os.environ.get("HF_USERNAME", "YashashMathur")
36
- STEP50_REPO = f"{HF_USERNAME}/aegis-step50"
37
  CKPT_REPO = f"{HF_USERNAME}/aegis-training-checkpoints"
 
38
 
39
  login(token=HF_TOKEN)
40
  api = HfApi()
@@ -44,11 +54,10 @@ except Exception as e:
44
  print(f"Repo create: {e}")
45
 
46
  MAX_SEQ_LEN = 1024
47
- SFT_STEPS = 80
48
- # More warmup for JSON format
49
  GRPO_STEPS = 250
50
  GRPO_K = 2
51
- GRPO_LR = 2e-5
52
  CURRICULUM_SWITCH = 0 # Start with Level 1, advance early
53
  GRAD_CLIP = 1.0
54
  SAVE_EVERY = 50
@@ -160,21 +169,20 @@ W2 = {
160
 
161
 
162
  def build_prompt(sc, tokenizer):
 
163
  cot = tokenizer.decode(
164
- tokenizer.encode(sc.get("worker_cot_trace", ""), add_special_tokens=False)[
165
- :800
166
- ],
167
  skip_special_tokens=True,
168
  )
169
  out = tokenizer.decode(
170
- tokenizer.encode(sc.get("worker_output", ""), add_special_tokens=False)[:400],
171
  skip_special_tokens=True,
172
  )
173
  return (
174
  f"You are AEGIS, an AI fleet oversight agent. "
175
  f"Use <thought> tags for reasoning, then output JSON.\n"
176
  f"Worker Role: {sc.get('worker_role', 'dev')}\n"
177
- f"[WORKER_THOUGHTS]\n{cot}\n"
178
  f"[WORKER_OUTPUT]\n{out}\n"
179
  f"Rules:\n{RULES_BLOCK}\nJSON:"
180
  )
@@ -286,6 +294,52 @@ def score_response(a, truth, raw_text, level=1, fmt_ema=1.0):
286
  }
287
 
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # ─── Load Model + Step-50 Checkpoint ─────────────────────────────────────────
290
  from unsloth import FastLanguageModel
291
 
@@ -317,16 +371,31 @@ model = FastLanguageModel.get_peft_model(
317
  use_rslora=True,
318
  )
319
 
320
- # Load step_50 LoRA weights into the freshly created adapter
321
- print(f"Loading step_50 adapter from HF Hub: {STEP50_REPO}")
 
322
  try:
323
- ckpt_path = snapshot_download(STEP50_REPO, token=HF_TOKEN)
324
- adapter_weights = load_file(f"{ckpt_path}/adapter_model.safetensors")
325
- # set_peft_model_state_dict loads into the default adapter without rebuilding
 
 
 
 
326
  set_peft_model_state_dict(model, adapter_weights)
327
- print("Step_50 adapter loaded successfully.")
 
328
  except Exception as e:
329
- print(f"WARNING: Could not load step_50 adapter ({e}). Starting from fresh LoRA.")
 
 
 
 
 
 
 
 
 
330
 
331
  FastLanguageModel.for_training(model)
332
  if hasattr(model, "generation_config"):
@@ -406,14 +475,14 @@ for step in range(GRPO_STEPS):
406
  prompt, return_tensors="pt", truncation=True, max_length=1024
407
  ).to("cuda")
408
  prompt_len = p_enc.input_ids.shape[1]
409
- temp = max(0.7, 1.0 - step * 0.0008)
410
 
411
  FastLanguageModel.for_inference(model)
412
  with torch.no_grad():
413
  gen = model.generate(
414
  input_ids=p_enc.input_ids,
415
  attention_mask=p_enc.attention_mask,
416
- max_new_tokens=150,
417
  temperature=temp,
418
  top_p=0.9,
419
  do_sample=True,
@@ -433,21 +502,30 @@ for step in range(GRPO_STEPS):
433
  [rd["total"] for rd in reward_dicts], dtype=torch.float32, device="cuda"
434
  )
435
 
436
- if rewards.std().item() < 1e-6:
437
- rewards = rewards + torch.randn_like(rewards) * 0.01
438
- adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
439
- adv = adv.clamp(-2.0, 2.0)
440
-
441
  format_ema = (
442
  0.1 * (sum(1 for a in acts if a.get("__valid__")) / GRPO_K)
443
  + 0.9 * format_ema
444
  )
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  FastLanguageModel.for_training(model)
447
  optimizer.zero_grad()
448
  for r_text, a_val in zip(resps, adv.tolist()):
449
  f_enc = tokenizer(
450
- prompt + r_text, return_tensors="pt", truncation=True, max_length=1280
451
  ).to("cuda")
452
  lbls = f_enc.input_ids.clone()
453
  lbls[:, :prompt_len] = -100
@@ -504,7 +582,7 @@ for step in range(GRPO_STEPS):
504
  print(f" >> Pushed step_{step} to https://huggingface.co/{CKPT_REPO}")
505
  TRAIN_STATUS["phase"] = "GRPO"
506
 
507
- del gen, p_enc, resps, acts, rewards, adv, reward_dicts
508
 
509
  except torch.cuda.OutOfMemoryError:
510
  print(f"Step {step:04d} | OOM β€” clearing cache and skipping")
 
1
  """
2
  AEGIS Training Script for HF Spaces (A10G Small, 24GB VRAM)
3
+ - Loads Qwen2.5-7B-Unsloth-bnb-4bit + GRPO step_50 LoRA adapter (last good checkpoint)
4
+ - Runs SFT warmup + 250 GRPO steps with collapse-safe advantage computation
5
  - Saves LoRA checkpoints to HF Hub every 50 GRPO steps
6
  - Serves a minimal status page on :7860 so the Space stays alive
7
  - Prints "TRAINING COMPLETE - PLEASE DOWNGRADE HARDWARE" when done
8
+
9
+ FIXES vs previous version:
10
+ 1. Load GRPO step_50 (last good checkpoint) instead of original SFT step_50
11
+ 2. build_prompt: COT capped at 300 tokens, output at 150 β€” leaves 400+ tokens for generation
12
+ 3. max_new_tokens 150 -> 300 so thought+JSON never truncates mid-brace
13
+ 4. Skip GRPO gradient update when ALL completions fail format (was applying random gradients)
14
+ 5. Format recovery mini-SFT triggers automatically if fmt_ema < 0.15
15
+ 6. Temperature starts at 1.3 for exploration (matches blog), anneals to 0.9
16
+ 7. Backward pass max_length matches MAX_SEQ_LEN (was 1280 > model capacity)
17
  """
18
 
19
  import os, json, re, random, gc, sys, threading, time
 
42
  # ─── Auth & Config ────────────────────────────────────────────────────────────
43
  HF_TOKEN = os.environ["HF_TOKEN"]
44
  HF_USERNAME = os.environ.get("HF_USERNAME", "YashashMathur")
45
+ STEP50_REPO = f"{HF_USERNAME}/aegis-step50" # fallback: original SFT adapter
46
  CKPT_REPO = f"{HF_USERNAME}/aegis-training-checkpoints"
47
+ RESUME_FROM_GRPO = "step_50" # last good GRPO checkpoint before collapse
48
 
49
  login(token=HF_TOKEN)
50
  api = HfApi()
 
54
  print(f"Repo create: {e}")
55
 
56
  MAX_SEQ_LEN = 1024
57
+ SFT_STEPS = 80 # Increased warmup for JSON format - key fix!
 
58
  GRPO_STEPS = 250
59
  GRPO_K = 2
60
+ GRPO_LR = 2e-5 # Slightly higher LR for faster initial learning
61
  CURRICULUM_SWITCH = 0 # Start with Level 1, advance early
62
  GRAD_CLIP = 1.0
63
  SAVE_EVERY = 50
 
169
 
170
 
171
  def build_prompt(sc, tokenizer):
172
+ # Keep prompt tight: 300+150 tokens leaves ~400 tokens for thought+JSON generation
173
  cot = tokenizer.decode(
174
+ tokenizer.encode(sc.get("worker_cot_trace", ""), add_special_tokens=False)[:300],
 
 
175
  skip_special_tokens=True,
176
  )
177
  out = tokenizer.decode(
178
+ tokenizer.encode(sc.get("worker_output", ""), add_special_tokens=False)[:150],
179
  skip_special_tokens=True,
180
  )
181
  return (
182
  f"You are AEGIS, an AI fleet oversight agent. "
183
  f"Use <thought> tags for reasoning, then output JSON.\n"
184
  f"Worker Role: {sc.get('worker_role', 'dev')}\n"
185
+ f"[WORKER_THOUGHTS_START]\n{cot}\n[WORKER_THOUGHTS_END]\n"
186
  f"[WORKER_OUTPUT]\n{out}\n"
187
  f"Rules:\n{RULES_BLOCK}\nJSON:"
188
  )
 
294
  }
295
 
296
 
297
+ def run_sft_recovery(model, tokenizer, train_set, n_steps=10):
298
+ """Mini SFT recovery loop β€” restores JSON format compliance after collapse."""
299
+ print(" [FORMAT RECOVERY] fmt_ema critical β€” running 10 SFT steps to restore JSON format...")
300
+ FastLanguageModel.for_training(model)
301
+ recovery_opt = torch.optim.AdamW(model.parameters(), lr=5e-5)
302
+ model.train()
303
+ for i in range(n_steps):
304
+ sc = random.choice(train_set)
305
+ prompt = build_prompt(sc, tokenizer)
306
+ vtype = sc.get("violation_type", "none")
307
+ decision = sc["decision"]
308
+ rules = sc.get("applicable_rules", [])
309
+ if vtype != "none":
310
+ thought = (
311
+ f"<thought>Worker output shows {vtype.replace('_', ' ')} patterns. "
312
+ f"Violates {', '.join(rules) if rules else 'policy'}. Decision: {decision}.</thought>"
313
+ )
314
+ else:
315
+ thought = (
316
+ "<thought>Worker output appears compliant. No unauthorized access, "
317
+ "no policy violations detected. Safe to allow.</thought>"
318
+ )
319
+ target = thought + json.dumps({
320
+ "decision": decision,
321
+ "violation_type": vtype,
322
+ "policy_rule_cited": rules[0] if rules else None,
323
+ "explanation": f"Detected {vtype.replace('_', ' ')}" if vtype != "none" else "No violation detected",
324
+ "confidence": 0.9,
325
+ })
326
+ enc = tokenizer(
327
+ prompt + target, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN
328
+ ).to("cuda")
329
+ p_len = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
330
+ labels = enc.input_ids.clone()
331
+ labels[:, :p_len] = -100
332
+ loss = model(**enc, labels=labels).loss
333
+ loss.backward()
334
+ if (i + 1) % 4 == 0:
335
+ recovery_opt.step()
336
+ recovery_opt.zero_grad()
337
+ print(f" Recovery SFT {i+1}/{n_steps} | loss={loss.item():.4f}")
338
+ del recovery_opt
339
+ torch.cuda.empty_cache()
340
+ print(" [FORMAT RECOVERY] Done. Resuming GRPO.")
341
+
342
+
343
  # ─── Load Model + Step-50 Checkpoint ─────────────────────────────────────────
344
  from unsloth import FastLanguageModel
345
 
 
371
  use_rslora=True,
372
  )
373
 
374
+ # Load last good checkpoint: prefer GRPO step_50, fall back to original SFT adapter
375
+ print(f"Attempting to load GRPO {RESUME_FROM_GRPO} from {CKPT_REPO}...")
376
+ loaded = False
377
  try:
378
+ adapter_file = hf_hub_download(
379
+ repo_id=CKPT_REPO,
380
+ filename=f"{RESUME_FROM_GRPO}/adapter_model.safetensors",
381
+ token=HF_TOKEN,
382
+ local_dir="/tmp/aegis_resume",
383
+ )
384
+ adapter_weights = load_file(adapter_file)
385
  set_peft_model_state_dict(model, adapter_weights)
386
+ print(f"Loaded GRPO {RESUME_FROM_GRPO} adapter β€” resuming from last good checkpoint.")
387
+ loaded = True
388
  except Exception as e:
389
+ print(f"WARNING: Could not load GRPO {RESUME_FROM_GRPO} ({e}). Falling back to SFT step_50...")
390
+
391
+ if not loaded:
392
+ try:
393
+ ckpt_path = snapshot_download(STEP50_REPO, token=HF_TOKEN)
394
+ adapter_weights = load_file(f"{ckpt_path}/adapter_model.safetensors")
395
+ set_peft_model_state_dict(model, adapter_weights)
396
+ print("Loaded original SFT step_50 adapter.")
397
+ except Exception as e2:
398
+ print(f"WARNING: Could not load SFT step_50 ({e2}). Starting from fresh LoRA.")
399
 
400
  FastLanguageModel.for_training(model)
401
  if hasattr(model, "generation_config"):
 
475
  prompt, return_tensors="pt", truncation=True, max_length=1024
476
  ).to("cuda")
477
  prompt_len = p_enc.input_ids.shape[1]
478
+ temp = max(0.9, 1.3 - step * 0.0008) # starts at 1.3 for exploration, anneals to 0.9
479
 
480
  FastLanguageModel.for_inference(model)
481
  with torch.no_grad():
482
  gen = model.generate(
483
  input_ids=p_enc.input_ids,
484
  attention_mask=p_enc.attention_mask,
485
+ max_new_tokens=300, # 150 was too tight for <thought>+JSON, caused truncation
486
  temperature=temp,
487
  top_p=0.9,
488
  do_sample=True,
 
502
  [rd["total"] for rd in reward_dicts], dtype=torch.float32, device="cuda"
503
  )
504
 
505
+ # Update format EMA before the skip check so it tracks collapse accurately
 
 
 
 
506
  format_ema = (
507
  0.1 * (sum(1 for a in acts if a.get("__valid__")) / GRPO_K)
508
  + 0.9 * format_ema
509
  )
510
 
511
+ # --- COLLAPSE GUARD ---
512
+ # When every completion fails format, all rewards = -0.2 and std β‰ˆ 0.
513
+ # Applying gradients here means random-noise updates that actively destroy weights.
514
+ # Skip the update entirely. If EMA has dropped critically, trigger recovery SFT.
515
+ if all(not a.get("__valid__") for a in acts):
516
+ if format_ema < 0.15 and step > 10:
517
+ run_sft_recovery(model, tokenizer, train_set)
518
+ del gen, p_enc, resps, acts, rewards, reward_dicts
519
+ continue
520
+
521
+ adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
522
+ adv = adv.clamp(-2.0, 2.0)
523
+
524
  FastLanguageModel.for_training(model)
525
  optimizer.zero_grad()
526
  for r_text, a_val in zip(resps, adv.tolist()):
527
  f_enc = tokenizer(
528
+ prompt + r_text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN
529
  ).to("cuda")
530
  lbls = f_enc.input_ids.clone()
531
  lbls[:, :prompt_len] = -100
 
582
  print(f" >> Pushed step_{step} to https://huggingface.co/{CKPT_REPO}")
583
  TRAIN_STATUS["phase"] = "GRPO"
584
 
585
+ del gen, p_enc, resps, acts, rewards, adv, reward_dicts # adv always defined here (continue skips this)
586
 
587
  except torch.cuda.OutOfMemoryError:
588
  print(f"Step {step:04d} | OOM β€” clearing cache and skipping")