muskan singh commited on
Commit
869f731
·
1 Parent(s): e22f664

training script

Browse files
Files changed (1) hide show
  1. training/train.py +512 -0
training/train.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OrgOS GRPO Training Script
3
+ Equivalent to training/grpo_orgos.ipynb but runs headlessly.
4
+
5
+ Outputs:
6
+ training_log.txt — structured training log for submission
7
+ before_after_curves.png — score improvement chart
8
+ orgos_lora_adapter/ — trained LoRA weights
9
+ """
10
+
11
+ import datetime
12
+ import json
13
+ import os
14
+ import re
15
+ import subprocess
16
+ import sys
17
+ import time
18
+ from typing import List
19
+
20
+ import httpx
21
+ import matplotlib
22
+ matplotlib.use("Agg") # headless — no display needed
23
+ import matplotlib.pyplot as plt
24
+ import matplotlib.gridspec as gridspec
25
+ import numpy as np
26
+ import torch
27
+ from datasets import Dataset
28
+ from transformers import TrainerCallback
29
+ from trl import GRPOConfig, GRPOTrainer
30
+ from unsloth import FastLanguageModel
31
+
32
+ # ------------------------------------------------------------------
33
+ # Config
34
+ # ------------------------------------------------------------------
35
+
36
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-3B-Instruct")
37
+ ENV_URL = "http://localhost:8000"
38
+ LOG_FILE = "training_log.txt"
39
+ N_PROMPTS_PER_WORKFLOW = 20
40
+ N_EVAL = 10
41
+ NUM_EPOCHS = 3
42
+ BATCH_SIZE = 4
43
+ GRAD_ACCUM = 2
44
+ LR = 5e-5
45
+ NUM_GEN = 4
46
+ TEMPERATURE = 0.8
47
+ BETA = 0.04
48
+ LORA_R = 16
49
+ MAX_SEQ_LEN = 2048
50
+
51
+ # ------------------------------------------------------------------
52
+ # Logger
53
+ # ------------------------------------------------------------------
54
+
55
+ with open(LOG_FILE, "w") as f:
56
+ f.write(f"# OrgOS GRPO Training Log\n")
57
+ f.write(f"# Generated: {datetime.datetime.utcnow().isoformat()}Z\n\n")
58
+
59
+
60
+ def tlog(line: str) -> None:
61
+ print(line, flush=True)
62
+ with open(LOG_FILE, "a") as f:
63
+ f.write(line + "\n")
64
+
65
+
66
+ # ------------------------------------------------------------------
67
+ # Start OrgOS environment server
68
+ # ------------------------------------------------------------------
69
+
70
+ def start_env_server():
71
+ print("Starting OrgOS environment server...", flush=True)
72
+ proc = subprocess.Popen(
73
+ [sys.executable, "-m", "uvicorn", "server.app:app",
74
+ "--host", "0.0.0.0", "--port", "8000"],
75
+ stdout=subprocess.DEVNULL,
76
+ stderr=subprocess.DEVNULL,
77
+ )
78
+ # Wait until healthy
79
+ for _ in range(20):
80
+ time.sleep(2)
81
+ try:
82
+ health = httpx.get(f"{ENV_URL}/health", timeout=5).json()
83
+ if health.get("status") == "healthy":
84
+ tlog(f"[ENV] status=healthy version={health.get('version', '?')}")
85
+ return proc
86
+ except Exception:
87
+ pass
88
+ raise RuntimeError("OrgOS server failed to start after 40 seconds")
89
+
90
+
91
+ # ------------------------------------------------------------------
92
+ # Model
93
+ # ------------------------------------------------------------------
94
+
95
+ def load_model():
96
+ model, tokenizer = FastLanguageModel.from_pretrained(
97
+ model_name = MODEL_NAME,
98
+ max_seq_length = MAX_SEQ_LEN,
99
+ dtype = None,
100
+ load_in_4bit = True,
101
+ )
102
+ model = FastLanguageModel.get_peft_model(
103
+ model,
104
+ r = LORA_R,
105
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
106
+ "gate_proj", "up_proj", "down_proj"],
107
+ lora_alpha = LORA_R,
108
+ lora_dropout = 0,
109
+ bias = "none",
110
+ use_gradient_checkpointing = "unsloth",
111
+ random_state = 42,
112
+ )
113
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
114
+ tlog(f"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} "
115
+ f"max_seq_len={MAX_SEQ_LEN} trainable_params={trainable:,} quantization=4bit")
116
+ return model, tokenizer
117
+
118
+
119
+ # ------------------------------------------------------------------
120
+ # Helpers
121
+ # ------------------------------------------------------------------
122
+
123
+ SYSTEM_PROMPT = """\
124
+ You are OrgOS Agent — an enterprise workflow automation agent.
125
+ You operate across four SaaS applications: Jira, Zendesk, Salesforce, and Workday.
126
+
127
+ Each turn you receive a JSON observation with:
128
+ - workflow_goal : the task you must complete
129
+ - pending_steps : remaining steps in the workflow
130
+ - app_states : current state of each app
131
+ - schema_hints : field renames in effect this episode (e.g. {"jira.priority": "severity"})
132
+ - active_rules : current SLA / approval thresholds
133
+ - message : feedback from the last action
134
+ - current_score : your cumulative score (0.001-0.999)
135
+
136
+ Respond ONLY with a valid JSON object — no markdown, no explanation.
137
+
138
+ Action format:
139
+ {"app": "<app>", "operation": "<op>", "args": {...}}
140
+
141
+ Available apps and key operations:
142
+ jira: get_issue, create_issue, update_status, set_priority, assign_owner,
143
+ add_label, link_zendesk_ticket, close_issue, list_issues
144
+ zendesk: get_ticket, acknowledge_ticket, set_urgency, assign_agent,
145
+ escalate_to_jira, resolve_ticket, add_note, list_tickets,
146
+ create_agent_profile
147
+ salesforce: get_account, list_accounts, update_deal_stage, flag_churn_risk,
148
+ assign_account_owner, log_interaction, get_opportunity
149
+ workday: get_employee, list_employees, provision_access, log_sla_event,
150
+ request_budget_approval, create_onboarding_task, complete_task
151
+
152
+ CRITICAL RULES:
153
+ 1. Read schema_hints FIRST — if "jira.priority" -> "severity", use "severity" not "priority" in args.
154
+ 2. Complete ALL pending_steps in order.
155
+ 3. Do not repeat a successful action.
156
+ 4. If an operation fails, read the message carefully and adapt.
157
+ 5. Use list_* operations to discover record IDs when needed.
158
+ 6. Stop when pending_steps is empty or done=true.
159
+ """
160
+
161
+
162
+ def obs_to_text(obs: dict) -> str:
163
+ hints = obs.get("schema_hints", {})
164
+ pending = obs.get("pending_steps", [])
165
+ lines = [
166
+ f"current_score: {obs['current_score']}",
167
+ f"step_count: {obs['step_count']}",
168
+ f"workflow_id: {obs['workflow_id']}",
169
+ "",
170
+ "=== WORKFLOW GOAL ===",
171
+ obs["workflow_goal"],
172
+ "",
173
+ "=== PENDING STEPS ===",
174
+ "\n".join(f" - {s}" for s in pending) or " (all steps complete!)",
175
+ "",
176
+ "=== SCHEMA HINTS (use these field names) ===",
177
+ json.dumps(hints, indent=2) if hints else " (no drift — use canonical names)",
178
+ "",
179
+ "=== ACTIVE RULES ===",
180
+ json.dumps(obs.get("active_rules", {}), indent=2),
181
+ "",
182
+ "=== LAST MESSAGE ===",
183
+ obs["message"],
184
+ "",
185
+ "=== APP STATES ===",
186
+ ]
187
+ for app_name, view in obs.get("app_states", {}).items():
188
+ lines.append(f" [{app_name.upper()}]")
189
+ lines.append(f" {view}")
190
+ lines.append("")
191
+ return "\n".join(lines)
192
+
193
+
194
+ def parse_action(text: str):
195
+ text = re.sub(r"```(?:json)?\s*", "", text.strip()).strip()
196
+ try:
197
+ return json.loads(text)
198
+ except json.JSONDecodeError:
199
+ m = re.search(r"\{.*\}", text, re.DOTALL)
200
+ if m:
201
+ try:
202
+ return json.loads(m.group())
203
+ except Exception:
204
+ pass
205
+ return None
206
+
207
+
208
+ def build_prompt(obs_text: str, tokenizer) -> str:
209
+ messages = [{"role": "user", "content": SYSTEM_PROMPT + "\n\n---\n\n" + obs_text}]
210
+ return tokenizer.apply_chat_template(
211
+ messages, tokenize=False, add_generation_prompt=True
212
+ )
213
+
214
+
215
+ # ------------------------------------------------------------------
216
+ # Prompt dataset
217
+ # ------------------------------------------------------------------
218
+
219
+ def build_prompt_dataset(tokenizer) -> Dataset:
220
+ rows = []
221
+ print("Collecting prompts from env resets...", flush=True)
222
+ for wf in ["A", "B", "C"]:
223
+ for _ in range(N_PROMPTS_PER_WORKFLOW):
224
+ result = httpx.post(f"{ENV_URL}/reset", json={"workflow_id": wf}).json()
225
+ obs = result["observation"]
226
+ obs_text = obs_to_text(obs)
227
+ rows.append({
228
+ "prompt": build_prompt(obs_text, tokenizer),
229
+ "workflow_id": wf,
230
+ "obs_text": obs_text,
231
+ })
232
+ tlog(f"[TRAIN_CONFIG] algorithm=GRPO prompts={len(rows)} "
233
+ f"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}")
234
+ return Dataset.from_list(rows)
235
+
236
+
237
+ # ------------------------------------------------------------------
238
+ # Reward function
239
+ # ------------------------------------------------------------------
240
+
241
+ def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:
242
+ workflow_ids = kwargs.get("workflow_id", ["A"] * len(completions))
243
+ rewards = []
244
+ for completion, wf_id in zip(completions, workflow_ids):
245
+ action = parse_action(completion)
246
+ if action is None:
247
+ rewards.append(-0.1)
248
+ continue
249
+ try:
250
+ httpx.post(f"{ENV_URL}/reset", json={"workflow_id": wf_id}, timeout=10)
251
+ result = httpx.post(f"{ENV_URL}/step", json=action, timeout=10).json()
252
+ rewards.append(float(result["reward"]))
253
+ except Exception:
254
+ rewards.append(-0.1)
255
+ return rewards
256
+
257
+
258
+ # ------------------------------------------------------------------
259
+ # Episode evaluation
260
+ # ------------------------------------------------------------------
261
+
262
+ def run_episode_with_model(model, tokenizer, workflow_id: str, max_steps: int = 15) -> float:
263
+ result = httpx.post(f"{ENV_URL}/reset", json={"workflow_id": workflow_id}).json()
264
+ obs = result["observation"]
265
+ history = []
266
+
267
+ for _ in range(max_steps):
268
+ if obs["done"]:
269
+ break
270
+
271
+ obs_text = obs_to_text(obs)
272
+ history.append({"role": "user", "content": obs_text})
273
+
274
+ messages = list(history)
275
+ messages[0] = {"role": "user",
276
+ "content": SYSTEM_PROMPT + "\n\n---\n\n" + messages[0]["content"]}
277
+
278
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
279
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
280
+
281
+ with torch.no_grad():
282
+ out = model.generate(
283
+ **inputs,
284
+ max_new_tokens = 256,
285
+ temperature = 0.0,
286
+ do_sample = False,
287
+ pad_token_id = tokenizer.eos_token_id,
288
+ )
289
+ action_str = tokenizer.decode(
290
+ out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
291
+ ).strip()
292
+
293
+ history.append({"role": "assistant", "content": action_str})
294
+
295
+ action = parse_action(action_str)
296
+ if action is None:
297
+ break
298
+
299
+ result = httpx.post(f"{ENV_URL}/step", json=action).json()
300
+ obs = result["observation"]
301
+ if obs["done"]:
302
+ break
303
+
304
+ return obs.get("current_score", 0.001)
305
+
306
+
307
+ def evaluate(model, tokenizer, phase: str) -> dict:
308
+ scores = {wf: [] for wf in ["A", "B", "C"]}
309
+ tlog(f"[EVAL_START] phase={phase}")
310
+ for wf in ["A", "B", "C"]:
311
+ for ep in range(N_EVAL):
312
+ score = run_episode_with_model(model, tokenizer, wf)
313
+ scores[wf].append(score)
314
+ tlog(f"[EVAL] phase={phase} workflow={wf} episode={ep+1} score={score:.4f}")
315
+ wf_mean = np.mean(scores[wf])
316
+ tlog(f"[EVAL_WORKFLOW] phase={phase} workflow={wf} "
317
+ f"mean={wf_mean:.4f} min={min(scores[wf]):.4f} max={max(scores[wf]):.4f}")
318
+ overall = np.mean([s for v in scores.values() for s in v])
319
+ tlog(f"[EVAL_END] phase={phase} overall_mean={overall:.4f}")
320
+ return scores
321
+
322
+
323
+ # ------------------------------------------------------------------
324
+ # Plot
325
+ # ------------------------------------------------------------------
326
+
327
+ def plot_results(baseline_scores: dict, post_scores: dict) -> None:
328
+ fig = plt.figure(figsize=(14, 8), facecolor="#0f172a")
329
+ fig.suptitle("OrgOS: Before vs After GRPO Training", fontsize=15,
330
+ color="white", fontweight="bold", y=0.98)
331
+
332
+ gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
333
+ COLORS = {"before": "#f87171", "after": "#34d399", "bg": "#1e293b", "grid": "#334155"}
334
+ LABELS = {
335
+ "A": "Workflow A\nCustomer Bug Fix",
336
+ "B": "Workflow B\nEmployee Onboarding",
337
+ "C": "Workflow C\nChurn Risk Alert",
338
+ }
339
+
340
+ for col, wf in enumerate(["A", "B", "C"]):
341
+ ax = fig.add_subplot(gs[0, col])
342
+ ax.set_facecolor(COLORS["bg"])
343
+ ax.grid(color=COLORS["grid"], linewidth=0.5, alpha=0.7)
344
+ before = baseline_scores[wf]
345
+ after = post_scores[wf]
346
+ delta = np.mean(after) - np.mean(before)
347
+ ax.plot(before, color=COLORS["before"], linewidth=1.5, alpha=0.8, label="Before GRPO")
348
+ ax.plot(after, color=COLORS["after"], linewidth=1.5, alpha=0.8, label="After GRPO")
349
+ ax.axhline(np.mean(before), color=COLORS["before"], linestyle="--", linewidth=1, alpha=0.5)
350
+ ax.axhline(np.mean(after), color=COLORS["after"], linestyle="--", linewidth=1, alpha=0.5)
351
+ ax.set_title(LABELS[wf] + f"\n(Δ = {delta:+.4f})", color="white", fontsize=9)
352
+ ax.set_xlabel("Episode", color="#94a3b8", fontsize=8)
353
+ ax.set_ylabel("Final Score", color="#94a3b8", fontsize=8)
354
+ ax.tick_params(colors="#64748b", labelsize=7)
355
+ ax.set_ylim(0, 1)
356
+ ax.legend(fontsize=7, facecolor="#1e293b", labelcolor="white",
357
+ edgecolor="#475569", framealpha=0.8)
358
+ for spine in ax.spines.values():
359
+ spine.set_edgecolor("#334155")
360
+
361
+ ax_hist = fig.add_subplot(gs[1, :])
362
+ ax_hist.set_facecolor(COLORS["bg"])
363
+ ax_hist.grid(color=COLORS["grid"], linewidth=0.5, alpha=0.5, axis="x")
364
+ all_before = [s for v in baseline_scores.values() for s in v]
365
+ all_after = [s for v in post_scores.values() for s in v]
366
+ bins = np.linspace(0, 1, 25)
367
+ ax_hist.hist(all_before, bins=bins, color=COLORS["before"], alpha=0.6,
368
+ label=f"Before GRPO (mean={np.mean(all_before):.4f})", edgecolor="none")
369
+ ax_hist.hist(all_after, bins=bins, color=COLORS["after"], alpha=0.6,
370
+ label=f"After GRPO (mean={np.mean(all_after):.4f})", edgecolor="none")
371
+ ax_hist.axvline(np.mean(all_before), color=COLORS["before"], linestyle="--", linewidth=1.5)
372
+ ax_hist.axvline(np.mean(all_after), color=COLORS["after"], linestyle="--", linewidth=1.5)
373
+ ax_hist.set_title("Score Distribution Across All Workflows", color="white", fontsize=10)
374
+ ax_hist.set_xlabel("Final Score", color="#94a3b8", fontsize=9)
375
+ ax_hist.set_ylabel("Count", color="#94a3b8", fontsize=9)
376
+ ax_hist.tick_params(colors="#64748b", labelsize=8)
377
+ ax_hist.legend(fontsize=9, facecolor="#1e293b", labelcolor="white",
378
+ edgecolor="#475569", framealpha=0.9)
379
+ for spine in ax_hist.spines.values():
380
+ spine.set_edgecolor("#334155")
381
+
382
+ plt.savefig("before_after_curves.png", dpi=150, bbox_inches="tight",
383
+ facecolor="#0f172a", edgecolor="none")
384
+ plt.close()
385
+ tlog("[ARTIFACT] file=before_after_curves.png")
386
+
387
+
388
+ # ------------------------------------------------------------------
389
+ # Training callback
390
+ # ------------------------------------------------------------------
391
+
392
+ class OrgOSLogCallback(TrainerCallback):
393
+ def on_log(self, args, state, control, logs=None, **kwargs):
394
+ if not logs:
395
+ return
396
+ step = state.global_step
397
+ loss = logs.get("loss", logs.get("train_loss", "?"))
398
+ mean_reward = logs.get("reward", logs.get("mean_reward", "?"))
399
+ kl = logs.get("kl", logs.get("approx_kl", "?"))
400
+ lr_now = logs.get("learning_rate", "?")
401
+
402
+ loss_str = f"{loss:.6f}" if isinstance(loss, float) else str(loss)
403
+ reward_str = f"{mean_reward:.4f}" if isinstance(mean_reward, float) else str(mean_reward)
404
+ kl_str = f"{kl:.6f}" if isinstance(kl, float) else str(kl)
405
+ lr_str = f"{lr_now:.2e}" if isinstance(lr_now, float) else str(lr_now)
406
+
407
+ tlog(f"[TRAIN_STEP] step={step} loss={loss_str} "
408
+ f"mean_reward={reward_str} kl={kl_str} lr={lr_str}")
409
+
410
+
411
+ # ------------------------------------------------------------------
412
+ # Main
413
+ # ------------------------------------------------------------------
414
+
415
+ def main():
416
+ server_proc = start_env_server()
417
+
418
+ try:
419
+ model, tokenizer = load_model()
420
+
421
+ prompt_dataset = build_prompt_dataset(tokenizer)
422
+
423
+ # Sanity-check reward function
424
+ test_r = orgos_reward_fn(
425
+ completions = ['{"app": "zendesk", "operation": "list_tickets", "args": {"state": "new"}}',
426
+ "not json"],
427
+ prompts = ["", ""],
428
+ workflow_id = ["A", "A"],
429
+ )
430
+ tlog(f"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}")
431
+
432
+ # Baseline evaluation
433
+ FastLanguageModel.for_inference(model)
434
+ baseline_scores = evaluate(model, tokenizer, phase="baseline")
435
+ baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])
436
+
437
+ # GRPO training
438
+ model.train()
439
+ tlog(f"[TRAIN_CONFIG] epochs={NUM_EPOCHS} batch_size={BATCH_SIZE} "
440
+ f"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} "
441
+ f"temperature={TEMPERATURE} beta_kl={BETA}")
442
+
443
+ grpo_config = GRPOConfig(
444
+ output_dir = "./orgos_grpo_ckpt",
445
+ num_train_epochs = NUM_EPOCHS,
446
+ per_device_train_batch_size = BATCH_SIZE,
447
+ gradient_accumulation_steps = GRAD_ACCUM,
448
+ learning_rate = LR,
449
+ warmup_steps = 10,
450
+ logging_steps = 5,
451
+ save_steps = 100,
452
+ bf16 = torch.cuda.is_bf16_supported(),
453
+ fp16 = not torch.cuda.is_bf16_supported(),
454
+ max_grad_norm = 1.0,
455
+ num_generations = NUM_GEN,
456
+ max_new_tokens = 256,
457
+ temperature = TEMPERATURE,
458
+ beta = BETA,
459
+ report_to = "none",
460
+ seed = 42,
461
+ )
462
+
463
+ trainer = GRPOTrainer(
464
+ model = model,
465
+ args = grpo_config,
466
+ reward_funcs = orgos_reward_fn,
467
+ train_dataset = prompt_dataset,
468
+ processing_class = tokenizer,
469
+ callbacks = [OrgOSLogCallback()],
470
+ )
471
+
472
+ tlog("[TRAIN_START]")
473
+ train_result = trainer.train()
474
+ tlog(f"[TRAIN_END] total_steps={train_result.global_step} "
475
+ f"train_loss={train_result.training_loss:.6f} "
476
+ f"train_runtime_s={train_result.metrics.get('train_runtime', 0):.1f}")
477
+
478
+ # Post-training evaluation
479
+ FastLanguageModel.for_inference(model)
480
+ post_scores = evaluate(model, tokenizer, phase="post_training")
481
+ post_mean = np.mean([s for v in post_scores.values() for s in v])
482
+ improvement = post_mean - baseline_mean
483
+
484
+ tlog(
485
+ f"[TRAIN_SUMMARY] "
486
+ f"model={MODEL_NAME} algorithm=GRPO "
487
+ f"baseline_mean={baseline_mean:.4f} "
488
+ f"post_training_mean={post_mean:.4f} "
489
+ f"improvement={improvement:+.4f} "
490
+ f"workflow_A_before={np.mean(baseline_scores['A']):.4f} "
491
+ f"workflow_A_after={np.mean(post_scores['A']):.4f} "
492
+ f"workflow_B_before={np.mean(baseline_scores['B']):.4f} "
493
+ f"workflow_B_after={np.mean(post_scores['B']):.4f} "
494
+ f"workflow_C_before={np.mean(baseline_scores['C']):.4f} "
495
+ f"workflow_C_after={np.mean(post_scores['C']):.4f}"
496
+ )
497
+
498
+ # Save artifacts
499
+ plot_results(baseline_scores, post_scores)
500
+ model.save_pretrained("orgos_lora_adapter")
501
+ tokenizer.save_pretrained("orgos_lora_adapter")
502
+ tlog("[ARTIFACT] file=orgos_lora_adapter/")
503
+ tlog("[ARTIFACT] file=training_log.txt")
504
+
505
+ print(f"\nDone. Improvement: {baseline_mean:.4f} → {post_mean:.4f} ({improvement:+.4f})")
506
+
507
+ finally:
508
+ server_proc.terminate()
509
+
510
+
511
+ if __name__ == "__main__":
512
+ main()