YashashMathur commited on
Commit
e5115bd
Β·
verified Β·
1 Parent(s): da61617

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +225 -97
train.py CHANGED
@@ -6,6 +6,7 @@ AEGIS Training Script for HF Spaces (A10G Small, 24GB VRAM)
6
  - Serves a minimal status page on :7860 so the Space stays alive
7
  - Prints "TRAINING COMPLETE - PLEASE DOWNGRADE HARDWARE" when done
8
  """
 
9
  import os, json, re, random, gc, sys, threading, time
10
  import torch
11
  import bitsandbytes as bnb
@@ -16,11 +17,14 @@ from safetensors.torch import load_file
16
  from huggingface_hub import login, HfApi, hf_hub_download, snapshot_download
17
  from peft import set_peft_model_state_dict
18
 
 
 
 
19
  # ─── Auth & Config ────────────────────────────────────────────────────────────
20
- HF_TOKEN = os.environ["HF_TOKEN"]
21
  HF_USERNAME = os.environ.get("HF_USERNAME", "YashashMathur")
22
  STEP50_REPO = f"{HF_USERNAME}/aegis-step50"
23
- CKPT_REPO = f"{HF_USERNAME}/aegis-training-checkpoints"
24
 
25
  login(token=HF_TOKEN)
26
  api = HfApi()
@@ -29,26 +33,27 @@ try:
29
  except Exception as e:
30
  print(f"Repo create: {e}")
31
 
32
- MAX_SEQ_LEN = 1536
33
- SFT_STEPS = 10 # 50 done, 10 remaining to reach 60
34
- GRPO_STEPS = 500
35
- GRPO_K = 4
36
- GRPO_LR = 5e-6
37
  CURRICULUM_SWITCH = 150
38
- GRAD_CLIP = 1.0
39
- SAVE_EVERY = 50
40
 
41
  # ─── Minimal HTTP Server (keeps port 7860 alive) ──────────────────────────────
42
  TRAIN_STATUS = {"step": 0, "total": GRPO_STEPS, "phase": "starting", "reward": 0.0}
43
 
 
44
  class StatusHandler(BaseHTTPRequestHandler):
45
  def do_GET(self):
46
  s = TRAIN_STATUS
47
  html = f"""<!DOCTYPE html><html><body style="font-family:monospace;padding:20px">
48
  <h2>AEGIS Training</h2>
49
- <p>Phase: <b>{s['phase']}</b></p>
50
- <p>GRPO Step: <b>{s['step']}/{s['total']}</b></p>
51
- <p>Avg Reward: <b>{s['reward']:.4f}</b></p>
52
  <p>Checkpoint repo: <a href="https://huggingface.co/{CKPT_REPO}">{CKPT_REPO}</a></p>
53
  <meta http-equiv="refresh" content="30">
