DataDetective / inference.py
Viani's picture
Fix HF_TOKEN: no default value per hackathon checklist
08e8825 verified
#!/usr/bin/env python3
"""
Baseline inference script for DataDetective.
Uses an LLM via the OpenAI-compatible API to investigate each task by
running SQL queries and submitting a final analysis.
Required environment variables:
API_BASE_URL — LLM endpoint (e.g. https://router.huggingface.co/v1)
MODEL_NAME — model identifier (e.g. gpt-4.1-mini)
HF_TOKEN — API key / Hugging Face token
Optional:
ENV_URL — DataDetective server URL (default http://localhost:7860)
"""
import asyncio
import json
import os
import re
import sys
import time
from openai import OpenAI
import websockets.asyncio.client as _wsc
_orig_ws_connect = _wsc.connect
def _patched_connect(*a, **kw):
kw.setdefault("ping_interval", 300)
kw.setdefault("ping_timeout", 300)
return _orig_ws_connect(*a, **kw)
_wsc.connect = _patched_connect
import openenv.core.env_client as _ec
_ec.ws_connect = _patched_connect
from openenv.core.generic_client import GenericEnvClient
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4.1-mini")
HF_TOKEN = os.environ.get("HF_TOKEN")
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
BENCHMARK = "data_detective"
MAX_STEPS = 20
TASK_IDS = [
"orders_drop",
"returns_spike",
"customer_churn",
"shipping_delay",
"revenue_paradox",
"supplier_quality",
"inventory_stockout",
"fraud_detection",
"repeat_purchase_decline",
]
def _build_llm_client() -> OpenAI:
if not HF_TOKEN:
print(
"ERROR: Set HF_TOKEN for LLM access. Exiting.",
file=sys.stderr,
)
sys.exit(1)
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
llm = _build_llm_client()
SYSTEM_PROMPT = """\
You are an expert data analyst investigating a business incident using a
SQL database. You have a LIMITED number of query steps, so be strategic.
At each turn respond with EXACTLY one JSON object (no extra text):
{{"action_type": "query", "content": "<SQL query>"}}
{{"action_type": "answer", "content": "<your analysis>"}}
Investigation strategy:
1. EXPLORE (1-2 queries): List tables and sample key columns to understand
the schema. Note all available tables -- some may hold critical clues.
2. HYPOTHESISE: Based on the task description, form 2-3 likely root causes.
3. QUERY (targeted): Run focused queries that confirm or reject each
hypothesis. Use JOINs across tables, GROUP BY with aggregates, and
compare time periods. Avoid broad SELECT * scans.
4. QUANTIFY: For every finding, gather specific numbers -- counts, totals,
percentages, before/after comparisons.
5. ANSWER: Submit a thorough analysis naming every root cause with
supporting evidence. Include specific product names, regions, customer
segments, suppliers, dollar amounts, dates, and percentages.
You have {max_steps} steps total. Budget roughly 70 % for querying and
reserve the last few steps for your answer. Do NOT run out of steps
without submitting -- partial evidence is better than none.
"""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _extract_json(text: str) -> dict:
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
if m:
try:
return json.loads(m.group(1))
except json.JSONDecodeError:
pass
m = re.search(r"\{[^{}]*\}", text, re.DOTALL)
if m:
try:
return json.loads(m.group(0))
except json.JSONDecodeError:
pass
return {"action_type": "answer", "content": text}
def _log_start(task_id: str) -> None:
print(
f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}",
flush=True,
)
def _log_step(step: int, action: dict, reward: float, done: bool, error: str | None) -> None:
action_str = json.dumps(action, separators=(",", ":"))
error_val = f"'{error}'" if error else "null"
print(
f"[STEP] step={step} action={action_str} "
f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
flush=True,
)
def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
async def run_task(task_id: str) -> float:
_log_start(task_id)
rewards: list[float] = []
step = 0
reward = 0.0
done = False
success = False
error_msg = None
try:
async with GenericEnvClient(base_url=ENV_URL) as env:
result = await env.reset(task_id=task_id)
obs = result.observation
system = SYSTEM_PROMPT.format(max_steps=MAX_STEPS)
messages = [
{"role": "system", "content": system},
{
"role": "user",
"content": (
f"## Investigation Task\n{obs.get('task_description', '')}\n\n"
f"## Database\n{obs.get('schema_info', '')}\n\n"
f"You have {MAX_STEPS} steps. Begin your investigation."
),
},
]
while not done and step < MAX_STEPS:
try:
completion = llm.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=0.1,
max_completion_tokens=1024,
)
llm_text = completion.choices[0].message.content or ""
except Exception as exc:
llm_text = json.dumps({
"action_type": "answer",
"content": "Unable to complete analysis due to LLM error.",
})
error_msg = str(exc)
action = _extract_json(llm_text)
if "action_type" not in action:
action["action_type"] = "query"
if "content" not in action:
action["content"] = llm_text
result = await env.step(action)
step += 1
done = result.done
reward = result.reward or 0.0
rewards.append(reward)
result_obs = result.observation
remaining = MAX_STEPS - step
_log_step(step, action, reward, done, error_msg)
error_msg = None
messages.append({"role": "assistant", "content": llm_text})
if not done and remaining <= 3:
urgency = (
f"URGENT: Only {remaining} step(s) left! "
"You MUST submit your final answer NOW using "
'{"action_type": "answer", "content": "..."}. '
"Summarize ALL findings so far."
)
else:
urgency = "Continue investigating or submit your final answer."
messages.append({
"role": "user",
"content": (
f"Query result:\n{result_obs.get('output', '')}\n\n"
f"{result_obs.get('message', '')}\n\n"
f"[Step {step}/{MAX_STEPS}] {urgency}"
),
})
success = done and reward > 0.0
except Exception as exc:
error_msg = str(exc)
_log_step(step + 1, {"action_type": "error"}, 0.0, False, error_msg)
score = reward if done else 0.0
_log_end(success=success, steps=step, score=score, rewards=rewards)
return score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def amain():
total = 0.0
for tid in TASK_IDS:
try:
r = await run_task(tid)
except Exception as exc:
print(f"[END] success=false steps=0 score=0.000 rewards=", flush=True)
r = 0.0
total += r
avg = total / len(TASK_IDS) if TASK_IDS else 0
print(f"\n=== Overall average score: {avg:.2f} ===", flush=True)
def main():
asyncio.run(amain())
if __name__ == "__main__":
main()