dataqa-env / dataqa_env /server /gradio_ui.py
varb15's picture
Upload folder using huggingface_hub
64eb355 verified
"""
Gradio UI β€” Agent Trajectory Replay Viewer for DataQA.
Designed for judges: zero clicks needed, auto-plays on load.
Tab per task, step slider, prominent metric cards, color-coded dataset.
"""
from __future__ import annotations
import csv
import io
import gradio as gr
from .environment import DataQAEnvironment, parse_issue_key
from .tasks import list_tasks, PlantedIssue
from ..models import DataQAAction
# ── Pre-built agent trajectories (simulates baseline agent) ──
AGENT_TRAJECTORIES = {
# Demo trajectories: fixes are ONLY proposed where the correct value
# is logically inferrable (computable, format conversion, or deducible from context).
# Ambiguous fixes (any valid salary, any past date) are NOT proposed.
"easy": [
{
"issues": [
"row:4,col:name,issue:missing_value",
"row:7,col:salary,issue:wrong_type",
"row:11,col:department,issue:format_violation",
"row:15,col:email,issue:inconsistent_value",
"row:3,col:email,issue:format_violation", # FP
],
"fixes": [],
},
{
"issues": [
"row:4,col:name,issue:missing_value",
"row:7,col:salary,issue:wrong_type",
"row:11,col:department,issue:format_violation",
"row:15,col:email,issue:inconsistent_value",
"row:12,col:start_date,issue:format_violation",
"row:21,col:employee_id,issue:duplicate_row",
],
"fixes": [
# All deterministic fixes:
"row:4,col:name,fix:David Kim", # from email david.kim@
"row:7,col:salary,fix:75000", # "seventy-five thousand" β†’ 75000
"row:11,col:department,fix:Engineering", # "Engneering" β†’ "Engineering"
"row:15,col:email,fix:oscar.rivera@company.com", # from name Oscar Rivera
"row:12,col:start_date,fix:2022-11-03", # MM-DD-YYYY β†’ YYYY-MM-DD
],
},
],
"medium": [
{
"issues": [
"row:5,col:total,issue:inconsistent_value",
"row:10,col:category,issue:format_violation",
"row:10,col:quantity,issue:wrong_type",
"row:12,col:order_date,issue:format_violation",
"row:29,col:product_name,issue:format_violation",
"row:24,col:status,issue:format_violation",
],
"fixes": [],
},
{
"issues": [
"row:5,col:total,issue:inconsistent_value",
"row:10,col:category,issue:format_violation",
"row:10,col:quantity,issue:wrong_type",
"row:12,col:order_date,issue:format_violation",
"row:19,col:order_id,issue:duplicate_row",
"row:21,col:unit_price,issue:format_violation",
"row:24,col:status,issue:format_violation",
"row:29,col:product_name,issue:format_violation",
],
"fixes": [
# All deterministic:
"row:5,col:total,fix:42.00", # qty(1) * price(42.00)
"row:10,col:category,fix:Sports", # "Fitness" β†’ nearest valid
"row:10,col:quantity,fix:10", # "1O" (letter O) β†’ "10"
"row:12,col:order_date,fix:2024-01-26", # DD/MM/YYYY β†’ YYYY-MM-DD
"row:24,col:status,fix:delivered", # "deliverred" β†’ "delivered"
"row:29,col:product_name,fix:Wireless Charger", # "Wireles" β†’ "Wireless"
"row:21,col:unit_price,fix:24.99", # 24.999 β†’ round to 2 decimals
],
},
],
"hard": [
{
"issues": [
"row:14,col:training_time_hours,issue:out_of_range",
"row:13,col:learning_rate,issue:out_of_range",
"row:15,col:model_name,issue:missing_value",
"row:9,col:batch_size,issue:format_violation",
"row:10,col:train_size,issue:inconsistent_value",
],
"fixes": [],
},
{
"issues": [
"row:14,col:training_time_hours,issue:out_of_range",
"row:13,col:learning_rate,issue:out_of_range",
"row:15,col:model_name,issue:missing_value",
"row:9,col:batch_size,issue:format_violation",
"row:10,col:train_size,issue:inconsistent_value",
"row:5,col:val_loss,issue:inconsistent_value",
"row:7,col:gpu_memory_gb,issue:statistical_outlier",
"row:11,col:timestamp,issue:inconsistent_value",
"row:9,col:training_time_hours,issue:statistical_outlier",
"row:12,col:test_accuracy,issue:statistical_outlier",
],
"fixes": [
# Only deterministic fixes:
"row:9,col:batch_size,fix:256", # 250 β†’ nearest power of 2
"row:14,col:training_time_hours,fix:72.0", # -72.0 β†’ remove negative sign
"row:15,col:model_name,fix:whisper-small", # "whsiper-small" β†’ fix spelling
# NOT proposed: row:13 LR (2.5 is out of range but any valid LR works)
],
},
],
"alignment": [
{
"issues": [
"row:6,col:response,issue:inconsistent_value",
"row:15,col:response,issue:inconsistent_value",
"row:28,col:prompt,issue:missing_value",
"row:20,col:response,issue:inconsistent_value",
"row:7,col:prompt,issue:duplicate_row",
"row:25,col:response,issue:missing_value",
"row:3,col:response,issue:inconsistent_value",
],
"fixes": [],
},
{
"issues": [
"row:3,col:response,issue:inconsistent_value",
"row:4,col:response,issue:inconsistent_value",
"row:6,col:response,issue:inconsistent_value",
"row:7,col:prompt,issue:duplicate_row",
"row:8,col:response,issue:inconsistent_value",
"row:11,col:response,issue:inconsistent_value",
"row:15,col:response,issue:inconsistent_value",
"row:23,col:helpfulness,issue:inconsistent_value",
"row:20,col:response,issue:inconsistent_value",
"row:25,col:response,issue:missing_value",
"row:28,col:prompt,issue:missing_value",
"row:29,col:response,issue:inconsistent_value",
],
"fixes": [
# Inferrable: Salvator Mundi facts are well-known ($450.3M at Christie's)
"row:4,col:response,fix:The most expensive painting ever sold at auction is Salvator Mundi by Leonardo da Vinci. It was sold for $450.3 million at Christie's in New York City in 2017.",
# Inferrable: strip leaked [SYSTEM] prompt prefix
"row:3,col:response,fix:Kitsch is art or design that is overly sentimental or ornate while camp is a style that is over-the-top and exaggerated often used in satire or irony.",
# NOT proposed: row:6 wrong scientific name (need taxonomy knowledge)
# NOT proposed: row:8 harmful advice (need to write safe version)
# NOT proposed: row:11 self-contradiction (need to rewrite coherently)
# NOT proposed: row:15 French response (need English translation)
# NOT proposed: row:29 hallucinated citation (need factual replacement)
],
},
],
"moderation": [
{
"issues": [
"row:16,col:hate,issue:inconsistent_value",
"row:17,col:harassment,issue:inconsistent_value",
"row:20,col:violence,issue:inconsistent_value",
"row:13,col:violence,issue:out_of_range",
"row:15,col:text,issue:missing_value",
"row:30,col:text,issue:duplicate_row",
],
"fixes": [],
},
{
"issues": [
"row:5,col:hate,issue:inconsistent_value",
"row:13,col:violence,issue:out_of_range",
"row:15,col:text,issue:missing_value",
"row:16,col:hate,issue:inconsistent_value",
"row:17,col:harassment,issue:inconsistent_value",
"row:20,col:violence,issue:inconsistent_value",
"row:22,col:self_harm,issue:inconsistent_value",
"row:24,col:hate,issue:inconsistent_value",
"row:25,col:violence,issue:inconsistent_value",
"row:30,col:text,issue:duplicate_row",
],
"fixes": [
# All deterministic label flips:
"row:16,col:hate,fix:1", # unflagged hate β†’ flag it
"row:17,col:harassment,fix:1", # unflagged harassment β†’ flag it
"row:20,col:violence,fix:1", # unflagged violence β†’ flag it
"row:22,col:self_harm,fix:1", # unflagged self-harm β†’ flag it
"row:5,col:hate,fix:0", # false positive on idiom β†’ unflag
"row:24,col:hate,fix:1", # subset rule: hate_threatening needs hate
"row:25,col:violence,fix:0", # chose walk over violence β†’ not violent
"row:13,col:violence,fix:0", # out of range 3 β†’ 0
],
},
],
}
# ── HTML rendering ──
def _metric_card(label: str, value: str, color: str = "#333") -> str:
return (
f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
f'border-radius:8px;min-width:100px;">'
f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
f'</div>'
)
def _csv_to_html(
csv_text: str,
planted: list[PlantedIssue],
correct: set[tuple[int, str]],
fp: set[tuple[int, str]],
missed: set[tuple[int, str]],
fixed: dict[tuple[int, str], str],
fix_values: dict[tuple[int, str], str] | None = None,
) -> str:
"""Render CSV as HTML with color-coded cells and inline fix proposals."""
fix_values = fix_values or {}
desc_map = {(i.row, i.col): i for i in planted}
reader = csv.reader(io.StringIO(csv_text.strip()))
rows = list(reader)
if not rows:
return ""
header = rows[0]
header_lower = [h.strip().lower() for h in header]
data = rows[1:]
t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
t.append('<tr>')
t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
for h in header:
t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
t.append('</tr>')
for i, row in enumerate(data):
rn = i + 1
bg = "#fff" if i % 2 == 0 else "#f8f9fa"
t.append(f'<tr style="background:{bg};">')
t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
for j, val in enumerate(row):
col = header_lower[j] if j < len(header_lower) else ""
ck = (rn, col)
s = "border:1px solid #dee2e6;padding:4px 8px;"
tip = ""
badge = ""
issue = desc_map.get(ck)
if ck in correct:
s += "background:#d4edda;"
tip = f"FOUND: {issue.description}" if issue else ""
badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
elif ck in fp:
s += "background:#f8d7da;"
badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
elif ck in missed:
s += "background:#fff3cd;"
tip = f"MISSED: {issue.description}" if issue else ""
badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'
fx = fixed.get(ck)
proposed = fix_values.get(ck)
if fx == "correct":
s += "box-shadow:inset 0 0 0 2px #28a745;"
badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
elif fx == "partial":
s += "box-shadow:inset 0 0 0 2px #ffc107;"
badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'
dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'
# Show proposed fix value below the corrupted value
fix_line = ""
if proposed is not None:
fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
fix_line = (
f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
f'border-top:1px dashed {fix_color};padding-top:2px;">'
f'\u2192 {proposed}</div>'
)
t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
t.append('</tr>')
t.append('</table>')
return "".join(t)
LEGEND_HTML = (
'<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
'<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
'<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
'<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
'<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
'<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
'</div>'
)
# ── Core replay logic ──
def _replay_task(task_id: str) -> list[dict]:
"""Run the agent trajectory and collect per-step data."""
env = DataQAEnvironment()
obs = env.reset(task_id=task_id)
task = env._current_task
planted_keys = {i.to_key() for i in task.planted_issues}
steps_data = []
# Step 0: initial state
steps_data.append({
"label": "Initial β€” corrupted dataset",
"html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
"metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
"identify": 0.0, "fix": 0.0, "fixes_correct": 0},
"feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
})
trajectory = AGENT_TRAJECTORIES.get(task_id, [])
for i, step_data in enumerate(trajectory):
action = DataQAAction(
issues=step_data["issues"],
fixes=step_data.get("fixes", []),
task_id=task_id,
)
obs = env.step(action)
reported_keys = set()
for iss in step_data["issues"]:
key = parse_issue_key(iss)
if key:
reported_keys.add(key)
tp_keys = reported_keys & planted_keys
fp_keys = reported_keys - planted_keys
fn_keys = planted_keys - reported_keys
correct = {_kc(k) for k in tp_keys}
fp = {_kc(k) for k in fp_keys}
missed = {_kc(k) for k in fn_keys} if obs.done else set()
fixed: dict[tuple[int, str], str] = {}
for d in obs.metadata.get("fix_details", []):
c = (d["row"], d["col"])
fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
# Extract proposed fix values from the raw fix strings
fix_values: dict[tuple[int, str], str] = {}
from .environment import parse_fix
for raw_fix in step_data.get("fixes", []):
parsed = parse_fix(raw_fix)
if parsed:
row, col, val = parsed
fix_values[(row, col)] = val
html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)
has_fixes = bool(step_data.get("fixes"))
if has_fixes:
label = f"Step {i+1} β€” identify + fix"
else:
label = f"Step {i+1} β€” identify only"
steps_data.append({
"label": label,
"html": html,
"metrics": {
"reward": obs.reward,
"tp": obs.metadata["tp"],
"fp": obs.metadata["fp"],
"fn": obs.metadata["fn"],
"identify": obs.metadata["identify_score"],
"fix": obs.metadata["fix_score"],
"fixes_correct": obs.metadata["fixes_correct"],
},
"feedback": obs.feedback,
})
return steps_data
def _kc(key: str) -> tuple[int, str]:
parts = key.split(",")
return (int(parts[0].split(":")[1]), parts[1].split(":")[1])
# ── Gradio app ──
def build_gradio_ui():
# Pre-compute all replays at startup
all_replays: dict[str, list[dict]] = {}
for tid in list_tasks():
all_replays[tid] = _replay_task(tid)
def show_step(task_id: str, step_idx: int):
replay = all_replays.get(task_id, [])
step_idx = int(step_idx)
if step_idx >= len(replay):
step_idx = len(replay) - 1
sd = replay[step_idx]
m = sd["metrics"]
# Reward color
r = m["reward"]
rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
cards = (
'<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
+ _metric_card("Reward", f"{r:.2f}", rc)
+ _metric_card("Found", str(m["tp"]), "#28a745")
+ _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
+ _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
+ _metric_card("Identify", f"{m['identify']:.2f}", "#333")
+ _metric_card("Fix", f"{m['fix']:.2f}", "#333")
+ '</div>'
)
full_html = (
f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
f'{sd["label"]}</div>'
+ cards + sd["html"] + LEGEND_HTML
)
return full_html, sd["feedback"]
def on_task_change(task_id):
replay = all_replays.get(task_id, [])
max_step = len(replay) - 1
html, fb = show_step(task_id, 0)
return (
gr.update(maximum=max_step, value=0),
html,
fb,
)
def on_step_change(task_id, step_idx):
html, fb = show_step(task_id, step_idx)
return html, fb
# ── Live agent runner (connects to the env server) ──
live_env = DataQAEnvironment()
live_state: dict = {"obs": None, "task_id": "easy", "steps": []}
def live_reset(task_id):
obs = live_env.reset(task_id=task_id)
task = live_env._current_task
live_state["obs"] = obs
live_state["task_id"] = task_id
live_state["steps"] = []
html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
info = f"**{task.name}** β€” {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
return html, info, "", "0.000"
def live_step(issues_text, fixes_text):
if live_state["obs"] is None:
return "Reset first.", "", "", ""
obs = live_state["obs"]
task = live_env._current_task
planted_keys = {i.to_key() for i in task.planted_issues}
issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []
action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
obs = live_env.step(action)
live_state["obs"] = obs
reported_keys = set()
for iss in issues:
key = parse_issue_key(iss)
if key:
reported_keys.add(key)
tp_keys = reported_keys & planted_keys
fp_keys = reported_keys - planted_keys
fn_keys = planted_keys - reported_keys
correct = {_kc(k) for k in tp_keys}
fp_set = {_kc(k) for k in fp_keys}
missed = {_kc(k) for k in fn_keys} if obs.done else set()
fixed: dict[tuple[int, str], str] = {}
for d in obs.metadata.get("fix_details", []):
c = (d["row"], d["col"])
fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
from .environment import parse_fix
fix_values: dict[tuple[int, str], str] = {}
for raw in fixes:
parsed = parse_fix(raw)
if parsed:
fix_values[(parsed[0], parsed[1])] = parsed[2]
html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)
m = obs.metadata
r = obs.reward
rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
cards = (
'<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
+ _metric_card("Reward", f"{r:.2f}", rc)
+ _metric_card("Found", str(m["tp"]), "#28a745")
+ _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
+ _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
+ '</div>'
)
full_html = cards + html + LEGEND_HTML
return full_html, obs.feedback, f"{r:.3f}", ""
# ── Build the UI ──
with gr.Blocks(title="DataQA Environment") as demo:
gr.Markdown(
"# DataQA β€” Data Quality Assurance Environment\n"
"Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
)
with gr.Tabs():
# ── Tab 1: Demo replay ──
with gr.Tab("Demo (Baseline Agent)"):
gr.Markdown(
"*Replay of the baseline Qwen-72B agent. "
"Use the slider to step through the agent's trajectory.*"
)
with gr.Row():
task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)
viz_html = gr.HTML()
feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)
task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
# ── Tab 2: Try your own agent ──
with gr.Tab("Try Your Own Agent"):
gr.Markdown(
"*Submit your own issues and fixes to see how the environment scores them. "
"This is the same environment the baseline agent talks to.*"
)
with gr.Row():
live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
live_reset_btn = gr.Button("Reset", variant="primary", scale=1)
with gr.Row():
live_info = gr.Markdown()
live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)
live_viz = gr.HTML()
with gr.Row():
live_issues = gr.Textbox(
label="Issues (one per line)",
placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
lines=5,
)
live_fixes = gr.Textbox(
label="Fixes (one per line, optional)",
placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
lines=5,
)
live_step_btn = gr.Button("Submit Step", variant="primary")
live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)
live_reset_btn.click(
live_reset, inputs=[live_task_dd],
outputs=[live_viz, live_info, live_feedback, live_reward],
)
live_step_btn.click(
live_step, inputs=[live_issues, live_fixes],
outputs=[live_viz, live_feedback, live_reward, live_issues],
)
return demo
if __name__ == "__main__":
demo = build_gradio_ui()
demo.launch()