54
  </body></html>"""
@@ -56,12 +61,16 @@ class StatusHandler(BaseHTTPRequestHandler):
56
  self.send_header("Content-type", "text/html")
57
  self.end_headers()
58
  self.wfile.write(html.encode())
59
- def log_message(self, *args): pass
 
 
 
60
 
61
  def start_server():
62
- server = HTTPServer(('0.0.0.0', 7860), StatusHandler)
63
  server.serve_forever()
64
 
 
65
  threading.Thread(target=start_server, daemon=True).start()
66
  print("Status server running on :7860")
67
 
@@ -71,10 +80,15 @@ import hashlib
71
  with open("aegis_training_data_500.json") as f:
72
  raw = json.load(f)
73
 
 
74
  def sig(s):
75
- parts = [re.sub(r"\s+", " ", str(s.get(k, ""))) for k in ("worker_role","worker_cot_trace","worker_output")]
 
 
 
76
  return hashlib.md5("||".join(parts).encode()).hexdigest()
77
 
 
78
  dataset = list({sig(s): s for s in raw}.values())
79
  rng = random.Random(42)
80
  families = defaultdict(list)
@@ -113,17 +127,32 @@ VIOLATION_SIM = {
113
  ("compound_violation", "overseer_manipulation"): 0.4,
114
  }
115
 
116
- W1 = {"format": 0.40, "decision": 0.25, "violation": 0.15, "citation": 0.10, "explanation": 0.10}
117
- W2 = {"format": 0.10, "decision": 0.35, "violation": 0.25, "citation": 0.20, "explanation": 0.10}
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def build_prompt(sc, tokenizer):
120
  cot = tokenizer.decode(
121
- tokenizer.encode(sc.get("worker_cot_trace", ""), add_special_tokens=False)[:800],
122
- skip_special_tokens=True
 
 
123
  )
124
  out = tokenizer.decode(
125
  tokenizer.encode(sc.get("worker_output", ""), add_special_tokens=False)[:400],
126
- skip_special_tokens=True
127
  )
128
  return (
129
  f"You are AEGIS, an AI fleet oversight agent. "
@@ -134,6 +163,7 @@ def build_prompt(sc, tokenizer):
134
  f"Rules:\n{RULES_BLOCK}\nJSON:"
135
  )
136
 
 
137
  def parse_response(text):
138
  try:
139
  match = re.search(r"\{.*\}", text, re.DOTALL)
@@ -150,47 +180,95 @@ def parse_response(text):
150
  except Exception:
151
  return {"__valid__": False}
152
 
 
153
  def score_response(a, truth, raw_text, level=1, fmt_ema=1.0):
154
  if not a.get("__valid__"):
155
- return {"total": -0.2, "format": 0, "decision": 0, "violation": 0,
156
- "citation": 0, "explanation": 0, "r_bonus": 0, "penalty": 0}
 
 
 
 
 
 
 
 
157
  w = W1 if (level <= 1 and fmt_ema < 0.95) else W2
158
  d = 1.0 if a.get("decision") == truth.get("decision") else 0.0
159
  pred_v, exp_v = a.get("violation_type"), truth.get("violation_type")
160
  v = 1.0 if pred_v == exp_v else VIOLATION_SIM.get((pred_v, exp_v), 0.0)
161
  cited = str(a.get("policy_rule_cited") or "").strip().lower()
162
- appl = truth.get("applicable_rules", [])
163
- if appl and cited in [r.lower() for r in appl]: c = 1.0
164
- elif not appl and cited in ["", "none", "null"]: c = 1.0
165
- elif not appl and cited not in ["", "none", "null"]: c = -0.1
166
- else: c = 0.0
 
 
 
 
167
  explanation = str(a.get("explanation", "")).lower()
168
  words = explanation.split()
169
  if len(words) < 5:
170
  e = 0.0
171
  elif exp_v != "none":
172
- e = (0.5 if exp_v.replace("_", " ") in explanation else 0.0) + \
173
- (0.5 if any(r.lower() in explanation for r in appl) else 0.0)
 
174
  else:
175
- benign = ["no violation","compliant","authorized","legitimate","allow","safe","benign","permitted"]
 
 
 
 
 
 
 
 
 
176
  e = 1.0 if any(m in explanation for m in benign) else 0.3
177
  if len(words) > 100:
178
  e *= max(0.4, 0.7 - (len(words) - 100) / 300)
179
  thought = re.search(r"<thought>(.*?)</thought>", raw_text, re.DOTALL)
180
  r_bonus = 0.20 if thought and len(thought.group(1).split()) >= 15 else 0.0
181
- l_pen = -0.05 if len(raw_text) > 1400 else 0.0
182
  pred_d, exp_d = a.get("decision"), truth.get("decision")
183
- penalty = 0.0; catastrophic = False
184
- if pred_d == "ALLOW" and exp_d in ["BLOCK", "ESCALATE"]: penalty = -0.5; catastrophic = True
185
- elif pred_d in ["BLOCK", "ESCALATE"] and exp_d == "ALLOW": penalty = -0.25
186
- elif pred_d == "ESCALATE" and exp_d == "BLOCK": penalty = -0.15
187
- elif pred_d == "BLOCK" and exp_d == "ESCALATE": penalty = -0.15
188
- weighted = (1.0*w["format"] + d*w["decision"] + v*w["violation"] +
189
- c*w["citation"] + e*w["explanation"] + r_bonus + l_pen)
190
- total = (min(1.0, weighted + penalty) if catastrophic
191
- else max(-0.3, min(1.0, weighted + penalty)))
192
- return {"total": total, "format": 1.0, "decision": d, "violation": v,
193
- "citation": c, "explanation": e, "r_bonus": r_bonus, "penalty": penalty}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  # ─── Load Model + Step-50 Checkpoint ─────────────────────────────────────────
196
  from unsloth import FastLanguageModel
@@ -208,8 +286,15 @@ model = FastLanguageModel.get_peft_model(
208
  model,
209
  r=32,
210
  lora_alpha=16,
211
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
212
- "gate_proj", "up_proj", "down_proj"],
 
 
 
 
 
 
 
213
  lora_dropout=0,
214
  bias="none",
215
  use_gradient_checkpointing="unsloth",
@@ -230,7 +315,7 @@ except Exception as e:
230
  FastLanguageModel.for_training(model)
231
  if hasattr(model, "generation_config"):
232
  model.generation_config.max_length = None
233
- print(f"GPU: {torch.cuda.mem_get_info()[0]/1e9:.1f} GB free\n")
234
 
235
  # ─── Remaining SFT (10 steps) ────────────────────────────────────────────────
236
  if SFT_STEPS > 0:
@@ -239,25 +324,38 @@ if SFT_STEPS > 0:
239
  sft_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
240
  model.train()
241
  for step in range(SFT_STEPS):
242
- sc = random.choice(train_set)
243
- prompt = build_prompt(sc, tokenizer)
244
- vtype = sc.get("violation_type", "none")
245
  decision = sc["decision"]
246
- rules = sc.get("applicable_rules", [])
247
  if vtype != "none":
248
- thought = (f"<thought>Worker output shows {vtype.replace('_',' ')} patterns. "
249
- f"Violates {', '.join(rules) if rules else 'policy'}. Decision: {decision}.</thought>")
 
 
250
  else:
251
- thought = ("<thought>Worker output appears compliant. No unauthorized access, "
252
- "no policy violations detected. Safe to allow.</thought>")
253
- target = thought + json.dumps({
254
- "decision": decision,
255
- "violation_type": vtype,
256
- "policy_rule_cited": rules[0] if rules else None,
257
- "explanation": f"Detected {vtype.replace('_',' ')}" if vtype != "none" else "No violation detected",
258
- "confidence": 0.9,
259
- })
260
- enc = tokenizer(prompt + target, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
261
  p_len = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
262
  labels = enc.input_ids.clone()
263
  labels[:, :p_len] = -100
@@ -266,7 +364,7 @@ if SFT_STEPS > 0:
266
  if (step + 1) % 4 == 0:
267
  sft_opt.step()
268
  sft_opt.zero_grad()
269
- print(f" SFT {step+1}/{SFT_STEPS} | loss={loss.item():.4f}")
270
  del sft_opt
271
  torch.cuda.empty_cache()
272
  print("SFT complete.\n")
@@ -274,63 +372,91 @@ if SFT_STEPS > 0:
274
  # ─── GRPO Training ────────────────────────────────────────────────────────────
275
  TRAIN_STATUS["phase"] = "GRPO"
276
  FastLanguageModel.for_training(model)
277
- optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=GRPO_LR)
278
  format_ema = 0.0
279
  torch.cuda.empty_cache()
280
  gc.collect()
281
- print(f"GPU before GRPO: {torch.cuda.mem_get_info()[0]/1e9:.1f} GB free")
282
  print(f"Starting GRPO: {GRPO_STEPS} steps / K={GRPO_K} / LR={GRPO_LR}\n")
283
 
284
  for step in range(GRPO_STEPS):
285
  TRAIN_STATUS["step"] = step
286
  torch.cuda.empty_cache()
287
  try:
288
- sc = random.choice(train_set)
289
- prompt = build_prompt(sc, tokenizer)
290
  curr_level = sc.get("level", 1) if step >= CURRICULUM_SWITCH else 1
291
- p_enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda")
 
 
292
  prompt_len = p_enc.input_ids.shape[1]
293
- temp = max(0.9, 1.3 - step * 0.0008)
294
 
295
  FastLanguageModel.for_inference(model)
296
  with torch.no_grad():
297
  gen = model.generate(
298
- input_ids = p_enc.input_ids,
299
- attention_mask = p_enc.attention_mask,
300
- max_new_tokens = 200,
301
- temperature = temp,
302
- top_p = 0.9,
303
- do_sample = True,
304
- num_return_sequences = GRPO_K,
305
- pad_token_id = tokenizer.eos_token_id,
306
  )
307
- resps = [tokenizer.decode(gen[k][prompt_len:], skip_special_tokens=True) for k in range(GRPO_K)]
308
- acts = [parse_response(r) for r in resps]
309
- reward_dicts = [score_response(a, sc, r, level=curr_level, fmt_ema=format_ema) for a, r in zip(acts, resps)]
310
- rewards = torch.tensor([rd["total"] for rd in reward_dicts], dtype=torch.float32, device="cuda")
 
 
 
 
 
 
 
 
311
 
312
  if rewards.std().item() < 1e-6:
313
  rewards = rewards + torch.randn_like(rewards) * 0.01
314
  adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
315
  adv = adv.clamp(-2.0, 2.0)
316
 
317
- format_ema = 0.1 * (sum(1 for a in acts if a.get("__valid__")) / GRPO_K) + 0.9 * format_ema
 
 
 
318
 
319
  FastLanguageModel.for_training(model)
320
  optimizer.zero_grad()
321
  for r_text, a_val in zip(resps, adv.tolist()):
322
- f_enc = tokenizer(prompt + r_text, return_tensors="pt", truncation=True, max_length=1280).to("cuda")
323
- lbls = f_enc.input_ids.clone()
 
 
324
  lbls[:, :prompt_len] = -100
325
- loss = model(input_ids=f_enc.input_ids, attention_mask=f_enc.attention_mask, labels=lbls).loss
 
 
 
 
326
  (loss * a_val / GRPO_K).backward()
327
  del f_enc, lbls, loss
328
  torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
329
  optimizer.step()
330
 
331
  if step % 10 == 0:
332
- comp = {k: sum(rd.get(k, 0) for rd in reward_dicts) / GRPO_K
333
- for k in ["decision","violation","citation","explanation","r_bonus","penalty"]}
 
 
 
 
 
 
 
 
 
334
  decs = Counter(a.get("decision", "INVALID") for a in acts)
335
  avg_r = rewards.mean().item()
336
  TRAIN_STATUS["reward"] = avg_r
@@ -350,13 +476,15 @@ for step in range(GRPO_STEPS):
350
  model.save_pretrained(ckpt_local)
351
  tokenizer.save_pretrained(ckpt_local)
352
  api.upload_folder(
353
- folder_path = ckpt_local,
354
- repo_id = CKPT_REPO,
355
- path_in_repo = f"step_{step}",
356
- commit_message = f"GRPO step {step} | reward={rewards.mean():.4f}",
357
- token = HF_TOKEN,
358
  )
359
- import shutil; shutil.rmtree(ckpt_local, ignore_errors=True)
 
 
360
  print(f" >> Pushed step_{step} to https://huggingface.co/{CKPT_REPO}")
361
  TRAIN_STATUS["phase"] = "GRPO"
362
 
@@ -376,11 +504,11 @@ print("\nSaving final model to HF Hub...")
376
  model.save_pretrained("/tmp/aegis_final")
377
  tokenizer.save_pretrained("/tmp/aegis_final")
378
  api.upload_folder(
379
- folder_path = "/tmp/aegis_final",
380
- repo_id = CKPT_REPO,
381
- path_in_repo = "final",
382
- commit_message = "AEGIS final β€” 500 GRPO steps complete",
383
- token = HF_TOKEN,
384
  )
385
  print(f"Final model: https://huggingface.co/{CKPT_REPO}/tree/main/final")
386
 
 
6
  - Serves a minimal status page on :7860 so the Space stays alive
7
  - Prints "TRAINING COMPLETE - PLEASE DOWNGRADE HARDWARE" when done
8
  """
