File size: 18,035 Bytes
ee21104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d669b0f
ee21104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
from __future__ import annotations

import argparse
import json
import random
import re
from pathlib import Path
from typing import Any

import requests
from transformers import TrainerCallback


LEGAL_ACTION_TYPES = [
    "reply_email",
    "archive_email",
    "reschedule_meeting",
    "cancel_meeting",
    "complete_task",
    "delegate_task",
    "send_message",
    "do_nothing",
]

MODEL_PRESETS: dict[str, str] = {
    # Fast iteration winner preset: small, strong instruction following, QLoRA-friendly.
    "small_iter_fast": "unsloth/Qwen2.5-3B-Instruct",
    # Existing baseline used in this repo.
    "balanced_3b": "unsloth/Llama-3.2-3B-Instruct",
    # Larger option when compute budget is stable.
    "bigger_4b": "unsloth/Qwen3-4B-Instruct-2507",
}

TRAINING_PRESETS: dict[str, dict[str, float | int | str]] = {
    "hackathon_turbo": {
        "max_sft_steps": 80,
        "max_grpo_steps": 180,
        "env_reward_scale": 1.00,
        "local_reward_scale": 0.45,
        "complexity_curriculum": "easy_to_full",
        "curriculum_ramp_ratio": 0.65,
        "sft_samples": 180,
        # Optimizer / schedule knobs (stability-first for iterative winning runs)
        "sft_lr": 1.2e-5,
        "sft_grad_accum": 8,
        "grpo_lr": 3.0e-6,
        "grpo_grad_accum": 8,
        "grpo_beta": 0.08,
        "reward_ema_decay": 0.35,
    },
    # Quicker loop for smoke iterations on weaker hardware.
    "quick_smoke": {
        "max_sft_steps": 30,
        "max_grpo_steps": 80,
        "env_reward_scale": 0.95,
        "local_reward_scale": 0.35,
        "complexity_curriculum": "easy_to_full",
        "curriculum_ramp_ratio": 0.50,
        "sft_samples": 90,
        "sft_lr": 1.5e-5,
        "sft_grad_accum": 4,
        "grpo_lr": 4.0e-6,
        "grpo_grad_accum": 4,
        "grpo_beta": 0.06,
        "reward_ema_decay": 0.25,
    },
}


def _extract_briefing(reset_payload: dict[str, Any]) -> str:
    obs = reset_payload.get("observation", reset_payload)
    if isinstance(obs, dict):
        return str(obs.get("echoed_message", "")).strip()
    return ""


def _legal_action_heuristic(briefing: str) -> dict[str, Any]:
    # Minimal heuristic used only for SFT warm-start data generation.
    # Keeps the action schema valid and non-idle-biased.
    lower = briefing.lower()
    if "e01" in lower:
        return {
            "action_type": "reply_email",
            "email_id": "e01",
            "message_body": "Acknowledged. Sharing a concise update shortly.",
        }
    if "m02" in lower:
        return {
            "action_type": "reschedule_meeting",
            "meeting_id": "m02",
            "new_time": "2026-04-21T18:00:00",
            "reason": "Resolve overlap with higher priority commitments.",
        }
    if "t06" in lower:
        return {"action_type": "complete_task", "task_id": "t06"}
    return {"action_type": random.choice(LEGAL_ACTION_TYPES)}


def generate_sft_jsonl_from_env(
    env_url: str,
    out_jsonl: Path,
    samples: int = 120,
    task_id: str = "phase2_core",
) -> None:
    out_jsonl.parent.mkdir(parents=True, exist_ok=True)
    rows: list[dict[str, str]] = []
    for _ in range(samples):
        r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": task_id}, timeout=30)
        r.raise_for_status()
        payload = r.json()
        briefing = _extract_briefing(payload)
        if not briefing:
            continue
        action = _legal_action_heuristic(briefing)
        prompt = (
            "You are Ghostexec AI Chief-of-Staff.\n"
            "Output one valid GhostexecAction JSON only.\n\n"
            f"{briefing}"
        )
        rows.append({"prompt": prompt, "completion": json.dumps(action, ensure_ascii=True)})
    with out_jsonl.open("w", encoding="utf-8") as fh:
        for row in rows:
            fh.write(json.dumps(row, ensure_ascii=True) + "\n")
    print(f"Wrote {len(rows)} SFT rows to {out_jsonl}")


