import os import re from collections.abc import Sequence from huggingface_hub import InferenceClient DEFAULT_MODELS = [ "google/flan-t5-small", "google/flan-t5-base", ] def build_prompt(input_text: str) -> str: return ( "Answer this binary reasoning question. " "Return only one line in the format 'label: 0' or 'label: 1'.\n\n" f"Question: {input_text}" ) def query_model(prompt: str, model_id: str = DEFAULT_MODELS[0], timeout: int = 60) -> str: """Call Hugging Face Inference (routed API) and return model text.""" token = os.environ.get("HF_TOKEN") if not token: return "ERROR: HF_TOKEN is not set." try: client = InferenceClient(model=model_id, token=token, timeout=timeout) generated = client.text_generation( prompt, max_new_tokens=32, return_full_text=False, ) except Exception as exc: return f"ERROR: inference failed for {model_id}: {exc}" if not isinstance(generated, str): generated = getattr(generated, "generated_text", None) or str(generated) return generated.strip() def query_models(prompt: str, model_ids: Sequence[str]) -> dict[str, str]: return {model_id: query_model(prompt, model_id=model_id) for model_id in model_ids} def parse_binary_prediction(output: str) -> int | None: """Parse a structured binary label from model output.""" normalized = output.strip().lower() if normalized.startswith("error:"): return None structured_patterns = [ r"\blabel\s*[:=]\s*([01])\b", r"\banswer\s*[:=]\s*([01])\b", r"\bprediction\s*[:=]\s*([01])\b", ] for pattern in structured_patterns: match = re.search(pattern, normalized) if match: return int(match.group(1)) if re.fullmatch(r"[01]", normalized): return int(normalized) return None