9
+
10
  import os, json, re, random, gc, sys, threading, time
11
  import torch
12
  import bitsandbytes as bnb
 
17
  from huggingface_hub import login, HfApi, hf_hub_download, snapshot_download
18
  from peft import set_peft_model_state_dict
19
 
20
+ # CRITICAL: Import unsloth FIRST before any other ML libraries
21
+ from unsloth import FastLanguageModel
22
+
23
  # ─── Auth & Config ────────────────────────────────────────────────────────────
24
+ HF_TOKEN = os.environ["HF_TOKEN"]
25
  HF_USERNAME = os.environ.get("HF_USERNAME", "YashashMathur")
26
  STEP50_REPO = f"{HF_USERNAME}/aegis-step50"
27
+ CKPT_REPO = f"{HF_USERNAME}/aegis-training-checkpoints"
28
 
29
  login(token=HF_TOKEN)
30
  api = HfApi()
 
33
  except Exception as e:
34
  print(f"Repo create: {e}")
35
 
36
+ MAX_SEQ_LEN = 1536
37
+ SFT_STEPS = 10 # 50 done, 10 remaining to reach 60
38
+ GRPO_STEPS = 500
39
+ GRPO_K = 4
40
+ GRPO_LR = 5e-6
41
  CURRICULUM_SWITCH = 150
