#!/usr/bin/env python3 """ Baseline inference script for DataDetective. Uses an LLM via the OpenAI-compatible API to investigate each task by running SQL queries and submitting a final analysis. Required environment variables: API_BASE_URL — LLM endpoint (e.g. https://router.huggingface.co/v1) MODEL_NAME — model identifier (e.g. gpt-4.1-mini) HF_TOKEN — API key / Hugging Face token Optional: ENV_URL — DataDetective server URL (default http://localhost:7860) """ import asyncio import json import os import re import sys import time from openai import OpenAI import websockets.asyncio.client as _wsc _orig_ws_connect = _wsc.connect def _patched_connect(*a, **kw): kw.setdefault("ping_interval", 300) kw.setdefault("ping_timeout", 300) return _orig_ws_connect(*a, **kw) _wsc.connect = _patched_connect import openenv.core.env_client as _ec _ec.ws_connect = _patched_connect from openenv.core.generic_client import GenericEnvClient # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4.1-mini") HF_TOKEN = os.environ.get("HF_TOKEN") ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/") BENCHMARK = "data_detective" MAX_STEPS = 20 TASK_IDS = [ "orders_drop", "returns_spike", "customer_churn", "shipping_delay", "revenue_paradox", "supplier_quality", "inventory_stockout", "fraud_detection", "repeat_purchase_decline", ] def _build_llm_client() -> OpenAI: if not HF_TOKEN: print( "ERROR: Set HF_TOKEN for LLM access. Exiting.", file=sys.stderr, ) sys.exit(1) return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) llm = _build_llm_client() SYSTEM_PROMPT = """\ You are an expert data analyst investigating a business incident using a SQL database. You have a LIMITED number of query steps, so be strategic. At each turn respond with EXACTLY one JSON object (no extra text): {{"action_type": "query", "content": ""}} {{"action_type": "answer", "content": ""}} Investigation strategy: 1. EXPLORE (1-2 queries): List tables and sample key columns to understand the schema. Note all available tables -- some may hold critical clues. 2. HYPOTHESISE: Based on the task description, form 2-3 likely root causes. 3. QUERY (targeted): Run focused queries that confirm or reject each hypothesis. Use JOINs across tables, GROUP BY with aggregates, and compare time periods. Avoid broad SELECT * scans. 4. QUANTIFY: For every finding, gather specific numbers -- counts, totals, percentages, before/after comparisons. 5. ANSWER: Submit a thorough analysis naming every root cause with supporting evidence. Include specific product names, regions, customer segments, suppliers, dollar amounts, dates, and percentages. You have {max_steps} steps total. Budget roughly 70 % for querying and reserve the last few steps for your answer. Do NOT run out of steps without submitting -- partial evidence is better than none. """ # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _extract_json(text: str) -> dict: text = text.strip() try: return json.loads(text) except json.JSONDecodeError: pass m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) if m: try: return json.loads(m.group(1)) except json.JSONDecodeError: pass m = re.search(r"\{[^{}]*\}", text, re.DOTALL) if m: try: return json.loads(m.group(0)) except json.JSONDecodeError: pass return {"action_type": "answer", "content": text} def _log_start(task_id: str) -> None: print( f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", flush=True, ) def _log_step(step: int, action: dict, reward: float, done: bool, error: str | None) -> None: action_str = json.dumps(action, separators=(",", ":")) error_val = f"'{error}'" if error else "null" print( f"[STEP] step={step} action={action_str} " f"reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True, ) def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} " f"score={score:.3f} rewards={rewards_str}", flush=True, ) async def run_task(task_id: str) -> float: _log_start(task_id) rewards: list[float] = [] step = 0 reward = 0.0 done = False success = False error_msg = None try: async with GenericEnvClient(base_url=ENV_URL) as env: result = await env.reset(task_id=task_id) obs = result.observation system = SYSTEM_PROMPT.format(max_steps=MAX_STEPS) messages = [ {"role": "system", "content": system}, { "role": "user", "content": ( f"## Investigation Task\n{obs.get('task_description', '')}\n\n" f"## Database\n{obs.get('schema_info', '')}\n\n" f"You have {MAX_STEPS} steps. Begin your investigation." ), }, ] while not done and step < MAX_STEPS: try: completion = llm.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=0.1, max_completion_tokens=1024, ) llm_text = completion.choices[0].message.content or "" except Exception as exc: llm_text = json.dumps({ "action_type": "answer", "content": "Unable to complete analysis due to LLM error.", }) error_msg = str(exc) action = _extract_json(llm_text) if "action_type" not in action: action["action_type"] = "query" if "content" not in action: action["content"] = llm_text result = await env.step(action) step += 1 done = result.done reward = result.reward or 0.0 rewards.append(reward) result_obs = result.observation remaining = MAX_STEPS - step _log_step(step, action, reward, done, error_msg) error_msg = None messages.append({"role": "assistant", "content": llm_text}) if not done and remaining <= 3: urgency = ( f"URGENT: Only {remaining} step(s) left! " "You MUST submit your final answer NOW using " '{"action_type": "answer", "content": "..."}. ' "Summarize ALL findings so far." ) else: urgency = "Continue investigating or submit your final answer." messages.append({ "role": "user", "content": ( f"Query result:\n{result_obs.get('output', '')}\n\n" f"{result_obs.get('message', '')}\n\n" f"[Step {step}/{MAX_STEPS}] {urgency}" ), }) success = done and reward > 0.0 except Exception as exc: error_msg = str(exc) _log_step(step + 1, {"action_type": "error"}, 0.0, False, error_msg) score = reward if done else 0.0 _log_end(success=success, steps=step, score=score, rewards=rewards) return score # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- async def amain(): total = 0.0 for tid in TASK_IDS: try: r = await run_task(tid) except Exception as exc: print(f"[END] success=false steps=0 score=0.000 rewards=", flush=True) r = 0.0 total += r avg = total / len(TASK_IDS) if TASK_IDS else 0 print(f"\n=== Overall average score: {avg:.2f} ===", flush=True) def main(): asyncio.run(amain()) if __name__ == "__main__": main()