Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- inference.py +59 -44
inference.py
CHANGED
|
@@ -33,7 +33,7 @@ from models import SQLAction
|
|
| 33 |
|
| 34 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 35 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 36 |
-
MODEL_NAME = os.getenv("MODEL_NAME")
|
| 37 |
ENV_URL = os.getenv("ENV_URL", "https://prithvigg-queryforge.hf.space")
|
| 38 |
|
| 39 |
MAX_STEPS = 5
|
|
@@ -115,34 +115,42 @@ def hr(char="β", width=70):
|
|
| 115 |
# ββ Per-task agent loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
|
| 117 |
def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
log_start(task=task_id, model=MODEL_NAME)
|
| 122 |
|
| 123 |
-
if result.done:
|
| 124 |
-
print(f" ERROR loading task: {obs.feedback}")
|
| 125 |
-
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 126 |
-
return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
|
| 127 |
-
|
| 128 |
-
print(f"\n Task : {obs.task_title} [{obs.task_level}]")
|
| 129 |
-
|
| 130 |
-
messages = [
|
| 131 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 132 |
-
{
|
| 133 |
-
"role": "user",
|
| 134 |
-
"content": (
|
| 135 |
-
f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
|
| 136 |
-
"Provide your fixed SQL query."
|
| 137 |
-
),
|
| 138 |
-
},
|
| 139 |
-
]
|
| 140 |
-
|
| 141 |
-
step = 0
|
| 142 |
-
rewards: List[float] = []
|
| 143 |
-
success = False
|
| 144 |
-
|
| 145 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
while not result.done:
|
| 147 |
step += 1
|
| 148 |
|
|
@@ -171,10 +179,10 @@ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
|
|
| 171 |
|
| 172 |
reward = result.reward or 0.0
|
| 173 |
rewards.append(reward)
|
|
|
|
| 174 |
|
| 175 |
-
# Determine error string for [STEP] log
|
| 176 |
if not obs.syntax_valid:
|
| 177 |
-
step_error = "syntax_error"
|
| 178 |
print(f" β Syntax error β query could not be parsed")
|
| 179 |
elif not obs.execution_success:
|
| 180 |
step_error = (obs.execution_error or "execution_error")[:120]
|
|
@@ -183,12 +191,12 @@ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
|
|
| 183 |
step_error = None
|
| 184 |
print(f" β Executed Β· rows returned: {obs.rows_returned}")
|
| 185 |
|
| 186 |
-
done_marker = " β DONE" if
|
| 187 |
print(f" Score : {score_bar(reward)}{done_marker}")
|
| 188 |
|
| 189 |
-
log_step(step=step, action=sql, reward=reward, done=
|
| 190 |
|
| 191 |
-
if
|
| 192 |
break
|
| 193 |
|
| 194 |
print(f"\n β» Retrying β score {reward:.3f} below threshold")
|
|
@@ -211,27 +219,29 @@ def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
|
|
| 211 |
),
|
| 212 |
})
|
| 213 |
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
finally:
|
| 217 |
-
log_end(success=success, steps=step, score=
|
| 218 |
|
| 219 |
return {
|
| 220 |
"task_id": task_id,
|
| 221 |
-
"task_title":
|
| 222 |
-
"task_level":
|
| 223 |
-
"best_score":
|
| 224 |
-
"attempts":
|
| 225 |
-
"done":
|
| 226 |
}
|
| 227 |
|
| 228 |
|
| 229 |
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 230 |
|
| 231 |
def main() -> None:
|
| 232 |
-
if not MODEL_NAME:
|
| 233 |
-
print("ERROR: MODEL_NAME env var is not set.")
|
| 234 |
-
sys.exit(1)
|
| 235 |
if not API_KEY:
|
| 236 |
print("ERROR: HF_TOKEN (or API_KEY) is not set.")
|
| 237 |
sys.exit(1)
|
|
@@ -246,10 +256,15 @@ def main() -> None:
|
|
| 246 |
hr()
|
| 247 |
|
| 248 |
results = []
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
# ββ Results summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 255 |
print(f"\n{'β' * 70}")
|
|
|
|
| 33 |
|
| 34 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 35 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 36 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
|
| 37 |
ENV_URL = os.getenv("ENV_URL", "https://prithvigg-queryforge.hf.space")
|
| 38 |
|
| 39 |
MAX_STEPS = 5
|
|
|
|
| 115 |
# ββ Per-task agent loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
|
| 117 |
def run_task(task_id: str, llm: OpenAI, env_client) -> dict:
|
| 118 |
+
# Initialise before anything that can throw β guarantees [END] is always emitted.
|
| 119 |
+
step = 0
|
| 120 |
+
rewards: List[float] = []
|
| 121 |
+
success = False
|
| 122 |
+
best_score = 0.0
|
| 123 |
+
task_title = task_id
|
| 124 |
+
task_level = "unknown"
|
| 125 |
+
attempts = 0
|
| 126 |
+
done = False
|
| 127 |
|
| 128 |
log_start(task=task_id, model=MODEL_NAME)
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
try:
|
| 131 |
+
result = env_client.reset(task_id=task_id)
|
| 132 |
+
obs = result.observation
|
| 133 |
+
|
| 134 |
+
if result.done:
|
| 135 |
+
print(f" ERROR loading task: {obs.feedback}")
|
| 136 |
+
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 137 |
+
return {"task_id": task_id, "best_score": 0.0, "attempts": 0, "done": False}
|
| 138 |
+
|
| 139 |
+
task_title = obs.task_title
|
| 140 |
+
task_level = obs.task_level
|
| 141 |
+
print(f"\n Task : {task_title} [{task_level}]")
|
| 142 |
+
|
| 143 |
+
messages = [
|
| 144 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 145 |
+
{
|
| 146 |
+
"role": "user",
|
| 147 |
+
"content": (
|
| 148 |
+
f"Here is your SQL challenge:\n\n{obs.task_description}\n\n"
|
| 149 |
+
"Provide your fixed SQL query."
|
| 150 |
+
),
|
| 151 |
+
},
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
while not result.done:
|
| 155 |
step += 1
|
| 156 |
|
|
|
|
| 179 |
|
| 180 |
reward = result.reward or 0.0
|
| 181 |
rewards.append(reward)
|
| 182 |
+
done = result.done
|
| 183 |
|
|
|
|
| 184 |
if not obs.syntax_valid:
|
| 185 |
+
step_error: Optional[str] = "syntax_error"
|
| 186 |
print(f" β Syntax error β query could not be parsed")
|
| 187 |
elif not obs.execution_success:
|
| 188 |
step_error = (obs.execution_error or "execution_error")[:120]
|
|
|
|
| 191 |
step_error = None
|
| 192 |
print(f" β Executed Β· rows returned: {obs.rows_returned}")
|
| 193 |
|
| 194 |
+
done_marker = " β DONE" if done else ""
|
| 195 |
print(f" Score : {score_bar(reward)}{done_marker}")
|
| 196 |
|
| 197 |
+
log_step(step=step, action=sql, reward=reward, done=done, error=step_error)
|
| 198 |
|
| 199 |
+
if done:
|
| 200 |
break
|
| 201 |
|
| 202 |
print(f"\n β» Retrying β score {reward:.3f} below threshold")
|
|
|
|
| 219 |
),
|
| 220 |
})
|
| 221 |
|
| 222 |
+
best_score = obs.best_score
|
| 223 |
+
attempts = obs.attempt
|
| 224 |
+
success = best_score >= SUCCESS_SCORE_THRESHOLD
|
| 225 |
+
|
| 226 |
+
except Exception as exc:
|
| 227 |
+
print(f" FATAL error in task {task_id}: {exc}", flush=True)
|
| 228 |
|
| 229 |
finally:
|
| 230 |
+
log_end(success=success, steps=step, score=best_score, rewards=rewards)
|
| 231 |
|
| 232 |
return {
|
| 233 |
"task_id": task_id,
|
| 234 |
+
"task_title": task_title,
|
| 235 |
+
"task_level": task_level,
|
| 236 |
+
"best_score": best_score,
|
| 237 |
+
"attempts": attempts,
|
| 238 |
+
"done": done,
|
| 239 |
}
|
| 240 |
|
| 241 |
|
| 242 |
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 243 |
|
| 244 |
def main() -> None:
|
|
|
|
|
|
|
|
|
|
| 245 |
if not API_KEY:
|
| 246 |
print("ERROR: HF_TOKEN (or API_KEY) is not set.")
|
| 247 |
sys.exit(1)
|
|
|
|
| 256 |
hr()
|
| 257 |
|
| 258 |
results = []
|
| 259 |
+
try:
|
| 260 |
+
env_ctx = QueryforgeEnv(base_url=ENV_URL).sync()
|
| 261 |
+
with env_ctx as env_client:
|
| 262 |
+
for task_id in TASK_IDS:
|
| 263 |
+
print(f"\n{'β' * 70}")
|
| 264 |
+
results.append(run_task(task_id, llm, env_client))
|
| 265 |
+
except Exception as exc:
|
| 266 |
+
print(f"FATAL: could not connect to environment at {ENV_URL}: {exc}", flush=True)
|
| 267 |
+
sys.exit(1)
|
| 268 |
|
| 269 |
# ββ Results summary βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 270 |
print(f"\n{'β' * 70}")
|