Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| DataQA Inference Script β Two-Phase Agent | |
| ------------------------------------------ | |
| LLM agent that plays the DataQA environment in two phases: | |
| Phase 1: Identify all data quality issues | |
| Phase 2: Propose fixes for identified issues | |
| Uses the OpenAI client to interact with any OpenAI-compatible LLM API. | |
| Required environment variables: | |
| API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1) | |
| MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct) | |
| HF_TOKEN - HuggingFace token / API key | |
| STDOUT FORMAT (mandatory for evaluation): | |
| [START] task=<task_name> env=<benchmark> model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from typing import List, Optional | |
| import requests | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") | |
| BENCHMARK = "dataqa_env" | |
| TASKS = ["easy", "medium", "hard", "alignment", "coding", "toolcalling", "moderation"] | |
| MAX_STEPS_PER_TASK = 3 | |
| # --------------------------------------------------------------------------- | |
| # Logging helpers (structured stdout β exact format required by evaluation) | |
| # --------------------------------------------------------------------------- | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Environment HTTP client | |
| # --------------------------------------------------------------------------- | |
| class EnvHTTPClient: | |
| """Minimal HTTP client for the DataQA environment.""" | |
| def __init__(self, base_url: str): | |
| self.base_url = base_url.rstrip("/") | |
| self.session = requests.Session() | |
| def health(self) -> bool: | |
| try: | |
| r = self.session.get(f"{self.base_url}/health", timeout=10) | |
| return r.status_code == 200 | |
| except Exception: | |
| return False | |
| def reset(self, task_id: str = "easy") -> dict: | |
| r = self.session.post( | |
| f"{self.base_url}/reset", | |
| json={"task_id": task_id}, | |
| timeout=30, | |
| ) | |
| r.raise_for_status() | |
| return r.json() | |
| def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict: | |
| r = self.session.post( | |
| f"{self.base_url}/step", | |
| json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}}, | |
| timeout=30, | |
| ) | |
| r.raise_for_status() | |
| return r.json() | |
| # --------------------------------------------------------------------------- | |
| # LLM Prompts | |
| # --------------------------------------------------------------------------- | |
| IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues. | |
| You will be given: | |
| 1. A dataset in CSV format | |
| 2. A schema describing expected column types and constraints | |
| 3. Validation rules that the data should satisfy | |
| You must identify ALL data quality issues and report each one in EXACTLY this format: | |
| row:<row_number>,col:<column_name>,issue:<issue_type> | |
| Supported issue types: | |
| - missing_value (null, empty, or whitespace-only) | |
| - wrong_type (value doesn't match expected type) | |
| - duplicate_row (exact duplicate or duplicate key) | |
| - out_of_range (value outside valid range) | |
| - format_violation (wrong format, invalid enum value) | |
| - inconsistent_value (computed field doesn't match, logical inconsistency) | |
| - statistical_outlier (value is unreasonable given context) | |
| - referential_integrity (foreign key violation) | |
| CRITICAL INSTRUCTIONS FOR ROW NUMBERING: | |
| - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column | |
| - Row 1 = the FIRST data row after the header | |
| - Row 2 = the SECOND data row after the header | |
| - DO NOT use the employee_id, order_id, or experiment_id as the row number | |
| - Column names must match exactly (use the CSV header names, lowercase) | |
| - Check EVERY row and EVERY column systematically | |
| - Consider cross-column consistency (e.g., total = quantity * price) | |
| - Look for subtle issues like whitespace-only values, near-duplicates | |
| - Report ALL issues you find, even if uncertain | |
| Respond with ONLY the list of issues, one per line. No other text. | |
| Example: row:3,col:salary,issue:missing_value""" | |
| FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue. | |
| For each issue you identified, propose a fix in EXACTLY this format: | |
| row:<row_number>,col:<column_name>,fix:<corrected_value> | |
| Guidelines for proposing fixes: | |
| - For missing_value: infer the correct value from context, schema, and other rows | |
| - For wrong_type: convert to the correct type (e.g., "seventy-five thousand" β "75000") | |
| - For out_of_range: propose a value within the valid range that makes sense in context | |
| - For format_violation: correct the format (e.g., "26/01/2024" β "2024-01-26") | |
| - For inconsistent_value: compute the correct value from related fields | |
| - For duplicate_row: propose a corrected unique key or indicate removal | |
| - For statistical_outlier: propose a reasonable value given the model/context | |
| Use the schema, validation rules, and surrounding data to determine the correct fix. | |
| Respond with ONLY the list of fixes, one per line. No other text. | |
| Example: row:3,col:salary,fix:75000""" | |
| def build_user_prompt(observation: dict, include_fixes: bool = False) -> str: | |
| obs = observation if isinstance(observation, dict) else observation | |
| parts = [] | |
| if obs.get("task_description"): | |
| parts.append(f"TASK: {obs['task_description']}") | |
| parts.append(f"SCHEMA:\n{obs.get('schema_description', '')}") | |
| parts.append(f"VALIDATION RULES:\n{obs.get('validation_rules', '')}") | |
| parts.append(f"DATASET:\n{obs.get('dataset_csv', '')}") | |
| hint = obs.get("num_issues_hint", 0) | |
| if hint: | |
| parts.append(f"HINT: There are exactly {hint} issues to find.") | |
| feedback = obs.get("feedback", "") | |
| if feedback and "reset" not in feedback.lower(): | |
| parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}") | |
| if include_fixes: | |
| parts.append( | |
| "Now propose fixes for ALL issues. " | |
| "Use format: row:<N>,col:<name>,fix:<corrected_value>" | |
| ) | |
| return "\n\n".join(parts) | |
| def parse_llm_response(response: str) -> list[str]: | |
| """Extract issue lines from LLM response.""" | |
| issues = [] | |
| for line in response.strip().split("\n"): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line) | |
| line = re.sub(r"^\s*[-*]\s*", "", line) | |
| line = line.strip() | |
| if "row" in line.lower() and "col" in line.lower(): | |
| match = re.search( | |
| r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)", | |
| line, | |
| re.IGNORECASE, | |
| ) | |
| if match: | |
| normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}" | |
| issues.append(normalized) | |
| return issues | |
| def parse_fix_response(response: str) -> list[str]: | |
| """Extract fix lines from LLM response.""" | |
| fixes = [] | |
| for line in response.strip().split("\n"): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line) | |
| line = re.sub(r"^\s*[-*]\s*", "", line) | |
| line = line.strip() | |
| if "row" in line.lower() and "fix" in line.lower(): | |
| match = re.search( | |
| r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$", | |
| line, | |
| re.IGNORECASE, | |
| ) | |
| if match: | |
| normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}" | |
| fixes.append(normalized) | |
| return fixes | |
| def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str: | |
| """Call the LLM with retry on rate limit.""" | |
| for attempt in range(3): | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.1, | |
| max_tokens=2048, | |
| ) | |
| return response.choices[0].message.content or "" | |
| except Exception as e: | |
| if "rate_limit" in str(e).lower() or "429" in str(e): | |
| wait = 10 * (attempt + 1) | |
| print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True) | |
| time.sleep(wait) | |
| else: | |
| print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True) | |
| return "" | |
| return "" | |
| def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float: | |
| """ | |
| Run a single task with two-phase strategy: | |
| Step 1: Identify issues only | |
| Step 2: Identify + Fix (using feedback from step 1) | |
| Step 3: Refined identify + fix (if needed) | |
| """ | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| best_score = 0.0 | |
| success = False | |
| try: | |
| reset_response = env.reset(task_id=task_id) | |
| observation = reset_response.get("observation", reset_response) | |
| last_issues: list[str] = [] | |
| last_llm_output = "" | |
| for step_num in range(1, MAX_STEPS_PER_TASK + 1): | |
| error_msg = None | |
| # ββ Phase 1: Identify issues ββ | |
| user_prompt = build_user_prompt(observation) | |
| identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt) | |
| issues = parse_llm_response(identify_output) | |
| if not issues and not error_msg: | |
| error_msg = "no issues parsed from LLM response" | |
| # ββ Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ββ | |
| fixes: list[str] = [] | |
| if issues and step_num >= 2: | |
| # Build a fix prompt that includes the identified issues | |
| fix_prompt = build_user_prompt(observation, include_fixes=True) | |
| fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues) | |
| fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt) | |
| fixes = parse_fix_response(fix_output) | |
| # ββ Submit to environment ββ | |
| action_str = ";".join(issues[:5]) if issues else "none" | |
| if fixes: | |
| action_str += "|fixes:" + ";".join(fixes[:3]) | |
| step_response = env.step(issues, fixes, task_id=task_id) | |
| observation = step_response.get("observation", step_response) | |
| reward = float(step_response.get("reward", 0.0) or 0.0) | |
| done = bool(step_response.get("done", False)) | |
| best_score = max(best_score, reward) | |
| rewards.append(reward) | |
| steps_taken = step_num | |
| log_step( | |
| step=step_num, | |
| action=action_str, | |
| reward=reward, | |
| done=done, | |
| error=error_msg, | |
| ) | |
| if done: | |
| break | |
| last_issues = issues | |
| last_llm_output = identify_output | |
| success = best_score >= 0.5 | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards) | |
| return best_score | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True) | |
| print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True) | |
| print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True) | |
| print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True) | |
| env = EnvHTTPClient(ENV_URL) | |
| llm_client = OpenAI( | |
| base_url=API_BASE_URL, | |
| api_key=API_KEY or "no-key", | |
| ) | |
| if not env.health(): | |
| print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True) | |
| sys.exit(1) | |
| print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True) | |
| scores = {} | |
| for task_id in TASKS: | |
| try: | |
| score = run_task(llm_client, env, task_id) | |
| scores[task_id] = score | |
| except Exception as e: | |
| print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True) | |
| scores[task_id] = 0.0 | |
| avg_score = sum(scores.values()) / len(scores) if scores else 0.0 | |
| print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True) | |
| if __name__ == "__main__": | |
| main() | |