def run_sft_then_grpo(
    model_name: str,
    env_url: str,
    sft_jsonl: Path,
    out_dir: Path,
    env_reward_scale: float,
    local_reward_scale: float,
    max_sft_steps: int,
    max_grpo_steps: int,
    complexity_curriculum: str,
    curriculum_ramp_ratio: float,
    *,
    sft_lr: float,
    sft_grad_accum: int,
    grpo_lr: float,
    grpo_grad_accum: int,
    grpo_beta: float,
    reward_ema_decay: float,
) -> None:
    try:
        from datasets import load_dataset
        from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
        from unsloth import FastLanguageModel
    except Exception as exc:  # pragma: no cover
        raise RuntimeError(
            "Missing training deps. Install unsloth, trl, datasets, transformers before running."
        ) from exc

    out_dir.mkdir(parents=True, exist_ok=True)

    def _trainable_lora_sum_abs(model) -> float:
        total = 0.0
        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            if "lora" not in n.lower():
                continue
            total += float(p.detach().abs().sum().item())
        return total

    policy, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_name,
        max_seq_length=2048,
        dtype=None,
        load_in_4bit=True,
    )
    policy = FastLanguageModel.get_peft_model(
        policy,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=16,
        lora_dropout=0.0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
    )

    ds = load_dataset("json", data_files=str(sft_jsonl), split="train")
    sft_cfg = SFTConfig(
        output_dir=str(out_dir / "sft"),
        max_steps=max_sft_steps,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=sft_grad_accum,
        learning_rate=sft_lr,
        lr_scheduler_type="cosine",
        warmup_ratio=0.06,
        max_grad_norm=1.0,
        adam_beta1=0.9,
        adam_beta2=0.95,
        logging_steps=5,
        save_steps=max(10, max_sft_steps),
        report_to=[],
    )
    sft_trainer = SFTTrainer(
        model=policy,
        tokenizer=tokenizer,
        train_dataset=ds,
        args=sft_cfg,
        dataset_text_field="prompt",
        formatting_func=lambda ex: [f"{p}\n\n{c}" for p, c in zip(ex["prompt"], ex["completion"])],
    )
    sft_before = _trainable_lora_sum_abs(policy)
    sft_trainer.train()
    sft_after = _trainable_lora_sum_abs(sft_trainer.model)
    sft_delta = abs(sft_after - sft_before)
    print(f"SFT LoRA delta(abs-sum): {sft_delta:.6f}")
    if sft_delta <= 1e-6:
        raise RuntimeError("SFT appears not to have updated LoRA weights (delta too small).")
    sft_dir = out_dir / "sft_adapter"
    sft_trainer.model.save_pretrained(sft_dir)
    tokenizer.save_pretrained(sft_dir)
    print(f"SFT complete. Adapter saved: {sft_dir}")

    def _extract_json(text: str) -> dict[str, Any] | None:
        m = re.search(r"\{.*\}", text, flags=re.S)
        if not m:
            return None
        try:
            obj = json.loads(m.group(0))
        except Exception:
            return None
        return obj if isinstance(obj, dict) else None

    def _env_step_reward_from_completion(text: str) -> float:
        payload = _extract_json(text)
        if payload is None:
            return -0.25
        payload.setdefault("action_type", "do_nothing")
        try:
            r = requests.post(f"{env_url.rstrip('/')}/reset", json={"task_id": "phase2_core"}, timeout=30)
            r.raise_for_status()
            s = requests.post(
                f"{env_url.rstrip('/')}/step",
                json={"action": payload},
                timeout=30,
            )
            s.raise_for_status()
            raw = s.json()
        except Exception:
            return 0.0
        rew = raw.get("reward")
        if rew is None and isinstance(raw.get("observation"), dict):
            rew = raw["observation"].get("reward", 0.0)
        try:
            return float(rew)
        except Exception:
            return 0.0

    progress = {"step": 0, "total": max(1, max_grpo_steps)}
    reward_ema_state = {"env": None}

    class _ProgressCallback(TrainerCallback):
        def on_step_end(self, args, state, control, **kwargs):  # type: ignore[override]
            progress["step"] = int(getattr(state, "global_step", progress["step"]))
            return control

    def _progress_frac() -> float:
        return min(1.0, progress["step"] / progress["total"])

    def _curriculum_phase_weight() -> float:
        frac = _progress_frac()
        ramp = max(0.05, min(1.0, curriculum_ramp_ratio))
        if complexity_curriculum == "off":
            return 1.0
        # easy_to_full: start with strong scaffold guidance, then smoothly
        # transition to full env-dominant optimization.
        if frac >= ramp:
            return 0.0
        return max(0.0, 1.0 - (frac / ramp))

    def _annealed_local_scale() -> float:
        frac = _progress_frac()
        base = local_reward_scale * (1.20 - 0.70 * frac)
        return base * (1.0 + 0.70 * _curriculum_phase_weight())

    def _annealed_env_scale() -> float:
        w = _curriculum_phase_weight()
        # Slightly downweight env reward in early easy phase to reduce variance,
        # then recover to full strength by the end of ramp.
        return env_reward_scale * (1.0 - 0.30 * w)

    def env_reward(completions, **_):
        scale = _annealed_env_scale()
        raw = [scale * _env_step_reward_from_completion(str(c)) for c in completions]
        if reward_ema_decay <= 0.0:
            return raw
        batch_mean = sum(raw) / max(len(raw), 1)
        prev = reward_ema_state["env"]
        d = max(0.0, min(1.0, reward_ema_decay))
        if prev is None:
            smoothed_mean = batch_mean
        else:
            smoothed_mean = (1.0 - d) * prev + d * batch_mean
        reward_ema_state["env"] = smoothed_mean
        delta = smoothed_mean - batch_mean
        return [r + delta for r in raw]

    def format_reward(completions, **_):
        scale = _annealed_local_scale()
        outs: list[float] = []
        for c in completions:
            txt = str(c).strip()
            obj = _extract_json(txt)
            if obj is None:
                outs.append(-0.20 * scale)
                continue
            if obj.get("action_type") not in LEGAL_ACTION_TYPES:
                outs.append(-0.20 * scale)
                continue
            # Encourage concise, parseable schema-correct JSON.
            length_pen = -0.04 * scale if len(txt) > 500 else 0.0
            outs.append(0.12 * scale + length_pen)
        return outs

    def semantic_action_reward(completions, prompts=None, **_):
        scale = _annealed_local_scale()
        outs: list[float] = []
        for i, c in enumerate(completions):
            obj = _extract_json(str(c))
            if obj is None:
                outs.append(-0.10 * scale)
                continue
            at = str(obj.get("action_type", ""))
            ptxt = str(prompts[i] if prompts and i < len(prompts) else "").lower()
            bonus = 0.0
            if "critical" in ptxt and at == "reply_email":
                bonus += 0.08
            if "clash" in ptxt and at in ("reschedule_meeting", "cancel_meeting"):
                bonus += 0.08
            if ("overdue" in ptxt or "due soon" in ptxt) and at in ("complete_task", "delegate_task"):
                bonus += 0.08
            outs.append(scale * bonus)
        return outs

    def anti_idle_reward(completions, **_):
        scale = _annealed_local_scale()
        outs = []
        for c in completions:
            txt = str(c).lower()
            outs.append((-0.20 if "do_nothing" in txt else 0.02) * scale)
        return outs

    grpo_cfg = GRPOConfig(
        output_dir=str(out_dir / "grpo"),
        learning_rate=grpo_lr,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=grpo_grad_accum,
        max_steps=max_grpo_steps,
        logging_steps=5,
        num_generations=2,
        beta=grpo_beta,
        lr_scheduler_type="cosine",
        warmup_ratio=0.06,
        max_grad_norm=1.0,
        adam_beta1=0.9,
        adam_beta2=0.95,
        report_to=[],
    )
    grpo_trainer = GRPOTrainer(
        model=sft_trainer.model,
        processing_class=tokenizer,
        reward_funcs=[env_reward, format_reward, semantic_action_reward, anti_idle_reward],
        train_dataset=ds,
        args=grpo_cfg,
        callbacks=[_ProgressCallback()],
    )
    grpo_before = _trainable_lora_sum_abs(sft_trainer.model)
    grpo_trainer.train()
    progress["step"] = progress["total"]
    grpo_after = _trainable_lora_sum_abs(grpo_trainer.model)
    grpo_delta = abs(grpo_after - grpo_before)
    print(f"GRPO LoRA delta(abs-sum): {grpo_delta:.6f}")
    if grpo_delta <= 1e-6:
        raise RuntimeError("GRPO appears not to have updated LoRA weights (delta too small).")
    final_dir = out_dir / "grpo_adapter"
    grpo_trainer.model.save_pretrained(final_dir)
    tokenizer.save_pretrained(final_dir)
    print(f"GRPO complete. Adapter saved: {final_dir}")