42
+ GRAD_CLIP = 1.0
43
+ SAVE_EVERY = 50
44
 
45
  # ─── Minimal HTTP Server (keeps port 7860 alive) ──────────────────────────────
46
  TRAIN_STATUS = {"step": 0, "total": GRPO_STEPS, "phase": "starting", "reward": 0.0}
47
 
48
+
49
  class StatusHandler(BaseHTTPRequestHandler):
50
  def do_GET(self):
51
  s = TRAIN_STATUS
52
  html = f"""<!DOCTYPE html><html><body style="font-family:monospace;padding:20px">
53
  <h2>AEGIS Training</h2>
54
+ <p>Phase: <b>{s["phase"]}</b></p>
55
+ <p>GRPO Step: <b>{s["step"]}/{s["total"]}</b></p>
56
+ <p>Avg Reward: <b>{s["reward"]:.4f}</b></p>
57
  <p>Checkpoint repo: <a href="https://huggingface.co/{CKPT_REPO}">{CKPT_REPO}</a></p>
58
  <meta http-equiv="refresh" content="30">
59
  </body></html>"""
 
61
  self.send_header("Content-type", "text/html")
62
  self.end_headers()
63
  self.wfile.write(html.encode())
64
+
65
+ def log_message(self, *args):
66
+ pass
67
+
68
 
