""" Gradio UI — Agent Trajectory Replay Viewer for DataQA. Designed for judges: zero clicks needed, auto-plays on load. Tab per task, step slider, prominent metric cards, color-coded dataset. """ from __future__ import annotations import csv import io import gradio as gr from .environment import DataQAEnvironment, parse_issue_key from .tasks import list_tasks, PlantedIssue from ..models import DataQAAction # ── Pre-built agent trajectories (simulates baseline agent) ── AGENT_TRAJECTORIES = { # Demo trajectories: fixes are ONLY proposed where the correct value # is logically inferrable (computable, format conversion, or deducible from context). # Ambiguous fixes (any valid salary, any past date) are NOT proposed. "easy": [ { "issues": [ "row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type", "row:11,col:department,issue:format_violation", "row:15,col:email,issue:inconsistent_value", "row:3,col:email,issue:format_violation", # FP ], "fixes": [], }, { "issues": [ "row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type", "row:11,col:department,issue:format_violation", "row:15,col:email,issue:inconsistent_value", "row:12,col:start_date,issue:format_violation", "row:21,col:employee_id,issue:duplicate_row", ], "fixes": [ # All deterministic fixes: "row:4,col:name,fix:David Kim", # from email david.kim@ "row:7,col:salary,fix:75000", # "seventy-five thousand" → 75000 "row:11,col:department,fix:Engineering", # "Engneering" → "Engineering" "row:15,col:email,fix:oscar.rivera@company.com", # from name Oscar Rivera "row:12,col:start_date,fix:2022-11-03", # MM-DD-YYYY → YYYY-MM-DD ], }, ], "medium": [ { "issues": [ "row:5,col:total,issue:inconsistent_value", "row:10,col:category,issue:format_violation", "row:10,col:quantity,issue:wrong_type", "row:12,col:order_date,issue:format_violation", "row:29,col:product_name,issue:format_violation", "row:24,col:status,issue:format_violation", ], "fixes": [], }, { "issues": [ "row:5,col:total,issue:inconsistent_value", "row:10,col:category,issue:format_violation", "row:10,col:quantity,issue:wrong_type", "row:12,col:order_date,issue:format_violation", "row:19,col:order_id,issue:duplicate_row", "row:21,col:unit_price,issue:format_violation", "row:24,col:status,issue:format_violation", "row:29,col:product_name,issue:format_violation", ], "fixes": [ # All deterministic: "row:5,col:total,fix:42.00", # qty(1) * price(42.00) "row:10,col:category,fix:Sports", # "Fitness" → nearest valid "row:10,col:quantity,fix:10", # "1O" (letter O) → "10" "row:12,col:order_date,fix:2024-01-26", # DD/MM/YYYY → YYYY-MM-DD "row:24,col:status,fix:delivered", # "deliverred" → "delivered" "row:29,col:product_name,fix:Wireless Charger", # "Wireles" → "Wireless" "row:21,col:unit_price,fix:24.99", # 24.999 → round to 2 decimals ], }, ], "hard": [ { "issues": [ "row:14,col:training_time_hours,issue:out_of_range", "row:13,col:learning_rate,issue:out_of_range", "row:15,col:model_name,issue:missing_value", "row:9,col:batch_size,issue:format_violation", "row:10,col:train_size,issue:inconsistent_value", ], "fixes": [], }, { "issues": [ "row:14,col:training_time_hours,issue:out_of_range", "row:13,col:learning_rate,issue:out_of_range", "row:15,col:model_name,issue:missing_value", "row:9,col:batch_size,issue:format_violation", "row:10,col:train_size,issue:inconsistent_value", "row:5,col:val_loss,issue:inconsistent_value", "row:7,col:gpu_memory_gb,issue:statistical_outlier", "row:11,col:timestamp,issue:inconsistent_value", "row:9,col:training_time_hours,issue:statistical_outlier", "row:12,col:test_accuracy,issue:statistical_outlier", ], "fixes": [ # Only deterministic fixes: "row:9,col:batch_size,fix:256", # 250 → nearest power of 2 "row:14,col:training_time_hours,fix:72.0", # -72.0 → remove negative sign "row:15,col:model_name,fix:whisper-small", # "whsiper-small" → fix spelling # NOT proposed: row:13 LR (2.5 is out of range but any valid LR works) ], }, ], "alignment": [ { "issues": [ "row:6,col:response,issue:inconsistent_value", "row:15,col:response,issue:inconsistent_value", "row:28,col:prompt,issue:missing_value", "row:20,col:response,issue:inconsistent_value", "row:7,col:prompt,issue:duplicate_row", "row:25,col:response,issue:missing_value", "row:3,col:response,issue:inconsistent_value", ], "fixes": [], }, { "issues": [ "row:3,col:response,issue:inconsistent_value", "row:4,col:response,issue:inconsistent_value", "row:6,col:response,issue:inconsistent_value", "row:7,col:prompt,issue:duplicate_row", "row:8,col:response,issue:inconsistent_value", "row:11,col:response,issue:inconsistent_value", "row:15,col:response,issue:inconsistent_value", "row:23,col:helpfulness,issue:inconsistent_value", "row:20,col:response,issue:inconsistent_value", "row:25,col:response,issue:missing_value", "row:28,col:prompt,issue:missing_value", "row:29,col:response,issue:inconsistent_value", ], "fixes": [ # Inferrable: Salvator Mundi facts are well-known ($450.3M at Christie's) "row:4,col:response,fix:The most expensive painting ever sold at auction is Salvator Mundi by Leonardo da Vinci. It was sold for $450.3 million at Christie's in New York City in 2017.", # Inferrable: strip leaked [SYSTEM] prompt prefix "row:3,col:response,fix:Kitsch is art or design that is overly sentimental or ornate while camp is a style that is over-the-top and exaggerated often used in satire or irony.", # NOT proposed: row:6 wrong scientific name (need taxonomy knowledge) # NOT proposed: row:8 harmful advice (need to write safe version) # NOT proposed: row:11 self-contradiction (need to rewrite coherently) # NOT proposed: row:15 French response (need English translation) # NOT proposed: row:29 hallucinated citation (need factual replacement) ], }, ], "moderation": [ { "issues": [ "row:16,col:hate,issue:inconsistent_value", "row:17,col:harassment,issue:inconsistent_value", "row:20,col:violence,issue:inconsistent_value", "row:13,col:violence,issue:out_of_range", "row:15,col:text,issue:missing_value", "row:30,col:text,issue:duplicate_row", ], "fixes": [], }, { "issues": [ "row:5,col:hate,issue:inconsistent_value", "row:13,col:violence,issue:out_of_range", "row:15,col:text,issue:missing_value", "row:16,col:hate,issue:inconsistent_value", "row:17,col:harassment,issue:inconsistent_value", "row:20,col:violence,issue:inconsistent_value", "row:22,col:self_harm,issue:inconsistent_value", "row:24,col:hate,issue:inconsistent_value", "row:25,col:violence,issue:inconsistent_value", "row:30,col:text,issue:duplicate_row", ], "fixes": [ # All deterministic label flips: "row:16,col:hate,fix:1", # unflagged hate → flag it "row:17,col:harassment,fix:1", # unflagged harassment → flag it "row:20,col:violence,fix:1", # unflagged violence → flag it "row:22,col:self_harm,fix:1", # unflagged self-harm → flag it "row:5,col:hate,fix:0", # false positive on idiom → unflag "row:24,col:hate,fix:1", # subset rule: hate_threatening needs hate "row:25,col:violence,fix:0", # chose walk over violence → not violent "row:13,col:violence,fix:0", # out of range 3 → 0 ], }, ], } # ── HTML rendering ── def _metric_card(label: str, value: str, color: str = "#333") -> str: return ( f'
' f'
{label}
' f'
{value}
' f'
' ) def _csv_to_html( csv_text: str, planted: list[PlantedIssue], correct: set[tuple[int, str]], fp: set[tuple[int, str]], missed: set[tuple[int, str]], fixed: dict[tuple[int, str], str], fix_values: dict[tuple[int, str], str] | None = None, ) -> str: """Render CSV as HTML with color-coded cells and inline fix proposals.""" fix_values = fix_values or {} desc_map = {(i.row, i.col): i for i in planted} reader = csv.reader(io.StringIO(csv_text.strip())) rows = list(reader) if not rows: return "" header = rows[0] header_lower = [h.strip().lower() for h in header] data = rows[1:] t = [''] t.append('') t.append('') for h in header: t.append(f'') t.append('') for i, row in enumerate(data): rn = i + 1 bg = "#fff" if i % 2 == 0 else "#f8f9fa" t.append(f'') t.append(f'') for j, val in enumerate(row): col = header_lower[j] if j < len(header_lower) else "" ck = (rn, col) s = "border:1px solid #dee2e6;padding:4px 8px;" tip = "" badge = "" issue = desc_map.get(ck) if ck in correct: s += "background:#d4edda;" tip = f"FOUND: {issue.description}" if issue else "" badge = 'TP' elif ck in fp: s += "background:#f8d7da;" badge = 'FP' elif ck in missed: s += "background:#fff3cd;" tip = f"MISSED: {issue.description}" if issue else "" badge = 'MISS' fx = fixed.get(ck) proposed = fix_values.get(ck) if fx == "correct": s += "box-shadow:inset 0 0 0 2px #28a745;" badge += 'FIX' elif fx == "partial": s += "box-shadow:inset 0 0 0 2px #ffc107;" badge += '~FIX' dv = val if val.strip() else 'empty' # Show proposed fix value below the corrupted value fix_line = "" if proposed is not None: fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545") fix_line = ( f'
' f'\u2192 {proposed}
' ) t.append(f'') t.append('') t.append('
Row{h}
{rn}{dv}{badge}{fix_line}
') return "".join(t) LEGEND_HTML = ( '
' 'Found (TP)' 'False Positive' 'Missed' 'Fix Correct' 'Fix Partial' '
' ) # ── Core replay logic ── def _replay_task(task_id: str) -> list[dict]: """Run the agent trajectory and collect per-step data.""" env = DataQAEnvironment() obs = env.reset(task_id=task_id) task = env._current_task planted_keys = {i.to_key() for i in task.planted_issues} steps_data = [] # Step 0: initial state steps_data.append({ "label": "Initial — corrupted dataset", "html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}), "metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues), "identify": 0.0, "fix": 0.0, "fixes_correct": 0}, "feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}", }) trajectory = AGENT_TRAJECTORIES.get(task_id, []) for i, step_data in enumerate(trajectory): action = DataQAAction( issues=step_data["issues"], fixes=step_data.get("fixes", []), task_id=task_id, ) obs = env.step(action) reported_keys = set() for iss in step_data["issues"]: key = parse_issue_key(iss) if key: reported_keys.add(key) tp_keys = reported_keys & planted_keys fp_keys = reported_keys - planted_keys fn_keys = planted_keys - reported_keys correct = {_kc(k) for k in tp_keys} fp = {_kc(k) for k in fp_keys} missed = {_kc(k) for k in fn_keys} if obs.done else set() fixed: dict[tuple[int, str], str] = {} for d in obs.metadata.get("fix_details", []): c = (d["row"], d["col"]) fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong") # Extract proposed fix values from the raw fix strings fix_values: dict[tuple[int, str], str] = {} from .environment import parse_fix for raw_fix in step_data.get("fixes", []): parsed = parse_fix(raw_fix) if parsed: row, col, val = parsed fix_values[(row, col)] = val html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values) has_fixes = bool(step_data.get("fixes")) if has_fixes: label = f"Step {i+1} — identify + fix" else: label = f"Step {i+1} — identify only" steps_data.append({ "label": label, "html": html, "metrics": { "reward": obs.reward, "tp": obs.metadata["tp"], "fp": obs.metadata["fp"], "fn": obs.metadata["fn"], "identify": obs.metadata["identify_score"], "fix": obs.metadata["fix_score"], "fixes_correct": obs.metadata["fixes_correct"], }, "feedback": obs.feedback, }) return steps_data def _kc(key: str) -> tuple[int, str]: parts = key.split(",") return (int(parts[0].split(":")[1]), parts[1].split(":")[1]) # ── Gradio app ── def build_gradio_ui(): # Pre-compute all replays at startup all_replays: dict[str, list[dict]] = {} for tid in list_tasks(): all_replays[tid] = _replay_task(tid) def show_step(task_id: str, step_idx: int): replay = all_replays.get(task_id, []) step_idx = int(step_idx) if step_idx >= len(replay): step_idx = len(replay) - 1 sd = replay[step_idx] m = sd["metrics"] # Reward color r = m["reward"] rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545") cards = ( '
' + _metric_card("Reward", f"{r:.2f}", rc) + _metric_card("Found", str(m["tp"]), "#28a745") + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745") + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745") + _metric_card("Identify", f"{m['identify']:.2f}", "#333") + _metric_card("Fix", f"{m['fix']:.2f}", "#333") + '
' ) full_html = ( f'
' f'{sd["label"]}
' + cards + sd["html"] + LEGEND_HTML ) return full_html, sd["feedback"] def on_task_change(task_id): replay = all_replays.get(task_id, []) max_step = len(replay) - 1 html, fb = show_step(task_id, 0) return ( gr.update(maximum=max_step, value=0), html, fb, ) def on_step_change(task_id, step_idx): html, fb = show_step(task_id, step_idx) return html, fb # ── Live agent runner (connects to the env server) ── live_env = DataQAEnvironment() live_state: dict = {"obs": None, "task_id": "easy", "steps": []} def live_reset(task_id): obs = live_env.reset(task_id=task_id) task = live_env._current_task live_state["obs"] = obs live_state["task_id"] = task_id live_state["steps"] = [] html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}) info = f"**{task.name}** — {obs.num_issues_hint} issues to find, {obs.max_steps} steps max" return html, info, "", "0.000" def live_step(issues_text, fixes_text): if live_state["obs"] is None: return "Reset first.", "", "", "" obs = live_state["obs"] task = live_env._current_task planted_keys = {i.to_key() for i in task.planted_issues} issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()] fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else [] action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"]) obs = live_env.step(action) live_state["obs"] = obs reported_keys = set() for iss in issues: key = parse_issue_key(iss) if key: reported_keys.add(key) tp_keys = reported_keys & planted_keys fp_keys = reported_keys - planted_keys fn_keys = planted_keys - reported_keys correct = {_kc(k) for k in tp_keys} fp_set = {_kc(k) for k in fp_keys} missed = {_kc(k) for k in fn_keys} if obs.done else set() fixed: dict[tuple[int, str], str] = {} for d in obs.metadata.get("fix_details", []): c = (d["row"], d["col"]) fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong") from .environment import parse_fix fix_values: dict[tuple[int, str], str] = {} for raw in fixes: parsed = parse_fix(raw) if parsed: fix_values[(parsed[0], parsed[1])] = parsed[2] html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values) m = obs.metadata r = obs.reward rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545") cards = ( '
' + _metric_card("Reward", f"{r:.2f}", rc) + _metric_card("Found", str(m["tp"]), "#28a745") + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745") + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745") + '
' ) full_html = cards + html + LEGEND_HTML return full_html, obs.feedback, f"{r:.3f}", "" # ── Build the UI ── with gr.Blocks(title="DataQA Environment") as demo: gr.Markdown( "# DataQA — Data Quality Assurance Environment\n" "Two-phase RL environment: **Identify** data quality issues, then **Fix** them." ) with gr.Tabs(): # ── Tab 1: Demo replay ── with gr.Tab("Demo (Baseline Agent)"): gr.Markdown( "*Replay of the baseline Qwen-72B agent. " "Use the slider to step through the agent's trajectory.*" ) with gr.Row(): task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1) step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3) viz_html = gr.HTML() feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False) task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box]) step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box]) demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box]) # ── Tab 2: Try your own agent ── with gr.Tab("Try Your Own Agent"): gr.Markdown( "*Submit your own issues and fixes to see how the environment scores them. " "This is the same environment the baseline agent talks to.*" ) with gr.Row(): live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1) live_reset_btn = gr.Button("Reset", variant="primary", scale=1) with gr.Row(): live_info = gr.Markdown() live_reward = gr.Textbox(label="Reward", interactive=False, scale=1) live_viz = gr.HTML() with gr.Row(): live_issues = gr.Textbox( label="Issues (one per line)", placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type", lines=5, ) live_fixes = gr.Textbox( label="Fixes (one per line, optional)", placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000", lines=5, ) live_step_btn = gr.Button("Submit Step", variant="primary") live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False) live_reset_btn.click( live_reset, inputs=[live_task_dd], outputs=[live_viz, live_info, live_feedback, live_reward], ) live_step_btn.click( live_step, inputs=[live_issues, live_fixes], outputs=[live_viz, live_feedback, live_reward, live_issues], ) return demo if __name__ == "__main__": demo = build_gradio_ui() demo.launch()