Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Baseline inference script for DataDetective. | |
| Uses an LLM via the OpenAI-compatible API to investigate each task by | |
| running SQL queries and submitting a final analysis. | |
| Required environment variables: | |
| API_BASE_URL — LLM endpoint (e.g. https://router.huggingface.co/v1) | |
| MODEL_NAME — model identifier (e.g. gpt-4.1-mini) | |
| HF_TOKEN — API key / Hugging Face token | |
| Optional: | |
| ENV_URL — DataDetective server URL (default http://localhost:7860) | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from openai import OpenAI | |
| import websockets.asyncio.client as _wsc | |
| _orig_ws_connect = _wsc.connect | |
| def _patched_connect(*a, **kw): | |
| kw.setdefault("ping_interval", 300) | |
| kw.setdefault("ping_timeout", 300) | |
| return _orig_ws_connect(*a, **kw) | |
| _wsc.connect = _patched_connect | |
| import openenv.core.env_client as _ec | |
| _ec.ws_connect = _patched_connect | |
| from openenv.core.generic_client import GenericEnvClient | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4.1-mini") | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/") | |
| BENCHMARK = "data_detective" | |
| MAX_STEPS = 20 | |
| TASK_IDS = [ | |
| "orders_drop", | |
| "returns_spike", | |
| "customer_churn", | |
| "shipping_delay", | |
| "revenue_paradox", | |
| "supplier_quality", | |
| "inventory_stockout", | |
| "fraud_detection", | |
| "repeat_purchase_decline", | |
| ] | |
| def _build_llm_client() -> OpenAI: | |
| if not HF_TOKEN: | |
| print( | |
| "ERROR: Set HF_TOKEN for LLM access. Exiting.", | |
| file=sys.stderr, | |
| ) | |
| sys.exit(1) | |
| return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| llm = _build_llm_client() | |
| SYSTEM_PROMPT = """\ | |
| You are an expert data analyst investigating a business incident using a | |
| SQL database. You have a LIMITED number of query steps, so be strategic. | |
| At each turn respond with EXACTLY one JSON object (no extra text): | |
| {{"action_type": "query", "content": "<SQL query>"}} | |
| {{"action_type": "answer", "content": "<your analysis>"}} | |
| Investigation strategy: | |
| 1. EXPLORE (1-2 queries): List tables and sample key columns to understand | |
| the schema. Note all available tables -- some may hold critical clues. | |
| 2. HYPOTHESISE: Based on the task description, form 2-3 likely root causes. | |
| 3. QUERY (targeted): Run focused queries that confirm or reject each | |
| hypothesis. Use JOINs across tables, GROUP BY with aggregates, and | |
| compare time periods. Avoid broad SELECT * scans. | |
| 4. QUANTIFY: For every finding, gather specific numbers -- counts, totals, | |
| percentages, before/after comparisons. | |
| 5. ANSWER: Submit a thorough analysis naming every root cause with | |
| supporting evidence. Include specific product names, regions, customer | |
| segments, suppliers, dollar amounts, dates, and percentages. | |
| You have {max_steps} steps total. Budget roughly 70 % for querying and | |
| reserve the last few steps for your answer. Do NOT run out of steps | |
| without submitting -- partial evidence is better than none. | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _extract_json(text: str) -> dict: | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) | |
| if m: | |
| try: | |
| return json.loads(m.group(1)) | |
| except json.JSONDecodeError: | |
| pass | |
| m = re.search(r"\{[^{}]*\}", text, re.DOTALL) | |
| if m: | |
| try: | |
| return json.loads(m.group(0)) | |
| except json.JSONDecodeError: | |
| pass | |
| return {"action_type": "answer", "content": text} | |
| def _log_start(task_id: str) -> None: | |
| print( | |
| f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}", | |
| flush=True, | |
| ) | |
| def _log_step(step: int, action: dict, reward: float, done: bool, error: str | None) -> None: | |
| action_str = json.dumps(action, separators=(",", ":")) | |
| error_val = f"'{error}'" if error else "null" | |
| print( | |
| f"[STEP] step={step} action={action_str} " | |
| f"reward={reward:.2f} done={str(done).lower()} error={error_val}", | |
| flush=True, | |
| ) | |
| def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} " | |
| f"score={score:.3f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| async def run_task(task_id: str) -> float: | |
| _log_start(task_id) | |
| rewards: list[float] = [] | |
| step = 0 | |
| reward = 0.0 | |
| done = False | |
| success = False | |
| error_msg = None | |
| try: | |
| async with GenericEnvClient(base_url=ENV_URL) as env: | |
| result = await env.reset(task_id=task_id) | |
| obs = result.observation | |
| system = SYSTEM_PROMPT.format(max_steps=MAX_STEPS) | |
| messages = [ | |
| {"role": "system", "content": system}, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"## Investigation Task\n{obs.get('task_description', '')}\n\n" | |
| f"## Database\n{obs.get('schema_info', '')}\n\n" | |
| f"You have {MAX_STEPS} steps. Begin your investigation." | |
| ), | |
| }, | |
| ] | |
| while not done and step < MAX_STEPS: | |
| try: | |
| completion = llm.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=0.1, | |
| max_completion_tokens=1024, | |
| ) | |
| llm_text = completion.choices[0].message.content or "" | |
| except Exception as exc: | |
| llm_text = json.dumps({ | |
| "action_type": "answer", | |
| "content": "Unable to complete analysis due to LLM error.", | |
| }) | |
| error_msg = str(exc) | |
| action = _extract_json(llm_text) | |
| if "action_type" not in action: | |
| action["action_type"] = "query" | |
| if "content" not in action: | |
| action["content"] = llm_text | |
| result = await env.step(action) | |
| step += 1 | |
| done = result.done | |
| reward = result.reward or 0.0 | |
| rewards.append(reward) | |
| result_obs = result.observation | |
| remaining = MAX_STEPS - step | |
| _log_step(step, action, reward, done, error_msg) | |
| error_msg = None | |
| messages.append({"role": "assistant", "content": llm_text}) | |
| if not done and remaining <= 3: | |
| urgency = ( | |
| f"URGENT: Only {remaining} step(s) left! " | |
| "You MUST submit your final answer NOW using " | |
| '{"action_type": "answer", "content": "..."}. ' | |
| "Summarize ALL findings so far." | |
| ) | |
| else: | |
| urgency = "Continue investigating or submit your final answer." | |
| messages.append({ | |
| "role": "user", | |
| "content": ( | |
| f"Query result:\n{result_obs.get('output', '')}\n\n" | |
| f"{result_obs.get('message', '')}\n\n" | |
| f"[Step {step}/{MAX_STEPS}] {urgency}" | |
| ), | |
| }) | |
| success = done and reward > 0.0 | |
| except Exception as exc: | |
| error_msg = str(exc) | |
| _log_step(step + 1, {"action_type": "error"}, 0.0, False, error_msg) | |
| score = reward if done else 0.0 | |
| _log_end(success=success, steps=step, score=score, rewards=rewards) | |
| return score | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| async def amain(): | |
| total = 0.0 | |
| for tid in TASK_IDS: | |
| try: | |
| r = await run_task(tid) | |
| except Exception as exc: | |
| print(f"[END] success=false steps=0 score=0.000 rewards=", flush=True) | |
| r = 0.0 | |
| total += r | |
| avg = total / len(TASK_IDS) if TASK_IDS else 0 | |
| print(f"\n=== Overall average score: {avg:.2f} ===", flush=True) | |
| def main(): | |
| asyncio.run(amain()) | |
| if __name__ == "__main__": | |
| main() | |