openenv_project / inference_local.py
ark406's picture
Deploy OpenEnv Submission
0b55673 verified
"""
inference_local.py β€” Local inference for Python Bug Fixer OpenEnv.
This script runs all 3 tasks using pre-written correct solutions
(no external LLM required). Demonstrates the full environment loop.
Usage:
# 1. Start the server first:
# uvicorn app.main:app --host 0.0.0.0 --port 7860
# 2. Run this script:
# python inference_local.py
"""
import json
import requests
from datetime import datetime, timezone
SPACE_URL = "http://localhost:7860"
# ── Pre-written correct solutions for each task ───────────────────────────────
SOLUTIONS = {
"task_easy": """\
def get_last_element(lst):
return lst[len(lst) - 1]
def compute_sum(numbers):
total = 0
for i in range(len(numbers)):
total += numbers[i]
return total
result = get_last_element([1, 2, 3, 4, 5])
print(result)
total = compute_sum([10, 20, 30])
print(total)
""",
"task_medium": """\
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
arr = [1, 3, 5, 7, 9, 11, 13]
print(binary_search(arr, 7))
print(binary_search(arr, 11))
print(binary_search(arr, 4))
""",
"task_hard": """\
class DataProcessor:
\"\"\"Processes a list of employee records.\"\"\"
def __init__(self):
self.data = []
def add_record(self, record: dict):
self.data.append(record)
def get_average(self, field: str) -> float:
\"\"\"Return the average value of a numeric field.\"\"\"
if not self.data:
return 0.0
return sum(r[field] for r in self.data) / len(self.data)
def filter_records(self, field: str, value):
\"\"\"Return all records where record[field] == value.\"\"\"
return [r for r in self.data if r[field] == value]
def get_sorted(self, field: str, reverse: bool = False):
\"\"\"Return records sorted by field. reverse=True means descending.\"\"\"
return sorted(self.data, key=lambda x: x[field], reverse=reverse)
def get_max(self, field: str) -> dict:
\"\"\"Return the record with the highest value for field.\"\"\"
return max(self.data, key=lambda x: x[field])
p = DataProcessor()
p.add_record({"name": "Alice", "score": 85})
p.add_record({"name": "Bob", "score": 92})
p.add_record({"name": "Charlie", "score": 78})
print(round(p.get_average("score"), 1))
print(len(p.filter_records("name", "Alice")))
print(p.get_sorted("score", reverse=True)[0]["name"])
print(p.get_max("score")["name"])
""",
}
TASK_IDS = ["task_easy", "task_medium", "task_hard"]
def now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def run_task(task_id: str) -> dict:
"""Run one task: reset β†’ submit fixed code β†’ check reward."""
# Reset
resp = requests.post(f"{SPACE_URL}/reset", json={"task_id": task_id}, timeout=30)
resp.raise_for_status()
reset_data = resp.json()
session_id = reset_data["session_id"]
start_log = {
"task_id": task_id,
"session_id": session_id,
"model": "local-solver",
"timestamp": now_iso(),
}
print(f"[START] {json.dumps(start_log)}", flush=True)
# Submit the correct solution
action = SOLUTIONS[task_id]
resp = requests.post(
f"{SPACE_URL}/step",
json={"session_id": session_id, "action": action},
timeout=30,
)
resp.raise_for_status()
result = resp.json()
step_log = {
"step": 1,
"action_chars": len(action),
"reward": result["reward"],
"done": result["done"],
"observation": result["observation"][:200],
}
print(f"[STEP] {json.dumps(step_log)}", flush=True)
end_log = {
"task_id": task_id,
"session_id": session_id,
"total_reward": result["reward"],
"steps": 1,
"success": result["reward"] >= 0.8,
"timestamp": now_iso(),
}
print(f"[END] {json.dumps(end_log)}", flush=True)
return {
"task_id": task_id,
"reward": result["reward"],
"steps": 1,
"success": result["reward"] >= 0.8,
}
def main():
print("=" * 60)
print(" Python Bug Fixer β€” Local Inference (no LLM needed)")
print("=" * 60)
print(f" Server: {SPACE_URL}")
print(f" Tasks: {', '.join(TASK_IDS)}")
print("-" * 60)
results = []
for task_id in TASK_IDS:
result = run_task(task_id)
results.append(result)
print("-" * 60)
# Summary
print("\n=== SUMMARY ===")
total_reward = 0.0
for r in results:
status = "βœ… PASS" if r["success"] else "❌ FAIL"
print(f" [{status}] {r['task_id']:15s} reward={r['reward']:.2f} steps={r['steps']}")
total_reward += r["reward"]
avg = total_reward / len(results)
print(f"\n Average reward: {avg:.2f}")
print("=== END SUMMARY ===")
if __name__ == "__main__":
main()