ehsaaniqbal's picture
init
9a0ecd1 unverified
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import concurrent.futures
import csv
import json
import os
import re
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from urllib.error import URLError
from urllib.request import urlopen
ROOT = Path(__file__).resolve().parent
DEFAULT_API_BASE_URL = "https://router.huggingface.co/v1"
DEFAULT_MODELS = [
"zai-org/GLM-5.1",
"openai/gpt-oss-120b",
"MiniMaxAI/MiniMax-M2.5",
"moonshotai/Kimi-K2.5",
# "google/gemma-4-31B-it",
]
TASK_COLUMNS = ["easy", "medium", "medium_plus", "hard"]
@dataclass
class ServerHandle:
port: int
process: subprocess.Popen[str] | None
log_path: Path
log_handle: Any | None
reused: bool = False
@property
def base_url(self) -> str:
return f"http://localhost:{self.port}"
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run a batch of HF models against the local InvoiceOps environment."
)
parser.add_argument("--models", nargs="+", help="Override the default model list.")
parser.add_argument(
"--models-file",
help="Optional text file with one HF model id per line.",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port for the local InvoiceOps server.",
)
parser.add_argument(
"--sync",
action="store_true",
help="Run `uv sync --extra dev` before starting.",
)
parser.add_argument(
"--validate",
action="store_true",
help="Run `openenv validate --url` before the sweep.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Echo inference stderr while runs complete.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print the planned configuration without starting servers or calling models.",
)
parser.add_argument(
"--jobs",
type=int,
default=1,
help="Number of concurrent model runs.",
)
parser.add_argument(
"--reuse-running-server",
action="store_true",
help="Reuse an already-running server on the target port instead of failing fast.",
)
return parser.parse_args()
def slugify(value: str) -> str:
slug = re.sub(r"[^A-Za-z0-9._-]+", "-", value.strip())
slug = slug.strip("-._")
return slug or "value"
def load_models(args: argparse.Namespace) -> list[str]:
if args.models:
return args.models
if args.models_file:
path = Path(args.models_file).expanduser().resolve()
models = [
line.strip()
for line in path.read_text(encoding="utf-8").splitlines()
if line.strip() and not line.strip().startswith("#")
]
if not models:
raise RuntimeError(f"No models found in {path}")
return models
return DEFAULT_MODELS
def is_healthy(base_url: str, timeout_s: float = 1.0) -> bool:
try:
with urlopen(f"{base_url}/health", timeout=timeout_s) as response:
return response.status == 200
except URLError:
return False
except Exception:
return False
def wait_for_health(base_url: str, timeout_s: float = 20.0) -> bool:
start = time.time()
while time.time() - start < timeout_s:
if is_healthy(base_url, timeout_s=1.0):
return True
time.sleep(0.5)
return False
def start_server(
port: int,
batch_dir: Path,
*,
reuse_running_server: bool,
) -> ServerHandle:
base_url = f"http://localhost:{port}"
if is_healthy(base_url):
if not reuse_running_server:
raise RuntimeError(
"A healthy server is already running at "
f"{base_url}. Stop it first or rerun with --reuse-running-server."
)
print(f"[batch] reusing running invoiceops_env at {base_url}", file=sys.stderr)
return ServerHandle(
port=port,
process=None,
log_path=batch_dir / "logs" / "invoiceops_env__server.log",
log_handle=None,
reused=True,
)
log_path = batch_dir / "logs" / "invoiceops_env__server.log"
log_handle = log_path.open("w", encoding="utf-8")
process = subprocess.Popen(
["uv", "run", "server", "--port", str(port)],
cwd=ROOT,
stdout=log_handle,
stderr=subprocess.STDOUT,
text=True,
)
if not wait_for_health(base_url):
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
log_handle.close()
tail = log_path.read_text(encoding="utf-8", errors="replace")[-4000:]
raise RuntimeError(f"Failed to start invoiceops_env.\n{tail}")
print(f"[batch] started invoiceops_env at {base_url}", file=sys.stderr)
return ServerHandle(
port=port,
process=process,
log_path=log_path,
log_handle=log_handle,
reused=False,
)
def stop_server(handle: ServerHandle) -> None:
if handle.process is not None:
handle.process.terminate()
try:
handle.process.wait(timeout=5)
except subprocess.TimeoutExpired:
handle.process.kill()
handle.process.wait(timeout=5)
if handle.log_handle is not None:
handle.log_handle.close()
def validate_server(handle: ServerHandle) -> None:
subprocess.run(
[
"uvx",
"--from",
"openenv-core",
"openenv",
"validate",
"--url",
handle.base_url,
],
cwd=ROOT,
check=True,
)
def parse_output_path(stderr_text: str) -> Path | None:
for line in reversed(stderr_text.splitlines()):
if line.startswith("wrote="):
return Path(line.split("=", 1)[1].strip())
return None
def make_batch_dir() -> Path:
batch_id = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
batch_dir = ROOT / "batch_runs" / batch_id
(batch_dir / "logs").mkdir(parents=True, exist_ok=True)
return batch_dir
def _collect_request_errors(node: Any) -> list[str]:
errors: list[str] = []
if isinstance(node, dict):
if node.get("failure_reason") == "request_error":
message = node.get("error_message")
if isinstance(message, str) and message.strip():
errors.append(message.strip())
for value in node.values():
errors.extend(_collect_request_errors(value))
elif isinstance(node, list):
for value in node:
errors.extend(_collect_request_errors(value))
return errors
def classify_status(
*,
returncode: int,
payload: dict[str, Any] | None,
request_errors: list[str],
) -> str:
if returncode != 0 or payload is None:
return "failed"
if not request_errors:
return "ok"
joined = "\n".join(request_errors).lower()
if "model_not_supported" in joined or "not a chat model" in joined:
return "invalid_model"
if "depleted your monthly included credits" in joined:
return "provider_credit_error"
return "request_error"
def extract_scores(payload: dict[str, Any]) -> tuple[dict[str, float], int, int]:
results = payload.get("results") or []
scores: dict[str, float] = {}
fallback_count = 0
parse_failure_count = 0
for result in results:
task_id = result.get("task_id")
score = result.get("score")
if isinstance(task_id, str) and isinstance(score, (int, float)):
scores[task_id] = float(score)
if result.get("used_fallback") is True:
fallback_count += 1
if result.get("decision_parsed") is False:
parse_failure_count += 1
return scores, fallback_count, parse_failure_count
def run_inference(
handle: ServerHandle,
*,
model_name: str,
hf_token: str,
api_base_url: str,
batch_name: str,
logs_dir: Path,
verbose: bool,
) -> dict[str, Any]:
model_slug = slugify(model_name)
stdout_path = logs_dir / f"invoiceops_env__{model_slug}.stdout.log"
stderr_path = logs_dir / f"invoiceops_env__{model_slug}.stderr.log"
env = os.environ.copy()
env.update(
{
"HF_TOKEN": hf_token,
"API_BASE_URL": api_base_url,
"MODEL_NAME": model_name,
"ENV_URL": handle.base_url,
"EVAL_RUN_NAME": batch_name,
}
)
started_at = time.time()
result = subprocess.run(
["uv", "run", "python", "inference.py"],
cwd=ROOT,
env=env,
capture_output=True,
text=True,
check=False,
)
duration_s = round(time.time() - started_at, 2)
stdout_path.write_text(result.stdout, encoding="utf-8")
stderr_path.write_text(result.stderr, encoding="utf-8")
if verbose and result.stderr.strip():
sys.stderr.write(result.stderr)
if not result.stderr.endswith("\n"):
sys.stderr.write("\n")
output_path = parse_output_path(result.stderr)
payload: dict[str, Any] | None = None
if output_path is not None and output_path.exists():
payload = json.loads(output_path.read_text(encoding="utf-8"))
scores: dict[str, float] = {}
fallback_count = 0
parse_failure_count = 0
mean_score = None
request_errors: list[str] = []
if payload is not None:
if isinstance(payload.get("mean_score"), (int, float)):
mean_score = float(payload["mean_score"])
elif isinstance(payload.get("raw_mean_score"), (int, float)):
mean_score = float(payload["raw_mean_score"])
scores, fallback_count, parse_failure_count = extract_scores(payload)
request_errors = _collect_request_errors(payload)
status = classify_status(
returncode=result.returncode,
payload=payload,
request_errors=request_errors,
)
return {
"model": model_name,
"status": status,
"returncode": result.returncode,
"duration_s": duration_s,
"mean_score": mean_score,
"fallback_count": fallback_count,
"parse_failure_count": parse_failure_count,
"request_error_count": len(request_errors),
"first_request_error": request_errors[0] if request_errors else "",
"output_json": str(output_path) if output_path is not None else "",
"stdout_log": str(stdout_path),
"stderr_log": str(stderr_path),
**{task_id: scores.get(task_id) for task_id in TASK_COLUMNS},
}
def print_summary(rows: list[dict[str, Any]]) -> None:
headers = [
"model",
"mean",
*TASK_COLUMNS,
"fallbacks",
"parse_fail",
"req_err",
"status",
"sec",
]
widths = {header: len(header) for header in headers}
rendered_rows: list[dict[str, str]] = []
for row in rows:
rendered = {
"model": row["model"],
"mean": "-" if row["mean_score"] is None else f"{row['mean_score']:.4f}",
"fallbacks": str(row["fallback_count"]),
"parse_fail": str(row["parse_failure_count"]),
"req_err": str(row["request_error_count"]),
"status": row["status"],
"sec": f"{row['duration_s']:.1f}",
}
rendered.update(
{
task_id: "-" if row.get(task_id) is None else f"{row[task_id]:.4f}"
for task_id in TASK_COLUMNS
}
)
rendered_rows.append(rendered)
for key, value in rendered.items():
widths[key] = max(widths[key], len(value))
print(" ".join(header.ljust(widths[header]) for header in headers))
print(" ".join("-" * widths[header] for header in headers))
for row in rendered_rows:
print(" ".join(row[header].ljust(widths[header]) for header in headers))
def write_summary_files(
batch_dir: Path, rows: list[dict[str, Any]]
) -> tuple[Path, Path]:
csv_path = batch_dir / "summary.csv"
json_path = batch_dir / "summary.json"
fieldnames = [
"model",
"mean_score",
*TASK_COLUMNS,
"fallback_count",
"parse_failure_count",
"request_error_count",
"status",
"duration_s",
"returncode",
"first_request_error",
"output_json",
"stdout_log",
"stderr_log",
]
with csv_path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames if rows else ["model"])
writer.writeheader()
writer.writerows(rows)
json_path.write_text(json.dumps(rows, indent=2), encoding="utf-8")
return csv_path, json_path
def main() -> int:
args = parse_args()
if args.jobs < 1:
raise RuntimeError("--jobs must be at least 1.")
models = load_models(args)
api_base_url = os.getenv("API_BASE_URL", DEFAULT_API_BASE_URL)
hf_token = os.getenv("HF_TOKEN")
if not hf_token and not args.dry_run:
raise RuntimeError("Set HF_TOKEN in the shell before running ./batch.")
if args.dry_run:
print("Dry run only.")
print(f"API_BASE_URL={api_base_url}")
print(f"models={','.join(models)}")
print(f"jobs={args.jobs}")
print(f"invoiceops_env -> http://localhost:{args.port}")
return 0
batch_dir = make_batch_dir()
batch_name = batch_dir.name
logs_dir = batch_dir / "logs"
rows: list[dict[str, Any]] = []
handle: ServerHandle | None = None
try:
if args.sync:
subprocess.run(["uv", "sync", "--extra", "dev"], cwd=ROOT, check=True)
handle = start_server(
args.port,
batch_dir,
reuse_running_server=args.reuse_running_server,
)
if args.validate:
validate_server(handle)
print(f"[batch] batch={batch_name}", file=sys.stderr)
print(f"[batch] api_base_url={api_base_url}", file=sys.stderr)
if args.jobs == 1:
for model_name in models:
print(
f"[batch] running invoiceops_env :: {model_name}", file=sys.stderr
)
row = run_inference(
handle,
model_name=model_name,
hf_token=hf_token,
api_base_url=api_base_url,
batch_name=batch_name,
logs_dir=logs_dir,
verbose=args.verbose,
)
rows.append(row)
mean_display = (
"-" if row["mean_score"] is None else f"{row['mean_score']:.4f}"
)
print(
f"[batch] result invoiceops_env :: {model_name} mean={mean_display} status={row['status']}",
file=sys.stderr,
)
else:
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.jobs
) as executor:
futures = {
executor.submit(
run_inference,
handle,
model_name=model_name,
hf_token=hf_token,
api_base_url=api_base_url,
batch_name=batch_name,
logs_dir=logs_dir,
verbose=args.verbose,
): model_name
for model_name in models
}
for future in concurrent.futures.as_completed(futures):
model_name = futures[future]
row = future.result()
rows.append(row)
mean_display = (
"-" if row["mean_score"] is None else f"{row['mean_score']:.4f}"
)
print(
f"[batch] result invoiceops_env :: {model_name} mean={mean_display} status={row['status']}",
file=sys.stderr,
)
order = {model: index for index, model in enumerate(models)}
rows.sort(key=lambda row: order[row["model"]])
csv_path, json_path = write_summary_files(batch_dir, rows)
print_summary(rows)
print(f"\nsummary_csv={csv_path}")
print(f"summary_json={json_path}")
print(f"logs_dir={logs_dir}")
return 0
finally:
if handle is not None:
stop_server(handle)
if __name__ == "__main__":
raise SystemExit(main())