Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- inference.py +48 -52
inference.py
CHANGED
|
@@ -6,7 +6,7 @@ MANDATORY COMPLIANCE
|
|
| 6 |
--------------------
|
| 7 |
- Named `inference.py`, placed in project root.
|
| 8 |
- Uses OpenAI client for all LLM calls.
|
| 9 |
-
- Reads: API_BASE_URL, MODEL_NAME,
|
| 10 |
- Emits [START] / [STEP] / [END] lines to stdout in the exact format below.
|
| 11 |
- Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB.
|
| 12 |
|
|
@@ -27,38 +27,39 @@ from typing import List, Optional
|
|
| 27 |
|
| 28 |
from openai import OpenAI
|
| 29 |
|
| 30 |
-
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
-
#
|
| 32 |
-
#
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
#
|
| 36 |
-
#
|
| 37 |
-
#
|
| 38 |
-
#
|
| 39 |
-
|
| 40 |
|
| 41 |
-
#
|
| 42 |
-
# to HF_TOKEN. Both variable names must be checked.
|
| 43 |
-
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY", "")
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
|
|
|
|
|
|
| 47 |
|
| 48 |
BENCHMARK = "nl2sql-bench"
|
| 49 |
MAX_STEPS = 5
|
| 50 |
-
TEMPERATURE = 0.2
|
| 51 |
MAX_TOKENS = 512
|
| 52 |
-
SUCCESS_THRESHOLD = 0.7
|
| 53 |
|
| 54 |
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 55 |
|
| 56 |
-
# ββ Startup diagnostics (stderr β not scored) βββββββββββββββββββββββββββββ
|
| 57 |
-
print(f"[DEBUG] API_BASE_URL = {API_BASE_URL}", file=sys.stderr, flush=True)
|
| 58 |
-
print(f"[DEBUG] MODEL_NAME = {MODEL_NAME}", file=sys.stderr, flush=True)
|
| 59 |
-
print(f"[DEBUG] API_KEY set = {bool(API_KEY)} (len={len(API_KEY)})", file=sys.stderr, flush=True)
|
| 60 |
-
print(f"[DEBUG] SPACE_URL = {SPACE_URL}", file=sys.stderr, flush=True)
|
| 61 |
-
|
| 62 |
# ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
SYSTEM_PROMPT = textwrap.dedent("""
|
| 64 |
You are an expert SQL analyst working with a SQLite e-commerce database.
|
|
@@ -94,6 +95,7 @@ def log_start(task: str, model: str) -> None:
|
|
| 94 |
def log_step(
|
| 95 |
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 96 |
) -> None:
|
|
|
|
| 97 |
action_single = " ".join(action.split())
|
| 98 |
error_val = error.replace("\n", " ") if error else "null"
|
| 99 |
print(
|
|
@@ -146,14 +148,7 @@ def build_user_prompt(
|
|
| 146 |
|
| 147 |
|
| 148 |
def call_llm(client: OpenAI, user_prompt: str) -> str:
|
| 149 |
-
# CRITICAL: Do NOT silently swallow exceptions with a bare `except Exception`.
|
| 150 |
-
# Silent failure means inference.py "succeeds" but makes zero LLM API calls,
|
| 151 |
-
# which causes the competition's LLM Criteria Check to fail.
|
| 152 |
try:
|
| 153 |
-
print(
|
| 154 |
-
f"[DEBUG] Calling LLM: model={MODEL_NAME} base_url={API_BASE_URL}",
|
| 155 |
-
file=sys.stderr, flush=True
|
| 156 |
-
)
|
| 157 |
resp = client.chat.completions.create(
|
| 158 |
model=MODEL_NAME,
|
| 159 |
messages=[
|
|
@@ -165,7 +160,6 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
|
|
| 165 |
stream=False,
|
| 166 |
)
|
| 167 |
text = (resp.choices[0].message.content or "").strip()
|
| 168 |
-
print(f"[DEBUG] LLM raw response (first 120 chars): {text[:120]}", file=sys.stderr, flush=True)
|
| 169 |
# Strip markdown code fences if model wraps in ```sql ... ```
|
| 170 |
if text.startswith("```"):
|
| 171 |
lines = text.split("\n")
|
|
@@ -175,11 +169,8 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
|
|
| 175 |
).strip()
|
| 176 |
return text if text else "SELECT 1"
|
| 177 |
except Exception as exc:
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
# Re-raise so the episode is marked failed, not silently scored as 0.
|
| 181 |
-
# A visible failure is better than a silent one that breaks the LLM check.
|
| 182 |
-
raise
|
| 183 |
|
| 184 |
|
| 185 |
# ββ Single-task episode ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -194,7 +185,11 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
|
| 194 |
log_start(task_name, MODEL_NAME)
|
| 195 |
|
| 196 |
try:
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
obs = result.observation
|
| 199 |
|
| 200 |
for step in range(1, MAX_STEPS + 1):
|
|
@@ -213,7 +208,7 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
|
| 213 |
|
| 214 |
sql = call_llm(client, user_prompt)
|
| 215 |
|
| 216 |
-
from models import NL2SQLAction
|
| 217 |
action = NL2SQLAction(query=sql)
|
| 218 |
result = await env.step(action)
|
| 219 |
obs = result.observation
|
|
@@ -230,12 +225,15 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
|
| 230 |
if done:
|
| 231 |
break
|
| 232 |
|
| 233 |
-
|
| 234 |
-
score
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
except Exception as exc:
|
| 238 |
-
print(f"[DEBUG] Episode error for {task_name}: {
|
| 239 |
finally:
|
| 240 |
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 241 |
|
|
@@ -245,21 +243,17 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
|
| 245 |
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 246 |
|
| 247 |
async def main() -> None:
|
| 248 |
-
# Validate that API_KEY is present β fail fast with a clear message
|
| 249 |
-
if not API_KEY:
|
| 250 |
-
print(
|
| 251 |
-
"[ERROR] No API key found. Set API_KEY or HF_TOKEN environment variable.",
|
| 252 |
-
file=sys.stderr, flush=True
|
| 253 |
-
)
|
| 254 |
-
sys.exit(1)
|
| 255 |
-
|
| 256 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 257 |
|
|
|
|
| 258 |
from client import NL2SQLEnv
|
| 259 |
|
| 260 |
all_results = []
|
| 261 |
|
| 262 |
for task_name in TASKS:
|
|
|
|
|
|
|
|
|
|
| 263 |
os.environ["NL2SQL_DEFAULT_TASK"] = task_name
|
| 264 |
|
| 265 |
try:
|
|
@@ -268,18 +262,20 @@ async def main() -> None:
|
|
| 268 |
all_results.append(result)
|
| 269 |
except Exception as exc:
|
| 270 |
print(
|
| 271 |
-
f"[DEBUG] Failed to connect for task {task_name}: {
|
| 272 |
file=sys.stderr,
|
| 273 |
flush=True,
|
| 274 |
)
|
|
|
|
| 275 |
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 276 |
all_results.append({"task": task_name, "success": False, "score": 0.0})
|
| 277 |
|
| 278 |
-
# Summary to stderr
|
| 279 |
print("\n=== Baseline Summary ===", file=sys.stderr)
|
| 280 |
for r in all_results:
|
| 281 |
print(
|
| 282 |
-
f" {r['task']:20s} score={r['score']:.3f}
|
|
|
|
| 283 |
file=sys.stderr,
|
| 284 |
)
|
| 285 |
avg = sum(r["score"] for r in all_results) / max(len(all_results), 1)
|
|
|
|
| 6 |
--------------------
|
| 7 |
- Named `inference.py`, placed in project root.
|
| 8 |
- Uses OpenAI client for all LLM calls.
|
| 9 |
+
- Reads: API_BASE_URL, MODEL_NAME, HF_TOKEN from environment.
|
| 10 |
- Emits [START] / [STEP] / [END] lines to stdout in the exact format below.
|
| 11 |
- Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB.
|
| 12 |
|
|
|
|
| 27 |
|
| 28 |
from openai import OpenAI
|
| 29 |
|
| 30 |
+
# # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
# API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
+
# MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
| 33 |
+
# API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
|
| 34 |
+
# IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 35 |
+
# SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
|
| 36 |
|
| 37 |
+
# BENCHMARK = "nl2sql-bench"
|
| 38 |
+
# MAX_STEPS = 5
|
| 39 |
+
# TEMPERATURE = 0.2 # Low temp for SQL generation
|
| 40 |
+
# MAX_TOKENS = 512
|
| 41 |
+
# SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β success
|
| 42 |
|
| 43 |
+
# TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 47 |
+
# Points to your newly uploaded fine-tuned weights!
|
| 48 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "ritvik360/qwen-7b-nl2sql-merged_1")
|
| 49 |
+
# CRITICAL FIX: Looks for 'API_KEY' first to satisfy the evaluator's LiteLLM proxy
|
| 50 |
+
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY")
|
| 51 |
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 52 |
+
# CRITICAL FIX: Point the default directly to your live HF Space!
|
| 53 |
+
SPACE_URL = os.getenv("SPACE_URL", "https://ritvik360-nl2sql-bench.hf.space")
|
| 54 |
|
| 55 |
BENCHMARK = "nl2sql-bench"
|
| 56 |
MAX_STEPS = 5
|
| 57 |
+
TEMPERATURE = 0.2 # Low temp for SQL generation
|
| 58 |
MAX_TOKENS = 512
|
| 59 |
+
SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β success
|
| 60 |
|
| 61 |
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
# ββ System prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
SYSTEM_PROMPT = textwrap.dedent("""
|
| 65 |
You are an expert SQL analyst working with a SQLite e-commerce database.
|
|
|
|
| 95 |
def log_step(
|
| 96 |
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 97 |
) -> None:
|
| 98 |
+
# Collapse multi-line SQL to single line for log compliance
|
| 99 |
action_single = " ".join(action.split())
|
| 100 |
error_val = error.replace("\n", " ") if error else "null"
|
| 101 |
print(
|
|
|
|
| 148 |
|
| 149 |
|
| 150 |
def call_llm(client: OpenAI, user_prompt: str) -> str:
|
|
|
|
|
|
|
|
|
|
| 151 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
resp = client.chat.completions.create(
|
| 153 |
model=MODEL_NAME,
|
| 154 |
messages=[
|
|
|
|
| 160 |
stream=False,
|
| 161 |
)
|
| 162 |
text = (resp.choices[0].message.content or "").strip()
|
|
|
|
| 163 |
# Strip markdown code fences if model wraps in ```sql ... ```
|
| 164 |
if text.startswith("```"):
|
| 165 |
lines = text.split("\n")
|
|
|
|
| 169 |
).strip()
|
| 170 |
return text if text else "SELECT 1"
|
| 171 |
except Exception as exc:
|
| 172 |
+
print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True)
|
| 173 |
+
return "SELECT 1"
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
# ββ Single-task episode ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 185 |
log_start(task_name, MODEL_NAME)
|
| 186 |
|
| 187 |
try:
|
| 188 |
+
# Reset β pass task_name via action payload or query param
|
| 189 |
+
# OpenEnv reset() may not accept task args via HTTP; we rely on
|
| 190 |
+
# NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
|
| 191 |
+
# pass it as a reset parameter if the server supports it.
|
| 192 |
+
result = await env.reset() # changed
|
| 193 |
obs = result.observation
|
| 194 |
|
| 195 |
for step in range(1, MAX_STEPS + 1):
|
|
|
|
| 208 |
|
| 209 |
sql = call_llm(client, user_prompt)
|
| 210 |
|
| 211 |
+
from models import NL2SQLAction # local to avoid circular at module level
|
| 212 |
action = NL2SQLAction(query=sql)
|
| 213 |
result = await env.step(action)
|
| 214 |
obs = result.observation
|
|
|
|
| 225 |
if done:
|
| 226 |
break
|
| 227 |
|
| 228 |
+
# Compute final score
|
| 229 |
+
# CRITICAL: Evaluator requires score strictly in (0, 1) β not 0.0, not 1.0.
|
| 230 |
+
# A perfect solve gives 1.0 β clamp to 0.999. All-fail gives 0.0 β clamp to 0.001.
|
| 231 |
+
raw_score = sum(rewards) / max(len(rewards), 1)
|
| 232 |
+
score = round(min(max(raw_score, 0.001), 0.999), 4)
|
| 233 |
+
success = raw_score >= SUCCESS_THRESHOLD
|
| 234 |
|
| 235 |
except Exception as exc:
|
| 236 |
+
print(f"[DEBUG] Episode error for {task_name}: {exc}", file=sys.stderr, flush=True)
|
| 237 |
finally:
|
| 238 |
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 239 |
|
|
|
|
| 243 |
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 244 |
|
| 245 |
async def main() -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 247 |
|
| 248 |
+
# Import here to avoid import errors if openenv not installed during lint
|
| 249 |
from client import NL2SQLEnv
|
| 250 |
|
| 251 |
all_results = []
|
| 252 |
|
| 253 |
for task_name in TASKS:
|
| 254 |
+
# Set the default task for the server session via env-var approach.
|
| 255 |
+
# For the hosted Space, we rely on the task cycling implemented in
|
| 256 |
+
# the task registry's round-robin iterator.
|
| 257 |
os.environ["NL2SQL_DEFAULT_TASK"] = task_name
|
| 258 |
|
| 259 |
try:
|
|
|
|
| 262 |
all_results.append(result)
|
| 263 |
except Exception as exc:
|
| 264 |
print(
|
| 265 |
+
f"[DEBUG] Failed to connect for task {task_name}: {exc}",
|
| 266 |
file=sys.stderr,
|
| 267 |
flush=True,
|
| 268 |
)
|
| 269 |
+
# Emit a zero-score END to keep log format valid
|
| 270 |
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 271 |
all_results.append({"task": task_name, "success": False, "score": 0.0})
|
| 272 |
|
| 273 |
+
# Summary to stderr (not scored, for human readability)
|
| 274 |
print("\n=== Baseline Summary ===", file=sys.stderr)
|
| 275 |
for r in all_results:
|
| 276 |
print(
|
| 277 |
+
f" {r['task']:20s} score={r['score']:.3f} "
|
| 278 |
+
f"success={r['success']}",
|
| 279 |
file=sys.stderr,
|
| 280 |
)
|
| 281 |
avg = sum(r["score"] for r in all_results) / max(len(all_results), 1)
|