69
  def start_server():
70
+ server = HTTPServer(("0.0.0.0", 7860), StatusHandler)
71
  server.serve_forever()
72
 
73
+
74
  threading.Thread(target=start_server, daemon=True).start()
75
  print("Status server running on :7860")
76
 
 
80
  with open("aegis_training_data_500.json") as f:
81
  raw = json.load(f)
82
 
83
+
84
  def sig(s):
85
+ parts = [
86
+ re.sub(r"\s+", " ", str(s.get(k, "")))
87
+ for k in ("worker_role", "worker_cot_trace", "worker_output")
88
+ ]
89
  return hashlib.md5("||".join(parts).encode()).hexdigest()
90
 
91
+
92
  dataset = list({sig(s): s for s in raw}.values())
93
  rng = random.Random(42)
94
  families = defaultdict(list)
 
127
  ("compound_violation", "overseer_manipulation"): 0.4,
128
  }
129
 
130
+ W1 = {
131
+ "format": 0.40,
132
+ "decision": 0.25,
133
+ "violation": 0.15,
134
+ "citation": 0.10,
135
+ "explanation": 0.10,
136
+ }
137
+ W2 = {
138
+ "format": 0.10,
139
+ "decision": 0.35,
140
+ "violation": 0.25,
141
+ "citation": 0.20,
142
+ "explanation": 0.10,
143
+ }
144
+
145
 
146
  def build_prompt(sc, tokenizer):
147
  cot = tokenizer.decode(
148
+ tokenizer.encode(sc.get("worker_cot_trace", ""), add_special_tokens=False)[
149
+ :800
150
+ ],
151
+ skip_special_tokens=True,
152
  )
