File size: 23,329 Bytes
a55c81d
 
bdec91d
a55c81d
ea6fe4e
a55c81d
 
b92ad01
 
 
 
a55c81d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92ad01
663b8db
a55c81d
 
 
b92ad01
 
 
a55c81d
ba8df98
a55c81d
 
5d0b2d4
2bfaf77
 
 
5d0b2d4
b8172c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d0b2d4
 
 
6747185
 
18b4e8a
5d0b2d4
 
 
 
 
663b8db
5d0b2d4
2bfaf77
 
5d0b2d4
a55c81d
b92ad01
 
 
5d0b2d4
b92ad01
c325ad7
bdec91d
c325ad7
 
b92ad01
a55c81d
663b8db
 
 
 
 
 
 
 
 
 
 
 
 
77156dd
663b8db
 
a55c81d
 
 
 
 
a2fa47a
a55c81d
 
 
 
 
b92ad01
a55c81d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba8df98
a55c81d
1128de1
3152fa9
a55c81d
 
 
 
 
 
 
 
 
 
 
 
b92ad01
 
 
 
 
 
 
 
 
 
e67270e
b92ad01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6fe4e
 
 
 
 
 
 
 
 
 
 
 
 
2b499e7
 
 
 
 
 
 
 
 
 
 
 
ea6fe4e
 
2b499e7
 
 
 
bdec91d
 
ea6fe4e
2b499e7
 
 
ea6fe4e
 
 
 
 
 
 
a55c81d
bdec91d
 
 
c325ad7
 
 
 
 
 
 
 
ea6fe4e
a55c81d
c325ad7
 
 
ea6fe4e
 
a55c81d
 
c325ad7
a55c81d
c325ad7
a55c81d
c325ad7
 
 
a55c81d
 
dc8001b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55c81d
 
 
 
 
 
 
 
 
4668456
 
a55c81d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c325ad7
a55c81d
 
 
 
 
dc8001b
a55c81d
b8172c5
a55c81d
 
 
8f291e0
a55c81d
 
 
 
 
 
 
 
 
 
 
c325ad7
a55c81d
 
 
 
4668456
a55c81d
 
 
 
 
ea6fe4e
 
73f957d
a55c81d
73f957d
ea6fe4e
 
18b4e8a
73f957d
ea6fe4e
a55c81d
 
 
 
 
 
 
 
18b4e8a
a55c81d
 
 
 
 
 
ba8df98
a55c81d
 
 
 
 
 
 
 
 
 
 
 
 
c325ad7
a55c81d
 
 
 
 
dc8001b
a55c81d
b8172c5
a55c81d
 
 
2b1fbf3
a55c81d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c325ad7
 
9864e61
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
"""
AgentDebuggerEnv β€” GRPO Training Script
Model: Qwen2.5-Coder-7B-Instruct (float16/bfloat16 + LoRA, no quantization)
Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
GPU: auto-detected at runtime (A100/H100 β†’ bfloat16+large batch, T4/V100 β†’ float16+small batch)

Usage:
  # Local reward sanity-check (no GPU, no model loading):
  python training/train_grpo.py --test-local

  # Test run (Colab/GPU, 10 steps):
  python training/train_grpo.py --test

  # Full training run:
  python training/train_grpo.py

  # Resume from checkpoint:
  python training/train_grpo.py --resume ./checkpoints/checkpoint-400
"""

import os
import sys
import json
import argparse
import random
import subprocess
import tempfile
import shutil
from importlib import metadata

# ── Parse args ────────────────────────────────────────────────────────────────
parser = argparse.ArgumentParser()
parser.add_argument("--test", action="store_true", help="Run 10 steps for testing (Colab/GPU)")
parser.add_argument("--test-local", action="store_true", dest="test_local",
                    help="Sanity-check reward function locally without any model or GPU")
parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
parser.add_argument("--max_steps", type=int, default=500)
args = parser.parse_args()

