obversarystudios's picture
Use huggingface_hub InferenceClient (routed inference API)
609c576 verified
import os
import re
from collections.abc import Sequence
from huggingface_hub import InferenceClient
DEFAULT_MODELS = [
"google/flan-t5-small",
"google/flan-t5-base",
]
def build_prompt(input_text: str) -> str:
return (
"Answer this binary reasoning question. "
"Return only one line in the format 'label: 0' or 'label: 1'.\n\n"
f"Question: {input_text}"
)
def query_model(prompt: str, model_id: str = DEFAULT_MODELS[0], timeout: int = 60) -> str:
"""Call Hugging Face Inference (routed API) and return model text."""
token = os.environ.get("HF_TOKEN")
if not token:
return "ERROR: HF_TOKEN is not set."
try:
client = InferenceClient(model=model_id, token=token, timeout=timeout)
generated = client.text_generation(
prompt,
max_new_tokens=32,
return_full_text=False,
)
except Exception as exc:
return f"ERROR: inference failed for {model_id}: {exc}"
if not isinstance(generated, str):
generated = getattr(generated, "generated_text", None) or str(generated)
return generated.strip()
def query_models(prompt: str, model_ids: Sequence[str]) -> dict[str, str]:
return {model_id: query_model(prompt, model_id=model_id) for model_id in model_ids}
def parse_binary_prediction(output: str) -> int | None:
"""Parse a structured binary label from model output."""
normalized = output.strip().lower()
if normalized.startswith("error:"):
return None
structured_patterns = [
r"\blabel\s*[:=]\s*([01])\b",
r"\banswer\s*[:=]\s*([01])\b",
r"\bprediction\s*[:=]\s*([01])\b",
]
for pattern in structured_patterns:
match = re.search(pattern, normalized)
if match:
return int(match.group(1))
if re.fullmatch(r"[01]", normalized):
return int(normalized)
return None