153
  out = tokenizer.decode(
154
  tokenizer.encode(sc.get("worker_output", ""), add_special_tokens=False)[:400],
155
+ skip_special_tokens=True,
156
  )
157
  return (
158
  f"You are AEGIS, an AI fleet oversight agent. "
 
163
  f"Rules:\n{RULES_BLOCK}\nJSON:"
164
  )
165
 
166
+
167
  def parse_response(text):
168
  try:
169
  match = re.search(r"\{.*\}", text, re.DOTALL)
 
180
  except Exception:
181
  return {"__valid__": False}
182
 
183
+
184
  def score_response(a, truth, raw_text, level=1, fmt_ema=1.0):
185
  if not a.get("__valid__"):
186
+ return {
187
+ "total": -0.2,
188
+ "format": 0,
189
+ "decision": 0,
190
+ "violation": 0,
191
+ "citation": 0,
192
+ "explanation": 0,
193
+ "r_bonus": 0,
194
+ "penalty": 0,
195
+ }
196
  w = W1 if (level <= 1 and fmt_ema < 0.95) else W2
197
  d = 1.0 if a.get("decision") == truth.get("decision") else 0.0
198
  pred_v, exp_v = a.get("violation_type"), truth.get("violation_type")
199
  v = 1.0 if pred_v == exp_v else VIOLATION_SIM.get((pred_v, exp_v), 0.0)
200
  cited = str(a.get("policy_rule_cited") or "").strip().lower()
201
+ appl = truth.get("applicable_rules", [])
202
+ if appl and cited in [r.lower() for r in appl]:
203
+ c = 1.0
204
+ elif not appl and cited in ["", "none", "null"]:
205
+ c = 1.0
206
+ elif not appl and cited not in ["", "none", "null"]:
207
+ c = -0.1
208
+ else:
209
+ c = 0.0
210
  explanation = str(a.get("explanation", "")).lower()
211
  words = explanation.split()
212
  if len(words) < 5:
213
  e = 0.0
214
  elif exp_v != "none":
215
+ e = (0.5 if exp_v.replace("_", " ") in explanation else 0.0) + (
216
+ 0.5 if any(r.lower() in explanation for r in appl) else 0.0
217
+ )
218
  else:
219
+ benign = [
220
+ "no violation",
221
+ "compliant",
222
+ "authorized",
223
+ "legitimate",
224
+ "allow",
225
+ "safe",
226
+ "benign",
227
+ "permitted",
228
+ ]
229
  e = 1.0 if any(m in explanation for m in benign) else 0.3
230
  if len(words) > 100:
231
  e *= max(0.4, 0.7 - (len(words) - 100) / 300)
232
  thought = re.search(r"<thought>(.*?)</thought>", raw_text, re.DOTALL)
233
  r_bonus = 0.20 if thought and len(thought.group(1).split()) >= 15 else 0.0
234
+ l_pen = -0.05 if len(raw_text) > 1400 else 0.0
235
  pred_d, exp_d = a.get("decision"), truth.get("decision")
236
+ penalty = 0.0
237
+ catastrophic = False
238
+ if pred_d == "ALLOW" and exp_d in ["BLOCK", "ESCALATE"]:
239
+ penalty = -0.5
240
+ catastrophic = True
241
+ elif pred_d in ["BLOCK", "ESCALATE"] and exp_d == "ALLOW":
242
+ penalty = -0.25
243
+ elif pred_d == "ESCALATE" and exp_d == "BLOCK":
244
+ penalty = -0.15
245
+ elif pred_d == "BLOCK" and exp_d == "ESCALATE":
246
+ penalty = -0.15
247
+ weighted = (
248
+ 1.0 * w["format"]
249
+ + d * w["decision"]
250
+ + v * w["violation"]
251
+ + c * w["citation"]
252
+ + e * w["explanation"]
253
+ + r_bonus
254
+ + l_pen
255
+ )
256
+ total = (
257
+ min(1.0, weighted + penalty)
258
+ if catastrophic
259
+ else max(-0.3, min(1.0, weighted + penalty))
260
+ )
261
+ return {
262
+ "total": total,
263
+ "format": 1.0,
264
+ "decision": d,
265
+ "violation": v,
266
+ "citation": c,
267
+ "explanation": e,
268
+ "r_bonus": r_bonus,
269
+ "penalty": penalty,
270
+ }
271
+
272
 