# ── Runtime dependency install ─────────────────────────────────────────────────
# requirements.txt only has torch (too large to install at runtime).
# Everything else is installed here, after gradio is already up.
# NOTE: mergekit intentionally excluded β€” conflicts with accelerate/peft/trl.
if not args.test_local:
    # ── Ensure CUDA-enabled torch is present before anything else imports it ──
    # The default PyPI torch wheel is CPU-only. We must install from the
    # PyTorch CUDA index so that torch.cuda.is_available() returns True and
    # device_map="auto" maps the model to GPU, not RAM.
    import importlib.util, importlib
    _needs_cuda_torch = True
    if importlib.util.find_spec("torch") is not None:
        import torch as _t
        if _t.cuda.is_available():
            _needs_cuda_torch = False
        del _t
    if _needs_cuda_torch:
        print("Installing CUDA-enabled torch (cu121)...", flush=True)
        _r = os.system(
            f"{sys.executable} -m pip install -q --no-cache-dir "
            "torch --index-url https://download.pytorch.org/whl/cu121"
        )
        if _r != 0:
            print("ERROR: CUDA torch install failed.", flush=True)
            sys.exit(1)
        print("CUDA torch installed.", flush=True)

    _TRAIN_DEPS = [
        "wandb==0.18.7",
        "datasets==3.0.2",
        "transformers==4.48.3",
        "accelerate==1.0.1",
        "trl==0.15.2",
        "peft==0.13.2",
    ]
    print("Installing training dependencies...", flush=True)
    ret = os.system(
        f"{sys.executable} -m pip install -q --no-cache-dir " + " ".join(f'"{d}"' for d in _TRAIN_DEPS)
    )
    if ret != 0:
        print("ERROR: pip install failed. Training cannot continue.", flush=True)
        sys.exit(1)
    print("Dependencies installed.", flush=True)

# ── GPU/training imports (skipped in --test-local mode) ───────────────────────
if not args.test_local:
    import torch
    import wandb
    from datasets import Dataset
    from transformers import (
        AutoModelForCausalLM, AutoTokenizer, TrainerCallback
    )
    from peft import get_peft_model, LoraConfig, TaskType
    from trl import GRPOTrainer, GRPOConfig

    def _pkg_ver(name: str) -> str:
        try:
            return metadata.version(name)
        except metadata.PackageNotFoundError:
            return "not-installed"

    print(
        "Runtime package versions | "
        f"python={sys.version.split()[0]} "
        f"torch={_pkg_ver('torch')} "
        f"transformers={_pkg_ver('transformers')} "
        f"trl={_pkg_ver('trl')} "
        f"accelerate={_pkg_ver('accelerate')} "
        f"peft={_pkg_ver('peft')}"
    )

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from server.reward_calculator import DebugRewardCalculator
from server.models import parse_agent_output

# ── Configuration ─────────────────────────────────────────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-Coder-3B-Instruct"
HF_REPO = "shashaank0707/AgentDebugger-trained"
MAX_STEPS = 10 if args.test else args.max_steps
CHECKPOINT_DIR = "./checkpoints"

# W&B β€” optional but strongly recommended for judging
WANDB_API_KEY = os.environ.get("WANDB_API_KEY", "") if not args.test_local else ""
if WANDB_API_KEY:
    wandb.init(
        project="AgentDebuggerEnv",
        name=f"grpo-qwen-7b-{'test' if args.test else 'full'}",
        config={
            "model": MODEL_NAME,
            "algorithm": "GRPO",
            "curriculum": "tier1->tier2->tier3",
            "max_steps": MAX_STEPS,
            "reward_components": ["format", "hypothesis", "localization", "fix", "semantic", "efficiency"],
            "paper_citations": ["Masud et al. 2026", "Ibrahim et al. 2024"],
        }
    )

# ── System prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """You are an expert Python debugger. You reason through bugs systematically.

You MUST respond in EXACTLY this format β€” no exceptions, no extra text:

OBSERVATION: [Specific observations about the code and error. Reference exact line numbers.]
HYPOTHESIS: [Your theory about the root cause. Must be at least 2 sentences. Reference specific variable names, operators, or logic.]
CONFIDENCE: [low | medium | high]
ACTION: [One of: inspect_lines | run_tests | propose_fix | request_context | give_up]
DETAIL: [For propose_fix: the complete corrected function code. For inspect_lines: line numbers. For others: specific details.]

