File size: 17,999 Bytes
4398db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd77971
4398db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""

real_model_eval.py — Genuine before/after component evaluation using the actual models.



BEFORE: base Qwen/Qwen2.5-0.5B-Instruct (no fine-tuning)

AFTER:  AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct (GRPO fine-tuned)



Rewards come from the live environment via POST /reset + /step (MR-2 compliant).

This replaces the scripted agent eval with real model outputs.

"""

import json
import os
import re
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean

import requests
import torch

sys.path.insert(0, ".")
from server.calibration_grader import CALIBRATION_MATRIX

ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:7861")
BASE_MODEL   = "Qwen/Qwen2.5-0.5B-Instruct"
TRAINED_MODEL = "AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct"

EVAL_TASKS   = ["clean_claim", "contradictory_claim", "distribution_shift_claim"]
SEEDS        = [7, 42]          # 2 seeds per task = 6 episodes each pass (fast but real)

SYSTEM = (
    "You are an expert insurance fraud investigator.\n"
    "Analyze the claim and respond EXACTLY in this format:\n"
    "DECISION: <approve_claim|deny_claim|escalate_to_human>\n"
    "CONFIDENCE: <HIGH|MED|LOW>\n"
    "REASON: <one sentence citing specific evidence>\n\n"
    "HIGH = certain. MED = likely but some doubt. LOW = ambiguous, expert needed.\n"
    "WARNING: HIGH confidence on a wrong answer is the worst possible outcome."
)

DECISION_RE   = re.compile(r"DECISION:\s*(approve_claim|deny_claim|escalate_to_human)", re.I)
CONFIDENCE_RE = re.compile(r"CONFIDENCE:\s*(HIGH|MED|LOW)", re.I)
REASON_RE     = re.compile(r"REASON:\s*(.*)", re.I | re.S)


def _parse(text):
    dm = DECISION_RE.search(text or "")
    cm = CONFIDENCE_RE.search(text or "")
    rm = REASON_RE.search(text or "")
    return (
        dm.group(1).lower() if dm else None,
        cm.group(1).upper() if cm else None,
        (rm.group(1).strip()[:200] if rm else ""),
    )


def load_model(model_id, label):
    print(f"\nLoading {label}: {model_id} ...")
    t0 = time.time()
    try:
        from unsloth import FastLanguageModel
        model, tok = FastLanguageModel.from_pretrained(
            model_name=model_id,
            max_seq_length=512,
            dtype=None,
            load_in_4bit=True,
        )
        FastLanguageModel.for_inference(model)
        print(f"  Loaded via Unsloth (4-bit) in {time.time()-t0:.1f}s")
    except Exception as e:
        print(f"  Unsloth not available ({e}), using standard transformers ...")
        from transformers import AutoModelForCausalLM, AutoTokenizer
        tok = AutoTokenizer.from_pretrained(model_id)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,   # CPU-safe, no accelerate needed
        )
        model.eval()
        print(f"  Loaded via transformers (fp32 CPU) in {time.time()-t0:.1f}s")
    return model, tok


def generate(model, tok, prompt, max_new_tokens=100):
    device = next(model.parameters()).device
    inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=1.0,
            pad_token_id=tok.eos_token_id,
        )
    plen = inputs["input_ids"].shape[-1]
    return tok.decode(out[0][plen:], skip_special_tokens=True)


def build_prompt(tok, obs_text, task_id):
    msgs = [
        {"role": "system", "content": SYSTEM},
        {"role": "user",   "content": obs_text},
    ]
    return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)


def run_episode_real(model, tok, task_id, seed):
    """

    Full episode:

      1. POST /reset → get observation text

      2. Model generates completion from the observation

      3. Parse DECISION/CONFIDENCE/REASON

      4. POST /step → get real reward_breakdown

    """
    reset_r = requests.post(
        f"{ENV_BASE_URL}/reset",
        json={"task_id": task_id, "seed": seed},
        timeout=15,
    )
    reset_r.raise_for_status()
    reset_data = reset_r.json()
    session_id = reset_data["session_id"]

    # Build text description from observation for the model
    obs = reset_data.get("observation", {})
    docs = obs.get("documents", [])
    doc_text = "\n".join(
        f"  [{d.get('doc_type','doc')}] {d.get('content','')}" for d in docs
    )
    incident = obs.get("incident", {})
    obs_text = (
        f"Task: {task_id} | Claim ID: {obs.get('claim_id','')}\n"
        f"Claimant: {obs.get('claimant',{}).get('name','')}\n"
        f"Incident: {incident.get('type','')}{incident.get('description','')[:150]}\n"
        f"Documents:\n{doc_text}\n"
        f"Linked claims: {len(obs.get('linked_claims', []))}"
    )

    prompt = build_prompt(tok, obs_text, task_id)

    t0 = time.time()
    completion = generate(model, tok, prompt)
    gen_time = time.time() - t0
    decision, confidence, reason = _parse(completion)

    if decision is None or confidence is None:
        # Format failure — submit a default escalation
        decision, confidence, reason = "escalate_to_human", "LOW", "Could not parse decision from model output."

    action = {
        "action_type": decision,
        "confidence":  confidence,
        "parameters":  {"reason": reason},
        "reasoning":   reason,
    }
    step_r = requests.post(
        f"{ENV_BASE_URL}/step",
        json={"action": action, "session_id": session_id},
        timeout=15,
    )
    step_r.raise_for_status()
    step_data = step_r.json()
    breakdown = step_data.get("observation", {}).get("reward_breakdown", {})

    return {
        "task_id":    task_id,
        "seed":       seed,
        "decision":   decision,
        "confidence": confidence,
        "reason":     reason[:100],
        "completion": completion[:200],
        "gen_time_s": round(gen_time, 1),
        "reward":     round(float(step_data.get("reward", 0.0)), 4),
        "fraud_detection_score":   round(float(breakdown.get("fraud_detection_score",  0.0)), 4),
        "decision_accuracy":       round(float(breakdown.get("decision_accuracy",      0.0)), 4),
        "evidence_quality_score":  round(float(breakdown.get("evidence_quality_score", 0.0)), 4),
        "calibration_score":       round(float(breakdown.get("calibration_score",      0.0)), 4),
    }


def eval_model(model, tok, label):
    print(f"\n{'='*60}")
    print(f"EVAL: {label}")
    print(f"{'='*60}")
    rows = []
    for task_id in EVAL_TASKS:
        for seed in SEEDS:
            try:
                row = run_episode_real(model, tok, task_id, seed)
                rows.append(row)
                print(
                    f"  {task_id:30s} seed={seed:2d}  "
                    f"decision={row['decision']:20s} conf={row['confidence']}  "
                    f"reward={row['reward']:.3f}  "
                    f"da={row['decision_accuracy']:.2f}  "
                    f"fd={row['fraud_detection_score']:.2f}  "
                    f"cal={row['calibration_score']:.2f}  "
                    f"[{row['gen_time_s']}s]"
                )
            except Exception as exc:
                print(f"  ERROR {task_id} seed={seed}: {exc}")
                rows.append({
                    "task_id": task_id, "seed": seed,
                    "reward": 0.0, "fraud_detection_score": 0.0,
                    "decision_accuracy": 0.0, "evidence_quality_score": 0.0,
                    "calibration_score": 0.0, "decision": "error",
                })

    component_means = {
        "Fraud detection":   round(mean(r["fraud_detection_score"]  for r in rows), 4),
        "Decision accuracy": round(mean(r["decision_accuracy"]      for r in rows), 4),
        "Evidence quality":  round(mean(r["evidence_quality_score"] for r in rows), 4),
        "Calibration":       round(mean(r["calibration_score"]      for r in rows), 4),
        "Mean reward":       round(mean(r["reward"]                 for r in rows), 4),
    }
    print(f"\n  Component means: {json.dumps(component_means)}")
    return rows, component_means


def save_and_plot(before_means, after_means, before_rows, after_rows, summary_path, log_history):
    delta = {k: round(after_means.get(k, 0.0) - before_means.get(k, 0.0), 4) for k in before_means}

    # Patch training_summary.json
    summary = json.loads(Path(summary_path).read_text(encoding="utf-8"))
    summary["eval_reward_before"] = before_means
    summary["eval_reward_after"]  = after_means
    summary["component_shift"] = {
        "note": (
            "Real model inference: before=Qwen/Qwen2.5-0.5B-Instruct (base), "
            "after=AniketAsla/debatefloor-grpo-qwen2.5-0.5b-instruct (GRPO fine-tuned). "
            "Rewards from live env HTTP API (MR-2 compliant)."
        ),
        "before": {k: v for k, v in before_means.items() if k != "Mean reward"},
        "after":  {k: v for k, v in after_means.items()  if k != "Mean reward"},
    }
    summary["component_shift_delta"] = {k: v for k, v in delta.items() if k != "Mean reward"}
    summary["eval_methodology"] = (
        "Real model inference: base Qwen2.5-0.5B (before) vs GRPO fine-tuned checkpoint (after). "
        f"Eval tasks: {EVAL_TASKS}. Seeds per task: {SEEDS}. "
        "Env reward from POST /step (not keyword matching)."
    )
    summary["eval_generated_at"] = datetime.now(timezone.utc).isoformat()
    summary["eval_rows"] = {"before": before_rows, "after": after_rows}

    # Remove stale pending markers
    for k in ("eval_reward_before", "eval_reward_after"):
        if summary.get(k) == "__pending_real_model_inference__":
            del summary[k]

    Path(summary_path).write_text(json.dumps(summary, indent=2), encoding="utf-8")
    print(f"\nUpdated {summary_path}")

    # Regenerate SVGs
    try:
        import matplotlib
        matplotlib.use("Agg")
        import matplotlib.pyplot as plt
        import numpy as np

        # component_shift.svg
        labels  = ["Fraud detection", "Decision accuracy", "Evidence quality", "Calibration"]
        bv = [before_means.get(l, 0.0) for l in labels]
        av = [after_means.get(l, 0.0)  for l in labels]
        x, w = np.arange(len(labels)), 0.35

        fig, ax = plt.subplots(figsize=(10, 5.5))
        ax.set_facecolor("#f9f9f9"); fig.patch.set_facecolor("#ffffff")
        bars_b = ax.bar(x - w/2, bv, w, label="Before (base Qwen2.5-0.5B)", color="#7a869a", alpha=0.85, edgecolor="white")
        bars_a = ax.bar(x + w/2, av, w, label="After (GRPO fine-tuned)",     color="#06a77d", alpha=0.85, edgecolor="white")

        for bar in bars_b:
            h = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, h + 0.02 if h >= 0 else h - 0.07,
                    f"{h:.2f}", ha="center", fontsize=9, color="#333")
        for bar in bars_a:
            h = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2, h + 0.02 if h >= 0 else h - 0.07,
                    f"{h:.2f}", ha="center", fontsize=9, color="#1a6b58")

        ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=11)
        ax.axhline(0, color="#666", linewidth=0.8, alpha=0.5)
        ax.set_ylim(-1.1, 1.3)
        ax.set_ylabel("Component score (clamped [0,1]; calibration unbounded)", fontsize=10)
        ax.set_xlabel("Reward component", fontsize=11)
        ax.set_title("DebateFloor: Real Model Before vs After GRPO Training", fontsize=13, fontweight="bold")
        ax.grid(True, axis="y", alpha=0.2, linestyle="--"); ax.legend(framealpha=0.85, fontsize=10)

        for i, (b_v, a_v) in enumerate(zip(bv, av)):
            d = a_v - b_v
            color = "#06a77d" if d > 0 else ("#e63946" if d < 0 else "#999")
            sign = "+" if d >= 0 else ""
            ax.text(x[i], max(a_v, b_v) + 0.10, f"D{sign}{d:.2f}",
                    ha="center", fontsize=9, color=color, fontweight="bold")

        delta_str = "  |  ".join(
            f"{k}: {'+' if v>=0 else ''}{v:.2f}" for k, v in delta.items() if k != "Mean reward"
        )
        ax.annotate(
            f"Deltas: {delta_str}\nTraining reward: 0.130 → 0.469 (+0.339, 3.6x via live env HTTP, 2,500 steps)\n"
            "Source: real model inference (not scripted)",
            xy=(0.01, 0.01), xycoords="axes fraction", fontsize=8, color="#555",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="#f0f8f0", edgecolor="#06a77d", alpha=0.85),
        )
        fig.tight_layout()
        Path("docs").mkdir(exist_ok=True)
        fig.savefig("docs/component_shift.svg", dpi=180, format="svg")
        plt.close(fig)
        print("docs/component_shift.svg updated")

        # reward_curve.svg (from training log_history)
        reward_steps, rewards, loss_steps, losses = [], [], [], []
        for row in log_history:
            step = row.get("step")
            if step is None: continue
            if "loss" in row and "train_runtime" not in row:
                loss_steps.append(step); losses.append(row["loss"])
            rv = row.get("reward") or row.get("rewards/reward_fn/mean")
            if rv is not None:
                reward_steps.append(step); rewards.append(rv)

        if rewards:
            def smooth(vals, w=7):
                return [sum(vals[max(0,i-w+1):i+1])/(i-max(0,i-w+1)+1) for i in range(len(vals))]
            fig2, ax1 = plt.subplots(figsize=(10, 5.5))
            ax1.set_facecolor("#f9f9f9"); fig2.patch.set_facecolor("#ffffff")
            if losses:
                ax1.plot(loss_steps, losses, color="#26547c", linewidth=1.2, alpha=0.45, label="Training loss")
                ax1.set_ylabel("Training loss", color="#26547c", fontsize=11)
                ax1.tick_params(axis="y", labelcolor="#26547c")
            ax1.set_xlabel("Training step", fontsize=11)
            ax1.grid(True, alpha=0.2, linestyle="--")
            ax2 = ax1.twinx()
            ax2.plot(reward_steps, rewards, color="#06a77d", linewidth=1.0, alpha=0.3)
            ax2.plot(reward_steps, smooth(rewards), color="#06a77d", linewidth=2.2, label="Mean reward (smoothed)")
            ax2.axhline(rewards[0],  color="#e63946", linewidth=1.0, linestyle="--", alpha=0.6, label=f"Start: {rewards[0]:.3f}")
            ax2.axhline(rewards[-1], color="#2a9d8f", linewidth=1.0, linestyle="--", alpha=0.6, label=f"End: {rewards[-1]:.3f}")
            ax2.set_ylabel("Mean reward — live env HTTP scalar (unbounded)", color="#06a77d", fontsize=11)
            ax2.tick_params(axis="y", labelcolor="#06a77d")
            ax2.annotate("Reward from live env (POST /step)\nNot comparable to clamped [0,1] eval score.",
                         xy=(0.02, 0.05), xycoords="axes fraction", fontsize=8.5, color="gray")
            lines1, lab1 = ax1.get_legend_handles_labels()
            lines2, lab2 = ax2.get_legend_handles_labels()
            ax2.legend(lines1+lines2, lab1+lab2, loc="upper left", framealpha=0.85, fontsize=9)
            fig2.suptitle("DebateFloor GRPO Training — Live Env Reward (HTTP, MR-2 Compliant)", fontsize=13, fontweight="bold")
            fig2.tight_layout()
            fig2.savefig("docs/reward_curve.svg", dpi=180, format="svg")
            plt.close(fig2)
            print("docs/reward_curve.svg updated")
    except Exception as exc:
        print(f"SVG generation failed: {exc}")


def main():
    # Verify env is up
    r = requests.get(f"{ENV_BASE_URL}/health", timeout=5)
    assert r.json().get("status") == "healthy", f"Env not healthy: {r.text}"
    print(f"Env healthy at {ENV_BASE_URL}")

    summary_path = "reports/training_summary.json"
    summary = json.loads(Path(summary_path).read_text(encoding="utf-8"))
    log_history = summary.get("log_history", [])

    # ── BEFORE: base model ─────────────────────────────────────────────────
    base_model, base_tok = load_model(BASE_MODEL, "BASE (before training)")
    before_rows, before_means = eval_model(base_model, base_tok, f"BEFORE — {BASE_MODEL}")
    del base_model  # free memory before loading trained model
    import gc; gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # ── AFTER: fine-tuned model ────────────────────────────────────────────
    trained_model, trained_tok = load_model(TRAINED_MODEL, "TRAINED (after GRPO)")
    after_rows, after_means = eval_model(trained_model, trained_tok, f"AFTER — {TRAINED_MODEL}")

    # ── Save everything ────────────────────────────────────────────────────
    save_and_plot(before_means, after_means, before_rows, after_rows, summary_path, log_history)

    print("\n" + "="*60)
    print("REAL INFERENCE RESULTS")
    print("="*60)
    delta = {k: round(after_means.get(k, 0.0) - before_means.get(k, 0.0), 4) for k in before_means if k != "Mean reward"}
    print(f"Before: {json.dumps({k:v for k,v in before_means.items() if k!='Mean reward'})}")
    print(f"After:  {json.dumps({k:v for k,v in after_means.items()  if k!='Mean reward'})}")
    print(f"Delta:  {json.dumps(delta)}")
    print("\nAll results from real model inference. Not scripted.")


if __name__ == "__main__":
    main()