File size: 1,918 Bytes
1b435f0
 
 
 
609c576
1b435f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609c576
1b435f0
 
 
 
 
609c576
 
 
 
 
 
 
 
1b435f0
609c576
 
 
1b435f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import re
from collections.abc import Sequence

from huggingface_hub import InferenceClient


DEFAULT_MODELS = [
    "google/flan-t5-small",
    "google/flan-t5-base",
]


def build_prompt(input_text: str) -> str:
    return (
        "Answer this binary reasoning question. "
        "Return only one line in the format 'label: 0' or 'label: 1'.\n\n"
        f"Question: {input_text}"
    )


def query_model(prompt: str, model_id: str = DEFAULT_MODELS[0], timeout: int = 60) -> str:
    """Call Hugging Face Inference (routed API) and return model text."""
    token = os.environ.get("HF_TOKEN")
    if not token:
        return "ERROR: HF_TOKEN is not set."

    try:
        client = InferenceClient(model=model_id, token=token, timeout=timeout)
        generated = client.text_generation(
            prompt,
            max_new_tokens=32,
            return_full_text=False,
        )
    except Exception as exc:
        return f"ERROR: inference failed for {model_id}: {exc}"

    if not isinstance(generated, str):
        generated = getattr(generated, "generated_text", None) or str(generated)
    return generated.strip()


def query_models(prompt: str, model_ids: Sequence[str]) -> dict[str, str]:
    return {model_id: query_model(prompt, model_id=model_id) for model_id in model_ids}


def parse_binary_prediction(output: str) -> int | None:
    """Parse a structured binary label from model output."""
    normalized = output.strip().lower()
    if normalized.startswith("error:"):
        return None

    structured_patterns = [
        r"\blabel\s*[:=]\s*([01])\b",
        r"\banswer\s*[:=]\s*([01])\b",
        r"\bprediction\s*[:=]\s*([01])\b",
    ]
    for pattern in structured_patterns:
        match = re.search(pattern, normalized)
        if match:
            return int(match.group(1))

    if re.fullmatch(r"[01]", normalized):
        return int(normalized)

    return None