Rules:
- Never omit any field
- HYPOTHESIS must explain WHY the bug causes the observed failure
- If proposing a fix, DETAIL must contain the complete function, not just the changed line
- Give up only if you have exhausted all reasonable hypotheses"""

# ── Load bugs ─────────────────────────────────────────────────────────────────
def load_bugs(tier: int) -> list[dict]:
    path = f"data/bugs_tier{tier}.jsonl"
    if not os.path.exists(path):
        print(f"WARNING: {path} not found. Run data/generate_bugs.py first.")
        return []
    with open(path) as f:
        return [json.loads(line) for line in f if line.strip()]

def get_bugs_for_step(step: int) -> list[dict]:
    tier1 = load_bugs(1)
    if step < 150:
        return tier1
    elif step < 600:
        return tier1 + load_bugs(2)
    return tier1 + load_bugs(2) + load_bugs(3)

def bug_to_prompt(bug: dict) -> str:
    return (
        f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
        f"<|im_start|>user\n"
        f"Debug this Python function:\n\n```python\n{bug['buggy_code']}\n```\n\n"
        f"Initial failure: {bug.get('initial_error', 'Some tests are failing.')}\n"
        f"<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )

def _run_fix(proposed_code: str, bug: dict) -> dict:
    """Safely run proposed fix with subprocess timeout."""
    test_cases = bug.get("test_cases", [])
    func_name = bug.get("function_name", "")
    if not proposed_code or not test_cases or not func_name:
        return {"passed": 0, "failed": 0, "total": len(test_cases), "newly_broken": 0}

    passed = 0
    for test in test_cases:
        inp = test["input"]
        args_str = ", ".join(repr(x) for x in inp)
        script = (
            f"{proposed_code}\n"
            f"try:\n"
            f"    r={func_name}({args_str})\n"
            f"    print('PASS' if r=={repr(test['expected_output'])} else 'FAIL')\n"
            f"except Exception as e:\n"
            f"    print(f'ERROR: {{e}}')\n"
        )
        try:
            with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
                f.write(script)
                fname = f.name
            python = shutil.which("python3") or shutil.which("python") or sys.executable
            r = subprocess.run([python, fname], capture_output=True, text=True, timeout=5)
            os.unlink(fname)
            if "PASS" in r.stdout:
                passed += 1
        except Exception:
            pass

    return {"passed": passed, "failed": len(test_cases) - passed, "total": len(test_cases), "newly_broken": 0}

# ── Mock completions for --test-local ─────────────────────────────────────────
MOCK_GOOD = """
OBSERVATION: The loop condition on line 4 uses <= instead of 
HYPOTHESIS: This causes an off-by-one error because Python lists are 
0-indexed, so the last valid index is len(arr)-1 not len(arr)
CONFIDENCE: high
ACTION: propose_fix
DETAIL: def binary_search(arr, target):
    left, right = 0, len(arr) - 1
    while left < right:
        mid = (left + right) // 2
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            left = mid + 1
        else:
            right = mid - 1
    return -1