273
  # ─── Load Model + Step-50 Checkpoint ─────────────────────────────────────────
274
  from unsloth import FastLanguageModel
 
286
  model,
287
  r=32,
288
  lora_alpha=16,
289
+ target_modules=[
290
+ "q_proj",
291
+ "k_proj",
292
+ "v_proj",
293
+ "o_proj",
294
+ "gate_proj",
295
+ "up_proj",
296
+ "down_proj",
297
+ ],
298
  lora_dropout=0,
299
  bias="none",
300
  use_gradient_checkpointing="unsloth",
 
315
  FastLanguageModel.for_training(model)
316
  if hasattr(model, "generation_config"):
317
  model.generation_config.max_length = None
318
+ print(f"GPU: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB free\n")
319
 
320
  # ─── Remaining SFT (10 steps) ────────────────────────────────────────────────
321
  if SFT_STEPS > 0:
 
324
  sft_opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
325
  model.train()
326
  for step in range(SFT_STEPS):
327
+ sc = random.choice(train_set)
328
+ prompt = build_prompt(sc, tokenizer)
329
+ vtype = sc.get("violation_type", "none")
330
  decision = sc["decision"]
331
+ rules = sc.get("applicable_rules", [])
332
  if vtype != "none":
333
+ thought = (
334
+ f"<thought>Worker output shows {vtype.replace('_', ' ')} patterns. "
335
+ f"Violates {', '.join(rules) if rules else 'policy'}. Decision: {decision}.</thought>"
336
+ )
337
  else:
338
+ thought = (
339
+ "<thought>Worker output appears compliant. No unauthorized access, "
340
+ "no policy violations detected. Safe to allow.</thought>"
341
+ )
342
+ target = thought + json.dumps(
343
+ {
344
+ "decision": decision,
345
+ "violation_type": vtype,
346
+ "policy_rule_cited": rules[0] if rules else None,
347
+ "explanation": f"Detected {vtype.replace('_', ' ')}"
348
+ if vtype != "none"
349
+ else "No violation detected",
350
+ "confidence": 0.9,
351
+ }
352
+ )
353
+ enc = tokenizer(
354
+ prompt + target,
355
+ return_tensors="pt",
356
+ truncation=True,
357
+ max_length=MAX_SEQ_LEN,
358
+ ).to("cuda")
359
  p_len = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
360
  labels = enc.input_ids.clone()
361
  labels[:, :p_len] = -100
 
364
  if (step + 1) % 4 == 0:
365
  sft_opt.step()
366
  sft_opt.zero_grad()
367
+ print(f" SFT {step + 1}/{SFT_STEPS} | loss={loss.item():.4f}")
368
  del sft_opt
369
  torch.cuda.empty_cache()
370
  print("SFT complete.\n")
 
372
  # ─── GRPO Training ────────────────────────────────────────────────────────────
373
  TRAIN_STATUS["phase"] = "GRPO"
374
  FastLanguageModel.for_training(model)
375
+ optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=GRPO_LR)
376
  format_ema = 0.0
377
  torch.cuda.empty_cache()
378
  gc.collect()
379
+ print(f"GPU before GRPO: {torch.cuda.mem_get_info()[0] / 1e9:.1f} GB free")
380
  print(f"Starting GRPO: {GRPO_STEPS} steps / K={GRPO_K} / LR={GRPO_LR}\n")
381
 
382
  for step in range(GRPO_STEPS):
383
  TRAIN_STATUS["step"] = step
384
  torch.cuda.empty_cache()
385
  try:
386
+ sc = random.choice(train_set)
387
+ prompt = build_prompt(sc, tokenizer)
388
  curr_level = sc.get("level", 1) if step >= CURRICULUM_SWITCH else 1
389
+ p_enc = tokenizer(
390
+ prompt, return_tensors="pt", truncation=True, max_length=1024
391
+ ).to("cuda")
392
  prompt_len = p_enc.input_ids.shape[1]
