File size: 1,918 Bytes
1b435f0 609c576 1b435f0 609c576 1b435f0 609c576 1b435f0 609c576 1b435f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import os
import re
from collections.abc import Sequence
from huggingface_hub import InferenceClient
DEFAULT_MODELS = [
"google/flan-t5-small",
"google/flan-t5-base",
]
def build_prompt(input_text: str) -> str:
return (
"Answer this binary reasoning question. "
"Return only one line in the format 'label: 0' or 'label: 1'.\n\n"
f"Question: {input_text}"
)
def query_model(prompt: str, model_id: str = DEFAULT_MODELS[0], timeout: int = 60) -> str:
"""Call Hugging Face Inference (routed API) and return model text."""
token = os.environ.get("HF_TOKEN")
if not token:
return "ERROR: HF_TOKEN is not set."
try:
client = InferenceClient(model=model_id, token=token, timeout=timeout)
generated = client.text_generation(
prompt,
max_new_tokens=32,
return_full_text=False,
)
except Exception as exc:
return f"ERROR: inference failed for {model_id}: {exc}"
if not isinstance(generated, str):
generated = getattr(generated, "generated_text", None) or str(generated)
return generated.strip()
def query_models(prompt: str, model_ids: Sequence[str]) -> dict[str, str]:
return {model_id: query_model(prompt, model_id=model_id) for model_id in model_ids}
def parse_binary_prediction(output: str) -> int | None:
"""Parse a structured binary label from model output."""
normalized = output.strip().lower()
if normalized.startswith("error:"):
return None
structured_patterns = [
r"\blabel\s*[:=]\s*([01])\b",
r"\banswer\s*[:=]\s*([01])\b",
r"\bprediction\s*[:=]\s*([01])\b",
]
for pattern in structured_patterns:
match = re.search(pattern, normalized)
if match:
return int(match.group(1))
if re.fullmatch(r"[01]", normalized):
return int(normalized)
return None
|