api-debug-env / inference.py
avichauhan's picture
Upload folder using huggingface_hub
d73bfc0 verified
"""
Baseline inference script for the API Debug Environment.
MANDATORY:
- Must be named inference.py and placed in the root directory.
- Must use OpenAI Client for all LLM calls.
- Must read env vars: API_BASE_URL, MODEL_NAME, HF_TOKEN.
- Must emit [START], [STEP], [END] structured logs to stdout.
STDOUT FORMAT:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
import asyncio
import json
import os
import re
import textwrap
from typing import List, Optional
from openai import OpenAI
from client import APIDebugEnv
from models import APIDebugAction
# Environment variables (mandatory for hackathon evaluation)
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = HF_TOKEN or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
ENV_URL = os.getenv("ENV_URL") or "https://avichauhan-api-debug-env.hf.space"
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
# Task configuration
TASKS = ["easy", "classify", "medium", "headers", "response", "hard"]
EPISODES_PER_TASK = 3
MAX_STEPS = {"easy": 3, "classify": 4, "medium": 5, "headers": 4, "response": 4, "hard": 7}
BENCHMARK_NAME = "api_debug"
# =========================================================================
# Structured logging (exact format required by evaluator)
# =========================================================================
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(
step: int, action: str, reward: float, done: bool, error: Optional[str]
) -> None:
done_val = str(done).lower()
error_val = error if error else "null"
print(
f"[STEP] step={step} action={action} reward={reward:.4f} "
f"done={done_val} error={error_val}",
flush=True,
)
def log_end(
success: bool, steps: int, score: float, rewards: List[float]
) -> None:
rewards_str = ",".join(f"{r:.4f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.4f} rewards={rewards_str}",
flush=True,
)
# =========================================================================
# System prompts per task
# =========================================================================
SYSTEM_PROMPTS = {
"easy": textwrap.dedent("""
You are an API debugging expert. You receive a broken API request and its specification.
Your job: identify the error type and the affected fields.
Respond with ONLY a JSON object in this format:
{"error_type": "<type>", "affected_fields": ["field1", "field2"]}
Valid error types:
missing_required_field, wrong_field_type, invalid_email_format,
missing_auth_header, extra_unknown_field, null_value_in_required,
wrong_http_method, malformed_json_value, invalid_enum_value,
datetime_format_error, wrong_content_type, expired_auth_token
""").strip(),
"classify": textwrap.dedent("""
You are an API debugging expert. You receive a broken API request with MULTIPLE errors.
Your job: identify ALL error types and ALL affected fields.
Respond with ONLY a JSON object in this format:
{"error_types": ["type1", "type2"], "affected_fields": ["field1", "field2"]}
Valid error types:
missing_required_field, wrong_field_type, invalid_email_format,
missing_auth_header, extra_unknown_field, null_value_in_required,
wrong_http_method, malformed_json_value, invalid_enum_value,
datetime_format_error, wrong_content_type, expired_auth_token
""").strip(),
"medium": textwrap.dedent("""
You are an API debugging expert. You receive a broken API request and its specification.
Your job: fix the request so it matches the spec.
Respond with ONLY a JSON object in this format:
{"fixed_request": "<valid JSON string matching the spec>", "fixed_headers": {"Header": "value"}}
The fixed_request must be a valid JSON string. Include all required fields with correct types.
""").strip(),
"headers": textwrap.dedent("""
You are an API debugging expert. You receive a broken API request with header-level errors.
Your job: identify the header error type and provide the corrected headers.
Respond with ONLY a JSON object in this format:
{"error_type": "<type>", "fixed_headers": {"Header-Name": "correct-value"}}
Valid header error types:
missing_auth_header, wrong_content_type, expired_auth_token
Common headers: Authorization (Bearer token), Content-Type (application/json)
""").strip(),
"response": textwrap.dedent("""
You are an API response validation expert. You receive an API request, its specification,
and the server's response. Your job: identify issues in the response.
Respond with ONLY a JSON object in this format:
{"response_issues": ["issue_type1", "issue_type2"], "affected_fields": ["field1"], "expected_status_code": 200}
Valid response issue types:
wrong_status_code, missing_response_field, wrong_response_type,
extra_response_field, inconsistent_error_format
Only include expected_status_code if you detect a wrong_status_code issue.
""").strip(),
"hard": textwrap.dedent("""
You are an API debugging expert. You receive a broken API request with multiple errors.
Your job: diagnose the errors, fix the request, and explain the fix for a developer.
Respond with ONLY a JSON object in this format:
{
"error_type": "<primary error type>",
"affected_fields": ["field1"],
"fixed_request": "<valid JSON string>",
"fixed_headers": {"Header": "value"},
"explanation": "Clear explanation of what was wrong and how to fix it."
}
""").strip(),
}
# =========================================================================
# Prompt building
# =========================================================================
def build_user_prompt(obs, step_num: int) -> str:
"""Build the user prompt from the observation."""
parts = [
f"API: {obs.http_method} {obs.endpoint} ({obs.api_name})",
f"Error count: {obs.error_count}",
f"Step {step_num}/{obs.max_steps}",
f"\nRequest body:\n{obs.broken_request}",
f"\nRequest headers: {json.dumps(obs.broken_headers)}",
f"\nAPI Specification:\n{obs.api_spec}",
]
# Include response data for response validation task
if obs.response_body:
parts.append(f"\nResponse status code: {obs.response_status_code}")
parts.append(f"\nResponse body:\n{obs.response_body}")
if obs.feedback:
parts.append(f"\nFeedback from previous attempt:\n{obs.feedback}")
return "\n".join(parts)
# =========================================================================
# LLM response parsing
# =========================================================================
def parse_llm_response(text: str) -> dict:
"""Extract a JSON object from the LLM response.
Handles cases where the LLM wraps JSON in markdown code blocks
or adds extra text around it.
"""
if not text:
return {}
# Try direct parse first
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try extracting from markdown code block
code_block = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL)
if code_block:
try:
return json.loads(code_block.group(1))
except json.JSONDecodeError:
pass
# Try finding any JSON object in the text
brace_match = re.search(r"\{[^{}]*\}", text, re.DOTALL)
if brace_match:
try:
return json.loads(brace_match.group(0))
except json.JSONDecodeError:
pass
return {}
def build_action(data: dict) -> APIDebugAction:
"""Convert parsed JSON dict to APIDebugAction."""
# Handle fixed_request: if it's a dict, serialize to JSON string
fixed_req = data.get("fixed_request")
if isinstance(fixed_req, dict):
fixed_req = json.dumps(fixed_req)
return APIDebugAction(
error_type=data.get("error_type"),
error_types=data.get("error_types"),
affected_fields=data.get("affected_fields"),
fixed_request=fixed_req,
fixed_headers=data.get("fixed_headers"),
explanation=data.get("explanation"),
response_issues=data.get("response_issues"),
expected_status_code=data.get("expected_status_code"),
)
# =========================================================================
# Episode runner
# =========================================================================
async def run_episode(
env: APIDebugEnv,
llm_client: OpenAI,
task: str,
) -> float:
"""Run a single episode for the given task. Returns the final score."""
log_start(task=task, env=BENCHMARK_NAME, model=MODEL_NAME)
result = await env.reset(task=task)
obs = result.observation
rewards: List[float] = []
steps_taken = 0
max_steps = MAX_STEPS[task]
for step in range(1, max_steps + 1):
if result.done:
break
user_prompt = build_user_prompt(obs, step)
# Call the LLM
try:
completion = llm_client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPTS[task]},
{"role": "user", "content": user_prompt},
],
max_tokens=500,
temperature=0.0,
)
llm_text = completion.choices[0].message.content or ""
except Exception as exc:
print(f"[DEBUG] LLM request failed: {exc}", flush=True)
llm_text = ""
# Parse LLM output into action
parsed = parse_llm_response(llm_text)
action = build_action(parsed)
# Step the environment
result = await env.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
steps_taken = step
# Build a short action summary for the log
action_summary = _action_summary(action, task)
log_step(step=step, action=action_summary, reward=reward, done=done, error=None)
if done:
break
# Final score is the max reward achieved (environment already tracks best)
# Clamp to open interval (0, 1) - evaluator rejects exactly 0.0 and 1.0
score = max(rewards) if rewards else 0.001
score = min(max(score, 0.001), 0.999)
success = score >= 0.5
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
def _action_summary(action: APIDebugAction, task: str) -> str:
"""Short summary of the action for logging."""
if task == "easy":
return f"diagnose:{action.error_type or 'none'}"
elif task == "classify":
types = action.error_types or [action.error_type or "none"]
return f"classify:{','.join(str(t) for t in types)}"
elif task == "medium":
fix_len = len(action.fixed_request or "")
return f"fix:len={fix_len}"
elif task == "headers":
hdr_count = len(action.fixed_headers or {})
return f"headers:{action.error_type or 'none'}+fix:{hdr_count}"
elif task == "response":
issues = action.response_issues or []
return f"response:{','.join(issues) or 'none'}+status:{action.expected_status_code or 'none'}"
else:
fix_len = len(action.fixed_request or "")
exp_len = len(action.explanation or "")
return f"fix:len={fix_len}+explain:len={exp_len}"
# =========================================================================
# Main
# =========================================================================
async def main() -> None:
llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
# Connect to environment (via Docker image or direct URL)
# Use longer timeout for HF Spaces (LLM calls can be slow)
if IMAGE_NAME:
try:
env = await APIDebugEnv.from_docker_image(IMAGE_NAME)
except Exception as exc:
print(f"[DEBUG] from_docker_image failed ({exc}), falling back to ENV_URL", flush=True)
env = APIDebugEnv(base_url=ENV_URL, message_timeout_s=120.0)
else:
env = APIDebugEnv(base_url=ENV_URL, message_timeout_s=120.0)
all_scores: dict = {}
try:
for task in TASKS:
task_scores = []
for ep in range(EPISODES_PER_TASK):
try:
score = await run_episode(env, llm_client, task)
except Exception as exc:
print(f"[DEBUG] Episode failed: {exc}", flush=True)
# Reconnect on WebSocket failure
try:
await env.close()
except Exception:
pass
env = APIDebugEnv(base_url=ENV_URL, message_timeout_s=120.0)
score = 0.0
task_scores.append(score)
avg = sum(task_scores) / len(task_scores)
all_scores[task] = avg
# Print summary
print("\n--- Baseline Scores ---", flush=True)
for task, avg in all_scores.items():
print(f" {task}: {avg:.3f}", flush=True)
overall = sum(all_scores.values()) / len(all_scores)
print(f" overall: {overall:.3f}", flush=True)
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
if __name__ == "__main__":
asyncio.run(main())