def main() -> None:
    parser = argparse.ArgumentParser(description="Run SFT warmup before GRPO.")
    parser.add_argument(
        "--model-name",
        default="",
        help="Optional explicit model id. If omitted, --model-preset is used.",
    )
    parser.add_argument(
        "--model-preset",
        choices=sorted(MODEL_PRESETS.keys()),
        default="small_iter_fast",
        help="Recommended compute-aware preset. small_iter_fast is best for iteration speed.",
    )
    parser.add_argument(
        "--training-preset",
        choices=sorted(TRAINING_PRESETS.keys()),
        default="hackathon_turbo",
        help="Compute-aware run preset. hackathon_turbo is best default for iterative winning loops.",
    )
    parser.add_argument("--env-url", default="http://127.0.0.1:8000")
    parser.add_argument("--sft-jsonl", type=Path, default=Path("outputs/sft_from_env.jsonl"))
    parser.add_argument("--out-dir", type=Path, default=Path("outputs/train_runs/sft_then_grpo"))
    parser.add_argument("--generate-sft-from-env", action="store_true")
    parser.add_argument("--sft-samples", type=int, default=120)
    parser.add_argument("--max-sft-steps", type=int, default=60)
    parser.add_argument("--max-grpo-steps", type=int, default=120)
    parser.add_argument("--env-reward-scale", type=float, default=1.0)
    parser.add_argument("--local-reward-scale", type=float, default=0.35)
    parser.add_argument(
        "--complexity-curriculum",
        choices=["off", "easy_to_full"],
        default="easy_to_full",
        help="Reward curriculum: easy_to_full starts with stronger local scaffold and anneals to env-dominant.",
    )
    parser.add_argument(
        "--curriculum-ramp-ratio",
        type=float,
        default=0.60,
        help="Fraction of GRPO steps used to ramp from easy scaffold to full env weighting.",
    )
    parser.add_argument(
        "--reward-ema-decay",
        type=float,
        default=-1.0,
        help="EMA decay in [0,1] for env reward smoothing; -1 uses training preset default.",
    )
    args = parser.parse_args()
    model_name = args.model_name.strip() or MODEL_PRESETS[args.model_preset]
    p = TRAINING_PRESETS[args.training_preset]
    max_sft_steps = int(p["max_sft_steps"])
    max_grpo_steps = int(p["max_grpo_steps"])
    env_reward_scale = float(p["env_reward_scale"])
    local_reward_scale = float(p["local_reward_scale"])
    complexity_curriculum = str(p["complexity_curriculum"])
    curriculum_ramp_ratio = float(p["curriculum_ramp_ratio"])
    sft_samples = int(p["sft_samples"])
    sft_lr = float(p["sft_lr"])
    sft_grad_accum = int(p["sft_grad_accum"])
    grpo_lr = float(p["grpo_lr"])
    grpo_grad_accum = int(p["grpo_grad_accum"])
    grpo_beta = float(p["grpo_beta"])
    reward_ema_decay = float(p["reward_ema_decay"])
    if args.max_sft_steps != 60:
        max_sft_steps = args.max_sft_steps
    if args.max_grpo_steps != 120:
        max_grpo_steps = args.max_grpo_steps
    if args.env_reward_scale != 1.0:
        env_reward_scale = args.env_reward_scale
    if args.local_reward_scale != 0.35:
        local_reward_scale = args.local_reward_scale
    if args.complexity_curriculum != "easy_to_full":
        complexity_curriculum = args.complexity_curriculum
    if args.curriculum_ramp_ratio != 0.60:
        curriculum_ramp_ratio = args.curriculum_ramp_ratio
    if args.sft_samples != 120:
        sft_samples = args.sft_samples
    if args.reward_ema_decay >= 0.0:
        reward_ema_decay = float(args.reward_ema_decay)
    print(f"Model preset: {args.model_preset} -> {model_name}")
    print(
        "Training preset:"
        f" {args.training_preset} -> sft={max_sft_steps}, grpo={max_grpo_steps},"
        f" env_scale={env_reward_scale}, local_scale={local_reward_scale},"
        f" curriculum={complexity_curriculum}, ramp={curriculum_ramp_ratio}"
    )

    if args.generate_sft_from_env or not args.sft_jsonl.exists():
        generate_sft_jsonl_from_env(
            env_url=args.env_url,
            out_jsonl=args.sft_jsonl,
            samples=sft_samples,
            task_id="phase2_core",
        )

    run_sft_then_grpo(
        model_name=model_name,
        env_url=args.env_url,
        sft_jsonl=args.sft_jsonl,
        out_dir=args.out_dir,
        env_reward_scale=env_reward_scale,
        local_reward_scale=local_reward_scale,
        max_sft_steps=max_sft_steps,
        max_grpo_steps=max_grpo_steps,
        complexity_curriculum=complexity_curriculum,
        curriculum_ramp_ratio=curriculum_ramp_ratio,
        sft_lr=sft_lr,
        sft_grad_accum=sft_grad_accum,
        grpo_lr=grpo_lr,
        grpo_grad_accum=grpo_grad_accum,
        grpo_beta=grpo_beta,
        reward_ema_decay=reward_ema_decay,
    )


if __name__ == "__main__":
    main()