dhuser's picture
add_support_ilaas_llm_provider (#2)
53aefb2
"""Prompt templates + ICL pool.
Templates are loaded from the tutorial repo so the app stays in sync with the
written material. ICLPool keeps a session-scoped, filterable bank of validated
or corrected examples.
"""
from __future__ import annotations
from copy import deepcopy
import json
import random
from dataclasses import dataclass, field, asdict
from typing import Optional
from paths import TUTORIAL_PROMPTS_DIR, read_text
from schemas import AnnotationSchema, to_json_schema
DEFAULT_SYSTEM_PROMPT = read_text(TUTORIAL_PROMPTS_DIR / "00_system_role.txt")
DEFAULT_ZERO_SHOT = read_text(TUTORIAL_PROMPTS_DIR / "01_zero_shot_pos_lemma_morph.txt")
DEFAULT_FEW_SHOT = read_text(TUTORIAL_PROMPTS_DIR / "02_few_shot_pos_lemma_morph.txt")
VALIDATION_RETRY = read_text(TUTORIAL_PROMPTS_DIR / "03_validation_retry.txt")
@dataclass
class ICLExample:
language: str
schema_hash: str
tokens: list[str]
gold_annotation: dict # {"tokens": [{"surface": ..., "lemma": ..., "pos": ...}, ...]}
source: str = "sandbox" # "sandbox" | "uploaded" | "corrected"
note: str = ""
@dataclass
class ICLPool:
"""Session-scoped pool of in-context examples.
Filter by (language, schema_hash) so a POS-correction never leaks into NER.
"""
entries: list[ICLExample] = field(default_factory=list)
version: int = 0
def _key(self, ex: ICLExample) -> tuple[str, str, tuple[str, ...]]:
return (
ex.language or "",
ex.schema_hash or "",
tuple(ex.tokens or []),
)
def _same_content(self, a: ICLExample, b: ICLExample) -> bool:
return a.gold_annotation == b.gold_annotation
def add(self, ex: ICLExample) -> str:
ex = deepcopy(ex)
key = self._key(ex)
for i, existing in enumerate(self.entries):
if self._key(existing) == key:
if self._same_content(existing, ex):
return "unchanged"
self.entries[i] = ex
self.version += 1
return "updated"
self.entries.append(ex)
self.version += 1
return "inserted"
def filter(self, language: str = "", schema_hash: str = "") -> list[ICLExample]:
out = self.entries
if language:
out = [e for e in out if e.language == language]
if schema_hash:
out = [e for e in out if e.schema_hash == schema_hash]
return out
def sample(
self,
n: int,
language: str = "",
schema_hash: str = "",
strategy: str = "random",
seed: int = 0,
) -> list[ICLExample]:
pool = self.filter(language=language, schema_hash=schema_hash)
if not pool or n <= 0:
return []
if strategy == "most_recent_corrections":
corr = [e for e in pool if e.source == "corrected"]
corr = list(reversed(corr))
others = [e for e in pool if e.source != "corrected"]
return (corr + others)[:n]
if strategy == "by_language":
# already filtered by language; deterministic order
return pool[:n]
rng = random.Random(seed)
return rng.sample(pool, min(n, len(pool)))
def to_jsonl(self) -> str:
return "\n".join(json.dumps(asdict(e), ensure_ascii=False) for e in self.entries)
@classmethod
def from_jsonl(cls, text: str) -> "ICLPool":
pool = cls()
for line in text.splitlines():
line = line.strip()
if not line:
continue
d = json.loads(line)
pool.entries.append(ICLExample(**d))
return pool
# ---------------------------------------------------------------------------
# Prompt rendering
# ---------------------------------------------------------------------------
def render_few_shot_block(examples: list[ICLExample]) -> str:
"""Format ICL examples as compact JSON blocks separated by ---."""
blocks = []
for i, ex in enumerate(examples, 1):
block = {
"tokens": ex.tokens,
"gold": ex.gold_annotation,
}
blocks.append(f"### Example {i}\n```json\n{json.dumps(block, ensure_ascii=False, indent=2)}\n```")
return "\n\n".join(blocks)
def render_inventory(schema: AnnotationSchema) -> tuple[str, str]:
"""Build the (upos_inventory, feature_inventory) text blobs for the template."""
upos_lines = []
feature_lines = []
for f in schema.fields:
if f.type == "enum":
upos_lines.append(f"- `{f.name}` ∈ {{{', '.join(f.values)}}}")
elif f.type == "object":
for sub in f.subfields:
vals = sub.values or "(free string)"
feature_lines.append(f"- `{f.name}.{sub.name}` ∈ {vals}")
else:
upos_lines.append(f"- `{f.name}`: free string{' (nullable)' if f.nullable else ''}")
return "\n".join(upos_lines), "\n".join(feature_lines) or "(no morphological subfields)"
def render_prompt(
template: str,
*,
schema: AnnotationSchema,
tokens: list[str],
text: str = "",
language: str = "",
script: str = "",
domain: str = "",
sentence_id: str = "s1",
few_shot_examples: Optional[list[ICLExample]] = None,
) -> str:
"""Fill the user-prompt template. Unknown braces are left untouched."""
upos_inv, feat_inv = render_inventory(schema)
fs_block = render_few_shot_block(few_shot_examples or [])
schema_str = json.dumps(to_json_schema(schema), ensure_ascii=False, indent=2)
mapping = {
"{language}": language or schema.language or "(unspecified)",
"{script}": script or "(unspecified)",
"{domain}": domain or "(unspecified)",
"{sentence_id}": sentence_id,
"{text}": text or " ".join(tokens),
"{tokens}": json.dumps(tokens, ensure_ascii=False),
"{upos_inventory}": upos_inv,
"{feature_inventory}": feat_inv,
"{few_shot_examples}": fs_block or "(none)",
"{schema}": schema_str,
"{tagset}": upos_inv,
}
out = template
for k, v in mapping.items():
out = out.replace(k, v)
return out