ritvik360 commited on
Commit
daab6a1
Β·
verified Β·
1 Parent(s): 46e0615

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +48 -52
inference.py CHANGED
@@ -6,7 +6,7 @@ MANDATORY COMPLIANCE
6
  --------------------
7
  - Named `inference.py`, placed in project root.
8
  - Uses OpenAI client for all LLM calls.
9
- - Reads: API_BASE_URL, MODEL_NAME, API_KEY (+ HF_TOKEN fallback) from environment.
10
  - Emits [START] / [STEP] / [END] lines to stdout in the exact format below.
11
  - Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB.
12
 
@@ -27,38 +27,39 @@ from typing import List, Optional
27
 
28
  from openai import OpenAI
29
 
30
- # ── Configuration ──────────────────────────────────────────────────────────
31
- # CRITICAL: API_BASE_URL and API_KEY are injected by the competition evaluator.
32
- # Do NOT hardcode values. The evaluator injects their LiteLLM proxy URL + key.
33
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
 
 
34
 
35
- # CRITICAL FIX: Default MODEL_NAME must be a model available on the HF router /
36
- # the competition's LiteLLM proxy. "ritvik360/qwen-7b-nl2sql-merged_1" is NOT
37
- # on their proxy β€” it would silently fail and produce SELECT 1 for all steps.
38
- # The competition injects MODEL_NAME if they want to override it.
39
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
40
 
41
- # CRITICAL FIX: Read API_KEY first (competition injects this), then fall back
42
- # to HF_TOKEN. Both variable names must be checked.
43
- API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY", "")
44
 
45
- SPACE_URL = os.getenv("SPACE_URL", "https://ritvik360-nl2sql-bench.hf.space")
 
 
 
 
 
46
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
 
 
47
 
48
  BENCHMARK = "nl2sql-bench"
49
  MAX_STEPS = 5
50
- TEMPERATURE = 0.2
51
  MAX_TOKENS = 512
52
- SUCCESS_THRESHOLD = 0.7
53
 
54
  TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
55
 
56
- # ── Startup diagnostics (stderr β€” not scored) ─────────────────────────────
57
- print(f"[DEBUG] API_BASE_URL = {API_BASE_URL}", file=sys.stderr, flush=True)
58
- print(f"[DEBUG] MODEL_NAME = {MODEL_NAME}", file=sys.stderr, flush=True)
59
- print(f"[DEBUG] API_KEY set = {bool(API_KEY)} (len={len(API_KEY)})", file=sys.stderr, flush=True)
60
- print(f"[DEBUG] SPACE_URL = {SPACE_URL}", file=sys.stderr, flush=True)
61
-
62
  # ── System prompt ──────────────────────────────────────────────────────────