"""

MOCK_BAD = """
I think there might be a bug somewhere in the code.
Let me try fixing it.
"""

# ── --test-local: reward sanity-check without any model ───────────────────────
if args.test_local:
    print("=" * 60)
    print("LOCAL TEST MODE β€” no model loaded, testing reward function only")
    print("=" * 60)

    bugs = load_bugs(1)
    if not bugs:
        print("ERROR: No bugs found in data/bugs_tier1.jsonl. Run data/generate_bugs.py first.")
        sys.exit(1)

    bug = bugs[0]
    print(f"\nUsing bug: {bug.get('function_name', '?')} β€” {bug.get('bug_type', '?')}\n")

    calculator_local = DebugRewardCalculator()

    def _score(label: str, completion: str) -> float:
        try:
            agent_output = parse_agent_output(completion)
            test_results = {"passed": 0, "failed": 0, "total": 0, "newly_broken": 0}
            if agent_output.action == "propose_fix":
                test_results = _run_fix(agent_output.detail, bug)
            breakdown = calculator_local.compute_turn_reward(
                agent_output=agent_output,
                ground_truth={
                    "bug_function": bug.get("bug_location", {}).get("function", ""),
                    "bug_line": bug.get("bug_location", {}).get("line_start", -1),
                    "bug_type": bug.get("bug_type", ""),
                    "canonical_fix_code": bug.get("original_code", ""),
                },
                test_results=test_results,
                turn_number=0,
            )
            print(f"--- {label} reward breakdown ---")
            for field, value in breakdown.__dict__.items():
                print(f"  {field}: {value}")
            print(f"  TOTAL: {breakdown.total}\n")
            return breakdown.total
        except Exception as e:
            print(f"Reward error for {label}: {e}")
            return -0.3

    good_score = _score("MOCK_GOOD", MOCK_GOOD)
    bad_score = _score("MOCK_BAD", MOCK_BAD)

    print(f"MOCK_GOOD score: {good_score:.4f}")
    print(f"MOCK_BAD  score: {bad_score:.4f}")

    assert good_score > bad_score, (
        f"ASSERTION FAILED: MOCK_GOOD ({good_score:.4f}) should be > MOCK_BAD ({bad_score:.4f})"
    )
    print("\nLOCAL TEST PASSED")
    sys.exit(0)

# ── Auto-detect GPU and set optimal config ────────────────────────────────────
_gpu_vram_gb = 0
_is_ampere_plus = False  # A100/H100 support bfloat16 natively (compute cap >= 8.0)
if torch.cuda.is_available():
    _props = torch.cuda.get_device_properties(0)
    _gpu_vram_gb = _props.total_memory / 1e9
    _is_ampere_plus = _props.major >= 8
    print(f"GPU: {_props.name} | VRAM: {_gpu_vram_gb:.1f}GB | "
          f"Compute cap: {_props.major}.{_props.minor} | "
          f"bfloat16: {'yes' if _is_ampere_plus else 'no'}")

COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16

# Scale batch/generation config to available VRAM.
# GRPO constraint: per_device_train_batch_size % num_generations == 0
if _gpu_vram_gb >= 70:          # A100 80GB
    _batch       = 8
    _grad_accum  = 1            # effective batch = 8
    _num_gen     = 8            # 8 % 8 == 0
    _max_comp    = 256
    _lora_r      = 16
elif _gpu_vram_gb >= 40:        # A100 40GB
    _batch       = 4
    _grad_accum  = 2            # effective batch = 8
    _num_gen     = 4            # 4 % 4 == 0
    _max_comp    = 256
    _lora_r      = 16
elif _gpu_vram_gb >= 20:        # A10G 24GB / V100 32GB
    _batch       = 2
    _grad_accum  = 4
    _num_gen     = 2            # 2 % 2 == 0
    _max_comp    = 192
    _lora_r      = 8
else:                           # T4 15GB / anything smaller
    _batch       = 2
    _grad_accum  = 4
    _num_gen     = 2            # 2 % 2 == 0
    _max_comp    = 160
    _lora_r      = 8

print(f"Training config: batch={_batch} grad_accum={_grad_accum} "
      f"num_gen={_num_gen} max_comp={_max_comp} lora_r={_lora_r} "
      f"dtype={COMPUTE_DTYPE}")

# ── Load model ────────────────────────────────────────────────────────────────
# Load in native float16/bfloat16 β€” no bitsandbytes needed.
# A10G (24GB) fits Qwen2.5-7B in float16 (~14GB) with room for LoRA + activations.
print(f"Loading {MODEL_NAME} in {COMPUTE_DTYPE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=COMPUTE_DTYPE,
)
model.config.use_cache = False

lora_config = LoraConfig(
    r=_lora_r,
    lora_alpha=_lora_r * 2,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")

# ── Runtime device selection ──────────────────────────────────────────────────
def _select_runtime_device(model) -> str:
    """
    Pick the safest generation device without forcing CUDA init on broken drivers.
    """
    def _cuda_usable() -> bool:
        try:
            if not torch.cuda.is_available():
                return False
            # Force lightweight CUDA init probe.
            _ = torch.zeros(1, device="cuda")
            return True
        except Exception as e:
            print(f"WARNING: CUDA initialization failed ({e}). Falling back to CPU.")
            return False

    # Prefer model's current device when available.
    try:
        model_device = str(next(model.parameters()).device)
        if model_device.startswith("cuda") and not _cuda_usable():
            return "cpu"
        return model_device
    except Exception:
        pass

    # Fallback to torch capability checks.
    if _cuda_usable():
        return "cuda"
    return "cpu"


RUNTIME_DEVICE = _select_runtime_device(model)
print(f"Using generation/training runtime device: {RUNTIME_DEVICE}")

# ── Reward function ───────────────────────────────────────────────────────────
calculator = DebugRewardCalculator()

def reward_fn(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
    """
    GRPO reward function. Called on groups of completions for the same prompt.
    GRPO learns from RELATIVE differences within each group.
    """
    rewards = []
    bugs_raw = kwargs.get("bug_metadata", [{}] * len(completions))
    bugs = [json.loads(b) if isinstance(b, str) else b for b in bugs_raw]

    for completion, bug in zip(completions, bugs):
        try:
            agent_output = parse_agent_output(completion)

            # Run fix if agent proposes one
            test_results = {"passed": 0, "failed": 0, "total": 0, "newly_broken": 0}
            if agent_output.action == "propose_fix" and bug:
                test_results = _run_fix(agent_output.detail, bug)

            breakdown = calculator.compute_turn_reward(
                agent_output=agent_output,
                ground_truth={
                    "bug_function": bug.get("bug_location", {}).get("function", ""),
                    "bug_line": bug.get("bug_location", {}).get("line_start", -1),
                    "bug_type": bug.get("bug_type", ""),
                    "canonical_fix_code": bug.get("original_code", ""),
                },
                test_results=test_results,
                turn_number=0,
            )

            if WANDB_API_KEY:
                wandb.log({k: v for k, v in breakdown.__dict__.items()})

            rewards.append(breakdown.total)

        except Exception as e:
            print(f"Reward error: {e}")
            rewards.append(-0.3)

    return rewards

# ── Baseline evaluation (run BEFORE training) ─────────────────────────────────
def run_baseline(n: int = 20) -> dict:
    print("\nRunning baseline evaluation on UNTRAINED model...")
    model.eval()
    bugs = load_bugs(1)[:n]
    rewards = []
    solved = 0
    for bug in bugs:
        prompt = bug_to_prompt(bug)
        inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
        with torch.no_grad():
            out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
        completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
        r = reward_fn([completion], [prompt], bug_metadata=[bug])
        rewards.append(r[0])
        if r[0] > 0.20:
            solved += 1

    result = {"solve_rate": solved / max(len(bugs), 1), "avg_reward": sum(rewards) / max(len(rewards), 1), "rewards": rewards}
    with open("baseline_results.json", "w") as f:
        json.dump(result, f)
    print(f"Baseline: solve_rate={result['solve_rate']:.1%}, avg_reward={result['avg_reward']:.3f}")
    if WANDB_API_KEY:
        wandb.log({"baseline/solve_rate": result["solve_rate"], "baseline/avg_reward": result["avg_reward"]})
    return result

baseline = run_baseline()
model.train()

# ── Build initial dataset ─────────────────────────────────────────────────────
def make_dataset(step: int) -> Dataset:
    bugs = get_bugs_for_step(step)
    return Dataset.from_list([{"prompt": bug_to_prompt(b), "bug_metadata": json.dumps(b)} for b in bugs])

# ── Training config ───────────────────────────────────────────────────────────
config = GRPOConfig(
    output_dir=CHECKPOINT_DIR,
    max_steps=MAX_STEPS,
    per_device_train_batch_size=_batch,
    gradient_accumulation_steps=_grad_accum,
    learning_rate=2e-5,
    lr_scheduler_type="cosine",
    warmup_steps=10 if args.test else 30,
    num_generations=_num_gen,
    max_completion_length=_max_comp,
    temperature=0.9,
    logging_steps=5,
    save_steps=50,
    report_to="wandb" if WANDB_API_KEY else "none",
)

trainer = GRPOTrainer(
    model=model,
    args=config,
    train_dataset=make_dataset(0),
    reward_funcs=reward_fn,
    processing_class=tokenizer,
)

# ── Curriculum callback ───────────────────────────────────────────────────────
class CurriculumCallback(TrainerCallback):
    def on_step_end(self, args, state, control, **kwargs):
        step = state.global_step
        if step in [150, 350]:
            trainer.train_dataset = make_dataset(step)
            print(f"\nCurriculum advanced at step {step}!")
            if WANDB_API_KEY:
                wandb.log({"curriculum/step": step})

trainer.add_callback(CurriculumCallback())

# ── Train ─────────────────────────────────────────────────────────────────────
print(f"\nStarting GRPO training. Max steps: {MAX_STEPS}")
print(f"Baseline solve rate: {baseline['solve_rate']:.1%} β€” target: >60% after training")
trainer.train(resume_from_checkpoint=args.resume)

# ── Post-training evaluation ──────────────────────────────────────────────────
model.eval()
bugs = load_bugs(1)[:20]
post_rewards = []
post_solved = 0
for bug in bugs:
    prompt = bug_to_prompt(bug)
    inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
    completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    r = reward_fn([completion], [prompt], bug_metadata=[bug])
    post_rewards.append(r[0])
    if r[0] > 0.20:
        post_solved += 1

post_solve_rate = post_solved / max(len(bugs), 1)
print(f"\n{'='*60}")
print(f"RESULTS:")
print(f"Before training: {baseline['solve_rate']:.1%} solve rate")
print(f"After training:  {post_solve_rate:.1%} solve rate")
print(f"Improvement:     +{post_solve_rate - baseline['solve_rate']:.1%}")
print(f"{'='*60}")

if WANDB_API_KEY:
    wandb.log({"final/solve_rate": post_solve_rate, "final/improvement": post_solve_rate - baseline["solve_rate"]})
    wandb.finish()

# ── Save and push ─────────────────────────────────────────────────────────────
model.save_pretrained("./final_model")
tokenizer.save_pretrained("./final_model")
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN and not args.test:
    model.push_to_hub(HF_REPO, token=HF_TOKEN, private=True)
    tokenizer.push_to_hub(HF_REPO, token=HF_TOKEN, private=True)
    print(f"Pushed to https://huggingface.co/{HF_REPO}")