lrec2026-llm-annotator / provider.py
lterriel's picture
fix_bug_adopt_suggest (#3)
0aaf5f7
"""LLM provider client for structured-output annotation.
Supports three OpenAI-compatible providers via parametric base URL:
- openrouter : https://openrouter.ai/api/v1 (MoE: many models behind one key)
- mistral : https://api.mistral.ai/v1
- openai : https://api.openai.com/v1
- ilaas : https://llm.ilaas.fr/v1 (documentation: https://www.ilaas.fr/services-inference/)
All three accept the same OpenAI Chat Completions request shape, including
`response_format` (json_schema strict on OpenAI; json_object on Mistral; varies
on OpenRouter — we auto-fall back if the strict mode is rejected).
"""
from __future__ import annotations
import asyncio
import json
import time
from dataclasses import dataclass
from typing import Optional
import httpx
from schemas import AnnotationSchema, to_json_schema, validate as schema_validate
from prompts import VALIDATION_RETRY
DEFAULT_TIMEOUT = 60.0
BASE_URLS = {
"openrouter": "https://openrouter.ai/api/v1",
"mistral": "https://api.mistral.ai/v1",
"openai": "https://api.openai.com/v1",
"ilaas": "https://llm.ilaas.fr/v1",
}
PROVIDERS = tuple(BASE_URLS.keys())
CURATED_MODELS_BY_PROVIDER: dict[str, list[str]] = {
"openrouter": [
"openai/gpt-oss-20b:free",
"google/gemma-4-26b-a4b-it:free",
"meta-llama/llama-3.3-70b-instruct:free",
"qwen/qwen3-next-80b-a3b-instruct:free",
"deepseek/deepseek-v4-flash:free",
"mistralai/mistral-nemo",
"mistralai/mistral-small-24b-instruct-2501",
"mistralai/ministral-3b-2512",
],
"mistral": [
"mistral-small-2603",
"mistral-large-2512",
"ministral-8b-2512",
"ministral-3b-2512",
],
"openai": [
"gpt-5-mini-2025-08-07",
"gpt-5-nano-2025-08-07",
"gpt-5-2025-08-07",
"gpt-4o-mini-2024-07-18",
],
"ilaas": [
"gemma-4-31b",
"gpt-oss-120b",
"llama-3.1-8b",
"llama-3.3-70b",
"qwen-3.6-35b-instruct",
"mistral-small-3.2-24b",
]
}
# Back-compat alias used by other modules
CURATED_MODELS = CURATED_MODELS_BY_PROVIDER["openrouter"]
@dataclass
class ModelResult:
model: str
ok: bool
annotation: Optional[dict]
latency_s: float
error: str = ""
raw: str = ""
def _build_headers(provider: str, api_key: str) -> dict:
h = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
if provider == "openrouter":
h["HTTP-Referer"] = "https://lrec2026-llm-annotator.local"
h["X-Title"] = "LREC2026 LLM-as-Annotator"
return h
class LLMClient:
def __init__(self, provider: str, api_key: str, timeout: float = DEFAULT_TIMEOUT):
if provider not in BASE_URLS:
raise ValueError(f"Unknown provider {provider!r}; expected one of {tuple(BASE_URLS)}")
self.provider = provider
self.api_key = api_key
self.base_url = BASE_URLS[provider]
self.endpoint = self.base_url + "/chat/completions"
self.headers = _build_headers(provider, api_key)
self.timeout = timeout
self._client: httpx.AsyncClient | None = None
async def __aenter__(self):
self._client = httpx.AsyncClient(
timeout=self.timeout,
limits=httpx.Limits(
max_connections=20,
max_keepalive_connections=10,
),
)
return self
async def __aexit__(self, exc_type, exc, tb):
if self._client:
await self._client.aclose()
self._client = None
async def annotate_one(
self,
*,
system: str,
user: str,
schema: AnnotationSchema,
model: str,
temperature: float = 0.0,
timeout: float = DEFAULT_TIMEOUT,
) -> ModelResult:
"""Call one model, validate JSON. One retry on schema-validation failure."""
print(f"[LLM] start provider={self.provider} model={model}")
json_schema = to_json_schema(schema)
start = time.time()
msgs = [{"role": "system", "content": system}, {"role": "user", "content": user}]
try:
client = self._client
close_after = False
if client is None:
client = httpx.AsyncClient(timeout=timeout)
close_after = True
try:
raw_text = await self._call(client, msgs, json_schema, model, temperature)
ann, err = self._parse_and_validate(raw_text, schema)
if err:
retry_msg = VALIDATION_RETRY + f"\n\nValidator errors:\n{err}\n\nPrevious response:\n{raw_text}"
msgs.append({"role": "assistant", "content": raw_text})
msgs.append({"role": "user", "content": retry_msg})
raw_text = await self._call(client, msgs, json_schema, model, temperature)
ann, err = self._parse_and_validate(raw_text, schema)
if err:
print(
f"[LLM] error provider={self.provider} model={model} latency={time.time() - start:.2f}s error={err}")
return ModelResult(model=model, ok=False, annotation=None, latency_s=time.time() - start, error=err,
raw=raw_text)
print(f"[LLM] done provider={self.provider} model={model} latency={time.time() - start:.2f}s")
return ModelResult(model=model, ok=True, annotation=ann, latency_s=time.time() - start, raw=raw_text)
finally:
if close_after:
await client.aclose()
except Exception as e:
print(f"[LLM] error provider={self.provider} model={model} latency={time.time() - start:.2f}s error={e}")
return ModelResult(model=model, ok=False, annotation=None, latency_s=time.time() - start, error=str(e))
async def annotate_many(
self,
*,
models: list[str],
system: str,
user: str,
schema: AnnotationSchema,
temperature: float = 0.0,
timeout: float = DEFAULT_TIMEOUT,
) -> list[ModelResult]:
coros = [
self.annotate_one(
system=system, user=user, schema=schema, model=m, temperature=temperature, timeout=timeout
)
for m in models
]
return await asyncio.gather(*coros)
async def _call(self, client: httpx.AsyncClient, msgs: list[dict], json_schema: dict, model: str,
temperature: float) -> str:
# Strict json_schema works on OpenAI and most OpenRouter models. For Mistral and
# for some open-source models routed via OpenRouter, fall back to json_object.
if self.provider in {"mistral", "ilaas"}:
payload = {
"model": model,
"messages": msgs,
"temperature": temperature,
"response_format": {"type": "json_object"},
}
else:
payload = {
"model": model,
"messages": msgs,
"temperature": temperature,
"response_format": {
"type": "json_schema",
"json_schema": {
"name": "annotation",
"strict": True,
"schema": json_schema,
},
},
}
resp = await client.post(self.endpoint, headers=self.headers, json=payload)
if resp.status_code >= 400:
payload["response_format"] = {"type": "json_object"}
resp = await client.post(self.endpoint, headers=self.headers, json=payload)
if resp.status_code >= 400:
payload.pop("response_format", None)
resp = await client.post(self.endpoint, headers=self.headers, json=payload)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"] or ""
@staticmethod
def _parse_and_validate(raw_text: str, schema: AnnotationSchema) -> tuple[Optional[dict], str]:
text = (raw_text or "").strip()
if text.startswith("```"):
text = text.strip("`")
if text.lower().startswith("json"):
text = text[4:]
text = text.strip()
try:
ann = json.loads(text)
except json.JSONDecodeError as e:
return None, f"Invalid JSON: {e}"
ok, errors = schema_validate(schema, ann)
if not ok:
return None, "; ".join(errors[:10])
return ann, ""
def test_connection_sync(api_key: str, provider: str = "openrouter", model: Optional[str] = None) -> tuple[bool, str]:
"""Quick blocking test for the 'Test' button in the key modal."""
if provider not in BASE_URLS:
return False, f"Unknown provider {provider!r}"
if not model:
models = CURATED_MODELS_BY_PROVIDER.get(provider) or []
model = models[0] if models else None
if not model:
return False, "No model configured for this provider."
url = BASE_URLS[provider] + "/chat/completions"
try:
resp = httpx.post(
url,
headers=_build_headers(provider, api_key),
json={
"model": model,
"messages": [{"role": "user", "content": "Reply with the single word: OK"}],
"max_tokens": 5,
"temperature": 0.0,
},
timeout=20.0,
)
if resp.status_code >= 400:
return False, f"HTTP {resp.status_code}: {resp.text[:200]}"
content = resp.json()["choices"][0]["message"]["content"]
return True, f"Connected ({provider} / {model}). Reply: {content!r}"
except Exception as e:
return False, str(e)
# Back-compat shim — old callers used OpenRouterClient(api_key=...).
class OpenRouterClient(LLMClient):
def __init__(self, api_key: str, **kwargs):
super().__init__(provider="openrouter", api_key=api_key)