63
  SYSTEM_PROMPT = textwrap.dedent("""
64
  You are an expert SQL analyst working with a SQLite e-commerce database.
@@ -94,6 +95,7 @@ def log_start(task: str, model: str) -> None:
94
  def log_step(
95
  step: int, action: str, reward: float, done: bool, error: Optional[str]
96
  ) -> None:
 
97
  action_single = " ".join(action.split())
98
  error_val = error.replace("\n", " ") if error else "null"
99
  print(
@@ -146,14 +148,7 @@ def build_user_prompt(
146
 
147
 
148
  def call_llm(client: OpenAI, user_prompt: str) -> str:
149
- # CRITICAL: Do NOT silently swallow exceptions with a bare `except Exception`.
150
- # Silent failure means inference.py "succeeds" but makes zero LLM API calls,
151
- # which causes the competition's LLM Criteria Check to fail.
152
  try:
153
- print(
154
- f"[DEBUG] Calling LLM: model={MODEL_NAME} base_url={API_BASE_URL}",
155
- file=sys.stderr, flush=True
156
- )
157
  resp = client.chat.completions.create(
158
  model=MODEL_NAME,
159
  messages=[
@@ -165,7 +160,6 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
165
  stream=False,
166
  )
167
  text = (resp.choices[0].message.content or "").strip()
168
- print(f"[DEBUG] LLM raw response (first 120 chars): {text[:120]}", file=sys.stderr, flush=True)
169
  # Strip markdown code fences if model wraps in ```sql ... ```
170
  if text.startswith("```"):
171
  lines = text.split("\n")
@@ -175,11 +169,8 @@ def call_llm(client: OpenAI, user_prompt: str) -> str:
175
  ).strip()
176
  return text if text else "SELECT 1"
177
  except Exception as exc:
178
- # Log the full error β€” this is the signal that tells you what went wrong
179
- print(f"[DEBUG] LLM call FAILED: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
180
- # Re-raise so the episode is marked failed, not silently scored as 0.
181
- # A visible failure is better than a silent one that breaks the LLM check.
182
- raise
183
 
184
 
185
  # ── Single-task episode ────────────────────────────────────────────────────
@@ -194,7 +185,11 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
194
  log_start(task_name, MODEL_NAME)
195
 
196
  try:
197
- result = await env.reset()
 
 
 
 
198
  obs = result.observation
199
 
200
  for step in range(1, MAX_STEPS + 1):
@@ -213,7 +208,7 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
213
 
214
  sql = call_llm(client, user_prompt)
215
 
216
- from models import NL2SQLAction
217
  action = NL2SQLAction(query=sql)
218
  result = await env.step(action)
219
  obs = result.observation
@@ -230,12 +225,15 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
230
  if done:
231
  break
232
 
233
- score = sum(rewards) / max(len(rewards), 1)
234
- score = round(min(max(score, 0.0), 1.0), 4)
235
- success = score >= SUCCESS_THRESHOLD
 
 
 
236
 
237
  except Exception as exc:
238
- print(f"[DEBUG] Episode error for {task_name}: {type(exc).__name__}: {exc}", file=sys.stderr, flush=True)
239
  finally:
240
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
241
 
@@ -245,21 +243,17 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
245
  # ── Main ───────────────────────────────────────────────────────────────────
246
 
247
  async def main() -> None:
248
- # Validate that API_KEY is present β€” fail fast with a clear message
249
- if not API_KEY:
250
- print(
251
- "[ERROR] No API key found. Set API_KEY or HF_TOKEN environment variable.",
252
- file=sys.stderr, flush=True
253
- )
254
- sys.exit(1)
255
-
256
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
257
 
 
258
  from client import NL2SQLEnv
259
 
260
  all_results = []
261
 
262
  for task_name in TASKS:
 
 
 
263
  os.environ["NL2SQL_DEFAULT_TASK"] = task_name
264
 
265
  try:
@@ -268,18 +262,20 @@ async def main() -> None:
268
  all_results.append(result)
269
  except Exception as exc:
270
  print(
271
- f"[DEBUG] Failed to connect for task {task_name}: {type(exc).__name__}: {exc}",
272
  file=sys.stderr,
273
  flush=True,
274
  )
 
275
  log_end(success=False, steps=0, score=0.0, rewards=[])
276
  all_results.append({"task": task_name, "success": False, "score": 0.0})
277
 
278
- # Summary to stderr
279
  print("\n=== Baseline Summary ===", file=sys.stderr)
280
  for r in all_results:
281
  print(
282
- f" {r['task']:20s} score={r['score']:.3f} success={r['success']}",
 
283
  file=sys.stderr,
284
  )
285
  avg = sum(r["score"] for r in all_results) / max(len(all_results), 1)
 
6
  --------------------
7
  - Named `inference.py`, placed in project root.
8
  - Uses OpenAI client for all LLM calls.
9
+ - Reads: API_BASE_URL, MODEL_NAME, HF_TOKEN from environment.
10
  - Emits [START] / [STEP] / [END] lines to stdout in the exact format below.
11
  - Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB.
12
 
 
27
 
28
  from openai import OpenAI
29
 
30
+ # # ── Configuration ──────────────────────────────────────────────────────────
31
+ # API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
+ # MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
33
+ # API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
34
+ # IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
35
+ # SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
36
 
37
+ # BENCHMARK = "nl2sql-bench"
38
+ # MAX_STEPS = 5
39
+ # TEMPERATURE = 0.2 # Low temp for SQL generation
40
+ # MAX_TOKENS = 512
41
+ # SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β†’ success
42
 
43
+ # TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
 
 
44
 
45
+ # ── Configuration ──────────────────────────────────────────────────────────
46
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
47
+ # Points to your newly uploaded fine-tuned weights!
48
+ MODEL_NAME = os.getenv("MODEL_NAME", "ritvik360/qwen-7b-nl2sql-merged_1")
49
+ # CRITICAL FIX: Looks for 'API_KEY' first to satisfy the evaluator's LiteLLM proxy
50
+ API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN", "") or os.getenv("OPENAI_API_KEY")
51
  IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
52
+ # CRITICAL FIX: Point the default directly to your live HF Space!
53
+ SPACE_URL = os.getenv("SPACE_URL", "https://ritvik360-nl2sql-bench.hf.space")
54
 
55
  BENCHMARK = "nl2sql-bench"
56
  MAX_STEPS = 5
57
+ TEMPERATURE = 0.2 # Low temp for SQL generation
58
  MAX_TOKENS = 512
59
+ SUCCESS_THRESHOLD = 0.7 # score >= 0.7 β†’ success
60
 
61
  TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
62
 
 
 
 
 
 
 
63
  # ── System prompt ──────────────────────────────────────────────────────────
64
  SYSTEM_PROMPT = textwrap.dedent("""
65
  You are an expert SQL analyst working with a SQLite e-commerce database.
 
95
  def log_step(
96
  step: int, action: str, reward: float, done: bool, error: Optional[str]
97
  ) -> None:
98
+ # Collapse multi-line SQL to single line for log compliance
99
  action_single = " ".join(action.split())
100
  error_val = error.replace("\n", " ") if error else "null"
101
  print(
 
148
 
149
 
150
  def call_llm(client: OpenAI, user_prompt: str) -> str:
 
 
 
151
  try:
 
 
 
 
152
  resp = client.chat.completions.create(
153
  model=MODEL_NAME,
154
  messages=[
 
160
  stream=False,
161
  )
162
  text = (resp.choices[0].message.content or "").strip()
 
163
  # Strip markdown code fences if model wraps in ```sql ... ```
164
  if text.startswith("```"):
165
  lines = text.split("\n")
 
169
  ).strip()
170
  return text if text else "SELECT 1"
171
  except Exception as exc:
172
+ print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True)
173
+ return "SELECT 1"
 
 
 
174
 
175
 
176
  # ── Single-task episode ────────────────────────────────────────────────────
 
185
  log_start(task_name, MODEL_NAME)
186
 
187
  try:
188
+ # Reset β€” pass task_name via action payload or query param
189
+ # OpenEnv reset() may not accept task args via HTTP; we rely on
190
+ # NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
191
+ # pass it as a reset parameter if the server supports it.
192
+ result = await env.reset() # changed
193
  obs = result.observation
194
 
195
  for step in range(1, MAX_STEPS + 1):
 
208
 
209
  sql = call_llm(client, user_prompt)
210
 
211
+ from models import NL2SQLAction # local to avoid circular at module level
212
  action = NL2SQLAction(query=sql)
213
  result = await env.step(action)
214
  obs = result.observation
 
225
  if done:
226
  break
227
 
228
+ # Compute final score
229
+ # CRITICAL: Evaluator requires score strictly in (0, 1) β€” not 0.0, not 1.0.
230
+ # A perfect solve gives 1.0 β†’ clamp to 0.999. All-fail gives 0.0 β†’ clamp to 0.001.
231
+ raw_score = sum(rewards) / max(len(rewards), 1)
232
+ score = round(min(max(raw_score, 0.001), 0.999), 4)
233
+ success = raw_score >= SUCCESS_THRESHOLD
234
 
235
  except Exception as exc:
236
+ print(f"[DEBUG] Episode error for {task_name}: {exc}", file=sys.stderr, flush=True)
237
  finally:
238
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
239
 
 
243
  # ── Main ───────────────────────────────────────────────────────────────────
244
 
245
  async def main() -> None:
 
 
 
 
 
 
 
 
246
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
247
 
248
+ # Import here to avoid import errors if openenv not installed during lint
249
  from client import NL2SQLEnv
250
 
251
  all_results = []
252
 
253
  for task_name in TASKS:
254
+ # Set the default task for the server session via env-var approach.
255
+ # For the hosted Space, we rely on the task cycling implemented in
256
+ # the task registry's round-robin iterator.
257
  os.environ["NL2SQL_DEFAULT_TASK"] = task_name
258
 
259
  try:
 
262
  all_results.append(result)
263
  except Exception as exc:
264
  print(
265
+ f"[DEBUG] Failed to connect for task {task_name}: {exc}",
266
  file=sys.stderr,
267
  flush=True,
268
  )
269
+ # Emit a zero-score END to keep log format valid
270
  log_end(success=False, steps=0, score=0.0, rewards=[])
271
  all_results.append({"task": task_name, "success": False, "score": 0.0})
272
 
273
+ # Summary to stderr (not scored, for human readability)
274
  print("\n=== Baseline Summary ===", file=sys.stderr)
275
  for r in all_results:
276
  print(
277
+ f" {r['task']:20s} score={r['score']:.3f} "
278
+ f"success={r['success']}",
279
  file=sys.stderr,
280
  )
281
  avg = sum(r["score"] for r in all_results) / max(len(all_results), 1)