| import os |
| import re |
| from collections.abc import Sequence |
|
|
| from huggingface_hub import InferenceClient |
|
|
|
|
| DEFAULT_MODELS = [ |
| "google/flan-t5-small", |
| "google/flan-t5-base", |
| ] |
|
|
|
|
| def build_prompt(input_text: str) -> str: |
| return ( |
| "Answer this binary reasoning question. " |
| "Return only one line in the format 'label: 0' or 'label: 1'.\n\n" |
| f"Question: {input_text}" |
| ) |
|
|
|
|
| def query_model(prompt: str, model_id: str = DEFAULT_MODELS[0], timeout: int = 60) -> str: |
| """Call Hugging Face Inference (routed API) and return model text.""" |
| token = os.environ.get("HF_TOKEN") |
| if not token: |
| return "ERROR: HF_TOKEN is not set." |
|
|
| try: |
| client = InferenceClient(model=model_id, token=token, timeout=timeout) |
| generated = client.text_generation( |
| prompt, |
| max_new_tokens=32, |
| return_full_text=False, |
| ) |
| except Exception as exc: |
| return f"ERROR: inference failed for {model_id}: {exc}" |
|
|
| if not isinstance(generated, str): |
| generated = getattr(generated, "generated_text", None) or str(generated) |
| return generated.strip() |
|
|
|
|
| def query_models(prompt: str, model_ids: Sequence[str]) -> dict[str, str]: |
| return {model_id: query_model(prompt, model_id=model_id) for model_id in model_ids} |
|
|
|
|
| def parse_binary_prediction(output: str) -> int | None: |
| """Parse a structured binary label from model output.""" |
| normalized = output.strip().lower() |
| if normalized.startswith("error:"): |
| return None |
|
|
| structured_patterns = [ |
| r"\blabel\s*[:=]\s*([01])\b", |
| r"\banswer\s*[:=]\s*([01])\b", |
| r"\bprediction\s*[:=]\s*([01])\b", |
| ] |
| for pattern in structured_patterns: |
| match = re.search(pattern, normalized) |
| if match: |
| return int(match.group(1)) |
|
|
| if re.fullmatch(r"[01]", normalized): |
| return int(normalized) |
|
|
| return None |
|
|