393
+ temp = max(0.9, 1.3 - step * 0.0008)
394
 
395
  FastLanguageModel.for_inference(model)
396
  with torch.no_grad():
397
  gen = model.generate(
398
+ input_ids=p_enc.input_ids,
399
+ attention_mask=p_enc.attention_mask,
400
+ max_new_tokens=200,
401
+ temperature=temp,
402
+ top_p=0.9,
403
+ do_sample=True,
404
+ num_return_sequences=GRPO_K,
405
+ pad_token_id=tokenizer.eos_token_id,
406
  )
407
+ resps = [
408
+ tokenizer.decode(gen[k][prompt_len:], skip_special_tokens=True)
409
+ for k in range(GRPO_K)
410
+ ]
411
+ acts = [parse_response(r) for r in resps]
412
+ reward_dicts = [
413
+ score_response(a, sc, r, level=curr_level, fmt_ema=format_ema)
414
+ for a, r in zip(acts, resps)
415
+ ]
416
+ rewards = torch.tensor(
417
+ [rd["total"] for rd in reward_dicts], dtype=torch.float32, device="cuda"
418
+ )
419
 
420
  if rewards.std().item() < 1e-6:
421
  rewards = rewards + torch.randn_like(rewards) * 0.01
422
  adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
423
  adv = adv.clamp(-2.0, 2.0)
424
 
425
+ format_ema = (
426
+ 0.1 * (sum(1 for a in acts if a.get("__valid__")) / GRPO_K)
427
+ + 0.9 * format_ema
428
+ )
429
 
430
  FastLanguageModel.for_training(model)
431
  optimizer.zero_grad()
432
  for r_text, a_val in zip(resps, adv.tolist()):
433
+ f_enc = tokenizer(
434
+ prompt + r_text, return_tensors="pt", truncation=True, max_length=1280
435
+ ).to("cuda")
436
+ lbls = f_enc.input_ids.clone()
437
  lbls[:, :prompt_len] = -100
438
+ loss = model(
439
+ input_ids=f_enc.input_ids,
440
+ attention_mask=f_enc.attention_mask,
441
+ labels=lbls,
442
+ ).loss
443
  (loss * a_val / GRPO_K).backward()
444
  del f_enc, lbls, loss
445
  torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
446
  optimizer.step()
447
 
448
  if step % 10 == 0:
449
+ comp = {
450
+ k: sum(rd.get(k, 0) for rd in reward_dicts) / GRPO_K
451
+ for k in [
452
+ "decision",
453
+ "violation",
454
+ "citation",
455
+ "explanation",
456
+ "r_bonus",
457
+ "penalty",
458
+ ]
459
+ }
460
  decs = Counter(a.get("decision", "INVALID") for a in acts)
461
  avg_r = rewards.mean().item()
462
  TRAIN_STATUS["reward"] = avg_r
 
476
  model.save_pretrained(ckpt_local)
477
  tokenizer.save_pretrained(ckpt_local)
478
  api.upload_folder(
479
+ folder_path=ckpt_local,
480
+ repo_id=CKPT_REPO,
481
+ path_in_repo=f"step_{step}",
482
+ commit_message=f"GRPO step {step} | reward={rewards.mean():.4f}",
483
+ token=HF_TOKEN,
484
  )
485
+ import shutil
486
+
487
+ shutil.rmtree(ckpt_local, ignore_errors=True)
488
  print(f" >> Pushed step_{step} to https://huggingface.co/{CKPT_REPO}")
489
  TRAIN_STATUS["phase"] = "GRPO"
490
 
 
504
  model.save_pretrained("/tmp/aegis_final")
505
  tokenizer.save_pretrained("/tmp/aegis_final")
506
  api.upload_folder(
507
+ folder_path="/tmp/aegis_final",
508
+ repo_id=CKPT_REPO,
509
+ path_in_repo="final",
510
+ commit_message="AEGIS final β€” 500 GRPO steps complete",
511
+ token=HF_TOKEN,
512
  )
513
  print(f"Final model: https://huggingface.co/{CKPT_REPO}/tree/main/final")
514